In [None]:
### Imports ###

import torch; torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
from torch.utils.data import DataLoader
import torch.distributions
import torchvision
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
from tqdm.notebook import tqdm
import time

In [None]:
### Define encoder class ###

class VariationalEncoder(nn.Module):
    def __init__(self, input_shape, latent_dims):
        super(VariationalEncoder, self).__init__()
        self.linear1 = nn.Linear(input_shape, int(input_shape/2))
        self.linear2 = nn.Linear(int(input_shape/2), int(input_shape/3))
        self.linear3 = nn.Linear(int(input_shape/3), int(input_shape/4))
        self.linear4 = nn.Linear(int(input_shape/4), latent_dims) #mu
        self.linear5 = nn.Linear(int(input_shape/4), latent_dims) #logstd
        
        self.N = torch.distributions.Normal(0, 1)
        #self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
        #self.N.scale = self.N.scale.cuda()
        self.kl = 0
    
    def forward(self, x):
        #x = torch.flatten(x, start_dim=1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        mu =  self.linear4(x)
        sigma = torch.exp(self.linear5(x))
        z = mu + sigma*self.N.sample(mu.shape)
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).mean()
        return z
    
### Define conditional decoder class ###
class Decoder(nn.Module):
    def __init__(self, input_shape, target_shape, latent_dims):
        super(Decoder, self).__init__()
        self.linear1 = nn.Linear(latent_dims+target_shape, int(input_shape/4))
        self.linear2 = nn.Linear(int(input_shape/4)+target_shape, int(input_shape/3))
        self.linear3 = nn.Linear(int(input_shape/3)+target_shape, int(input_shape/2))
        self.linear4 = nn.Linear(int(input_shape/2)+target_shape, input_shape)
        
    def forward(self, z, targets):
        ## lines for vanilla VAE (no conditioning)
        #z = F.relu(self.linear1(z))
        #z = torch.sigmoid(self.linear2(z))
                                 
        ## targets get concatenated to each layer output in decoder ##
        z = F.relu(self.linear1(torch.cat((z, targets), 1))) 
        z = F.relu(self.linear2(torch.cat((z, targets), 1)))
        z = F.relu(self.linear3(torch.cat((z, targets), 1)))
        z = torch.sigmoid(self.linear4(torch.cat((z, targets), 1)))
        
        return torch.cat((z, targets), 1)
    
### CVAE class ###
class CondVariationalAutoencoder(nn.Module):
    def __init__(self, input_shape, target_shape, latent_dims):
        super(CondVariationalAutoencoder, self).__init__()
        self.encoder = VariationalEncoder(input_shape, latent_dims)
        self.decoder = Decoder(input_shape, target_shape, latent_dims)
    
    def forward(self, x, targets):
        z = self.encoder(x)
        return self.decoder(z, targets)
    
### Function for KL annealing ###
def anneal_schedule(epoch):
    return min(epoch*0.1,1)

In [None]:
### Training function ###

def train(autoencoder, data, epochs):
    
    recon_list = []
    kl_list = []
    
    opt = torch.optim.Adam(autoencoder.parameters())
    epoch_counter = 0
    for epoch in epochs:
        for i, batch in enumerate(tqdm(data, desc="Epoch Progress")):
            x = batch
            x = x.to(device)
            opt.zero_grad()
            
            inp = x[:,:input_shape]
            tar = x[:,-target_shape:]
        
            x_hat = autoencoder(inp, tar) #Reconstructed samples (with targets)
            
            mse = ((inp - x_hat[:,:len(inp[0])])**2).mean() #Reconstruction loss between input and recon data
            kl = anneal_schedule(epoch_counter)*autoencoder.encoder.kl
            loss = mse + kl
            
            var = torch.var(inp) 
            rsq = 1-mse/var #R^2 Evaluation metric
            
            recon_list.append(mse)
            kl_list.append(kl)
            
            loss.backward()
            opt.step()
            
        epoch_counter += 1
        if epoch_counter % 50 == 0:
            print("Epoch %s " % (epoch_counter), "Loss: ", float(loss), "R^2: ", float(rsq))
        
            PATH = "/Users/dylansmith/Desktop/CS274E/Project/saved_models/cvae_epoch%s.pt" \
            % (epoch_counter)
            torch.save({
                'epoch': epoch_counter,
                'model_state_dict': autoencoder.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'recon_loss': recon_list,
                'kl_loss': kl_list
                }, PATH)
    return autoencoder

In [None]:
### Load data (using validation data for debugging) ###

path_input = '/Users/dylansmith/Desktop/CS274E/Project/val_input.npy'
data_input = torch.Tensor(np.load(path_input))

path_target = '/Users/dylansmith/Desktop/CS274E/Project/val_target.npy'
data_target = torch.Tensor(np.load(path_target))

data = torch.cat((data_input, data_target), 1)

train_dataloader = DataLoader(data, batch_size=32, shuffle=True)

In [None]:
latent_dims = 6
input_shape = len(data_input[0])
target_shape = len(data_target[0])
print("Number of data features: ", input_shape)
print("Number of data targets: ", target_shape)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", device)

cvae = CondVariationalAutoencoder(input_shape, target_shape, latent_dims).to(device) # GPU

In [None]:
### Training loop ###

epoch_tot = 400
epochs = tqdm(range(1, epoch_tot + 1), desc="Epochs")

vae = train(cvae, train_dataloader, epochs)

In [None]:
### Load model for further evaluation/testing ###

opt = torch.optim.Adam(cvae.parameters())

PATH = "/Users/dylansmith/Desktop/CS274E/Project/saved_models/cvae_epoch400.pt"
checkpoint = torch.load(PATH, map_location=device)
cvae.load_state_dict(checkpoint['model_state_dict'])
opt.load_state_dict(checkpoint['optimizer_state_dict'])
curr_epoch = checkpoint['epoch']
mse_loss = checkpoint['recon_loss']
kl_loss = checkpoint['kl_loss']