# Project AI

Import required sources.

In [None]:
%pylab inline
import torch
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch import nn, optim
from VAE import *
from train import *
import numpy as np
from collections import *
import random

## Run Gaussian model

In [None]:
# set hyperparameters
latent_dim = 2
batch_size = 50
epochs = 50

# Load data
train_data = datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=batch_size, shuffle=True, **{})

In [None]:
VAE_Gaussian, loss_Gaussian, z_Gaussian, KL_Gaussian, log_bern_Gaussian = run_train(latent_dim, epochs, 'Gaussian', train_loader, 1e-3)

In [None]:
VAE_gumbel, loss_gumbel, z_gumbel, KL_gumbel, log_bern_gumbel = run_train(latent_dim, epochs, 'Gumbel', train_loader, 1e-3)

In [None]:
VAE_logit, loss_logit, z_logit, KL_logit, log_bern_logit = run_train(latent_dim, epochs, 'logit', train_loader, 1e-3)

In [None]:
random.seed(1000)

VAE_logit_rank1, loss_logit_rank1, z_logit_rank1, KL_logit_rank1, log_bern_rank1 = run_train(latent_dim, epochs, 'logit', train_loader, 1e-3, True)

In [None]:
VAE_logit_sigmoid, loss_logit_sigmoid, z_logit_sigmoid, KL_logit_sigmoid, log_bern_sigmoid = run_train(3, epochs, 'logit-sigmoidal', train_loader, 1e-3)

## Plot of losses

In [None]:
epoch_space = np.linspace(1,50,50)
plt.plot(epoch_space, loss_Gaussian, label='Gaussian')
plt.plot(epoch_space, loss_logit, label='logit')
plt.plot(epoch_space, loss_logit_rank1, label='logit rank1')
plt.plot(epoch_space, loss_logit_sigmoid, label='logit sigmoidal')
plt.plot(epoch_space, loss_gumbel, label='concrete')
plt.legend()
plt.show()

## 2-D scatterplot

In [None]:
def find_latent_coordinates(train_loader_scatter_plot, model):
    
    x_coordinates = defaultdict(lambda: [])
    y_coordinates = defaultdict(lambda: [])
    
    for batch_idx, (data, label) in enumerate(train_loader_scatter_plot):
        _, z, _ = model(data)
        index = label.data.cpu().numpy()[0]
        
        x_coordinates[index].append(z.data.cpu().numpy()[0][0])
        y_coordinates[index].append(z.data.cpu().numpy()[0][1])
        
        if batch_idx == 10000:
            break
            
    return x_coordinates, y_coordinates

In [None]:
train_data_scatter_plot = datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor())

train_loader_scatter_plot = torch.utils.data.DataLoader(train_data_scatter_plot,
                                           batch_size=1, shuffle=False, **{})

In [None]:
x_latent_space_Gumbel, y_latent_space_Gumbel = find_latent_coordinates(train_loader_scatter_plot, VAE_gumbel)
for label in x_latent_space_Gumbel:
    plt.scatter(x_latent_space_Gumbel[label], y_latent_space_Gumbel[label], marker='.')
plt.show()

In [None]:
x_latent_space_Gaussian, y_latent_space_Gaussian = find_latent_coordinates(train_loader_scatter_plot, VAE_Gaussian)
for label in x_latent_space_Gaussian:
    plt.scatter(x_latent_space_Gaussian[label], y_latent_space_Gaussian[label], marker='.')
plt.show()

In [None]:
x_latent_space_logit, y_latent_space_logit = find_latent_coordinates(train_loader_scatter_plot, VAE_logit)
for label in x_latent_space_logit:
    plt.scatter(x_latent_space_logit[label], y_latent_space_logit[label], marker='.')
plt.show()

In [None]:
x_latent_space_logit, y_latent_space_logit = find_latent_coordinates(train_loader_scatter_plot, VAE_logit_sigmoid)
for label in x_latent_space_logit:
    plt.scatter(x_latent_space_logit[label], y_latent_space_logit[label], marker='.')
plt.show()

In [None]:
x_latent_space_logit, y_latent_space_logit = find_latent_coordinates(train_loader_scatter_plot, VAE_logit_rank1)
for label in x_latent_space_logit:
    plt.scatter(x_latent_space_logit[label], y_latent_space_logit[label], marker='.')
plt.show()

## 2-D plot of models

In [None]:
### Let's check if the reconstructions make sense
# Set model to test mode
VAE_Gaussian.eval()
    
# Reconstructed
train_data_plot = datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor())

train_loader_plot = torch.utils.data.DataLoader(train_data_plot,
                                           batch_size=1, shuffle=False, **{})

for batch_idx, (data, _) in enumerate(train_loader_plot):
    x_hat, mu, logvar = VAE_Gaussian(data)
    plt.imshow(x_hat.view(1,28,28).squeeze().data.numpy(), cmap='gray')
    plt.title('%i' % train_data.train_labels[batch_idx])
    plt.show()
    if batch_idx == 3:
        break

## Histograms of z

In [None]:
def plot_histogram(title, z):
    hist_values = [[] for i in range(len(z))]
    
    for i in range(z[0].size(0)):
        for j in range(len(z)):
            hist_values[i].append(float(z[j][i]))
        plt.hist(hist_values[i],  bins = 10, histtype=u'step')
        plt.show()

In [None]:

plot_histogram('Histogram for Gaussian Distribution', z_Gaussian)


In [None]:

plot_histogram('Histogram for the Gumbel Distribution', z_gumbel)


In [None]:

plot_histogram('Histogram for logit Distribution', z_logit)


In [None]:

plot_histogram('Histogram for the Logit Distribution for Rank1 Covariance Approximation', z_logit_rank1)


In [None]:

plot_histogram('Histogram for logit_sigmoid', z_logit_sigmoid)


In [None]:
print ('KL mean for Gauss = ', torch.mean(torch.stack(KL_Gaussian)).data.cpu().numpy())
print ('KL mean for Gumbel = ', torch.mean(torch.stack(KL_gumbel)).data.cpu().numpy())
print ('KL mean for logit = ',torch.mean(torch.stack(KL_logit)).data.cpu().numpy())
print ('KL mean for logit_rank1 = ', torch.mean(torch.stack(KL_logit_rank1)).data.cpu().numpy())
print ('KL mean for logit_sigmoid = ', torch.mean(torch.stack(KL_logit_sigmoid)).data.cpu().numpy())

In [None]:
print ('log_bern mean for Gaussian = ' ,torch.mean(torch.stack(log_bern_Gaussian)).data.cpu().numpy())
print ('log_bern mean for Gumbel = ', torch.mean(torch.stack(log_bern_gumbel)).data.cpu().numpy())
print ('log_bern mean for Logit = ', torch.mean(torch.stack(log_bern_logit)).data.cpu().numpy())
print ('log_bern mean for Logit_rank1 = ', torch.mean(torch.stack(log_bern_rank1)).data.cpu().numpy())
print ('log_bern mean for Logit_sigmoid = ', torch.mean(torch.stack(log_bern_sigmoid)).data.cpu().numpy())