In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import time
import torch
import argparse
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from torchvision import transforms
from collections import defaultdict
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import numpy as np

from vae_model import CVAEDataset, VAE

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ts = time.time()
new_data = pd.read_pickle("new_data_cvae.pkl")

In [57]:
new_data = new_data.fillna(method='ffill')

In [58]:
train, test = train_test_split(new_data, test_size=0.2)

In [59]:
train_dataset = CVAEDataset(train)
test_dataset = CVAEDataset(test)

In [60]:
batch_size = 128

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4)

test_loader = DataLoader(dataset=test_dataset,
                          batch_size=batch_size,
                          shuffle=False,
                          num_workers=4)


In [61]:
def loss_fn(recon_x, x, mean, log_var):
    BCE = torch.nn.functional.mse_loss(
        recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return (BCE + KLD) / x.size(0)

In [65]:
encoder_layer_sizes = [train.shape[1]-1,train.shape[1]-5, train.shape[1]//2]
decoder_layer_sizes = [train.shape[1]//2, train.shape[1]-5, train.shape[1]-1]
latent_size = 7
conditional = 1
num_condition = len(train['location'].unique())
learning_rate = 0.002
epochs = 1
print_every = 10

In [66]:
vae = VAE(
    encoder_layer_sizes=encoder_layer_sizes,
    latent_size=latent_size,
    decoder_layer_sizes=decoder_layer_sizes,
    conditional=conditional,
    num_labels=num_condition if conditional else 0,
    num_condition=num_condition).to(device)

optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)

logs = defaultdict(list)

for epoch in range(epochs):

    for iteration, (x, y) in enumerate(train_loader):
        x = x.to(device)
        y = y.to(device)
        
        if conditional:
            recon_x, mean, log_var, z = vae(x, y)
        else:
            recon_x, mean, log_var, z = vae(x)

        loss = loss_fn(recon_x.float(), x.float(), mean.float(), log_var.float())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        logs['loss'].append(loss.item())

        if iteration % print_every == 0 or iteration == len(train_loader)-1:
            print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:9.4f}".format(
                epoch, epochs, iteration, len(train_loader)-1, loss.item()))

            if conditional:
                c = torch.arange(0, num_condition).long().unsqueeze(1)
                x = vae.inference(n=c.size(0), c=c)
            else:
                x = vae.inference(n=num_condition)

with torch.no_grad():
    z = []
    for iteration, (x,y) in enumerate(test_loader):

        test_x = x.to(device)
        test_y = y.to(device)
        z += [vae.encode(test_x.to(device), test_y.to(device))]
    z = torch.cat(z, dim=0)

    z = z.mean(dim=0).cpu().numpy()
    z = z[:,:2]
    plt.scatter(x=test_z[:,0], y=test_z[:,1], c = test_y.cpu().numpy(), alpha=3)# , s='tab10')
    plt.colorbar()
#     plt.savefig('./plot/latent_space'+'.png',format='png')

model: VAE(
  (encoder): Encoder(
    (MLP): Sequential(
      (L0): Linear(in_features=21, out_features=14, bias=True)
      (BN0): BatchNorm1d(14, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (A0): ReLU()
      (L1): Linear(in_features=14, out_features=9, bias=True)
      (BN1): BatchNorm1d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (A1): ReLU()
    )
    (linear_means): Linear(in_features=9, out_features=7, bias=True)
    (linear_log_var): Linear(in_features=9, out_features=7, bias=True)
  )
  (decoder): Decoder(
    (MLP): Sequential(
      (L0): Linear(in_features=10, out_features=9, bias=True)
      (A0): ReLU()
      (BN0): BatchNorm1d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (L1): Linear(in_features=9, out_features=14, bias=True)
      (A1): ReLU()
      (BN1): BatchNorm1d(14, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (L2): Linear(in_features=14, out_features=18, b

KeyboardInterrupt: 