# Variational Auto-Encoder

To build a new Variational Auto-Encoder, you need two networks:

* An encoder that will take as input an image and compute the parameters of list of Normal distributions
* A decoder that will take a sample from each Normal distribution and will output an image
For simplicity we will assume that:

* each network as a single hidden layer of size 100
* the latent space contains only 2 points

To understand exactly what a VAE is, you can:

* check the slides of Michèle Sebag
* check this tutorial: https://arxiv.org/abs/1606.05908

## 1.2. Encoder

* Compute an hidden representation:  𝑧=𝑟𝑒𝑙𝑢(𝑊1𝑥+𝑏1)
* Compute the means of the normal distributions:  𝑚𝑢=𝑊2𝑥+𝑏2 
* Compute the log variance of the normal distributions: 𝑙𝑜𝑔_𝑠𝑖𝑔𝑚𝑎_𝑠𝑞𝑢𝑎𝑟𝑒𝑑=𝑊3𝑥+𝑏3

In [18]:
import os
import numpy as np
import matplotlib.pyplot as plt
import math
import dataset_loader
from collections import OrderedDict
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torchvision import datasets, transforms

%matplotlib inline

In [19]:
BATCH_SIZE = 10 

# To transform images into Tensors
transforms2 = transforms.Compose([transforms.ToTensor()])

train_data = datasets.MNIST(
    './data',
    train=True,
    download=True,
    transform=transforms2)

test_data = datasets.MNIST(
    './data',
    train=False,
    download=True,
    transform=transforms2
)

# The code can run on gpu (or) cpu, we can use the gpu if available. 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_iterator = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)


In [20]:
BATCH_SIZE=10


# Download mnist dataset 
if("mnist.pkl.gz" not in os.listdir(".")):
    !wget http://deeplearning.net/data/mnist/mnist.pkl.gz
        
mnist_path = "./mnist.pkl.gz"

# load the 3 splits
train_data, dev_data, test_data = dataset_loader.load_mnist(mnist_path)



train_iterator = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_iterator = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)


In [21]:
def getActivationFunction( activation ):    
    if activation == "ReLU":
        return nn.ReLU()
    if activation == "Tanh":
        return nn.Tanh()
    if activation == "Sigmoid":
        return nn.Sigmoid() 
    else:
        raise Exception('Wrong name for the activation function')

In [22]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, activation):
        super( Encoder, self ).__init__()
        
        self.activation = activation
        
        self.functions = OrderedDict()
        self.functions["affine_0"] = nn.Linear(input_dim,hidden_dim[0])
        self.functions["activation_0"] = getActivationFunction( activation )
        
        self.seq = nn.Sequential( self.functions )
        
        self.meanFunctions = OrderedDict()
        self.varFunctions = OrderedDict()
        
        self.meanFunctions["mean"] = nn.Linear( hidden_dim[0], latent_dim )
        self.meanSeq = nn.Sequential(  self.meanFunctions )
        
        self.varFunctions["var"] = nn.Linear( hidden_dim[0], latent_dim )  
        self.varSeq = nn.Sequential( self.varFunctions )
        
        self.init_parameters()
    
    def init_parameters(self):
        
        if self.activation == "Tanh":
            for idx, m in enumerate( self.meanSeq.modules() ):
                if type(m) == nn.Linear:
                    torch.nn.init.xavier_uniform_(m.weight.data)
                    
            for idx, m in enumerate( self.varSeq.modules() ):
                if type(m) == nn.Linear:
                    torch.nn.init.xavier_uniform_(m.weight.data)
            
    def forward(self,x):
        return self.meanSeq( self.seq(x) ), self.varSeq( self.seq(x) )
    
    
    

In [23]:
class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, activation):
        super( Decoder, self ).__init__()
        
        self.activation = activation
    
        self.functions = OrderedDict()
        self.functions["affine_0"] = nn.Linear(latent_dim,hidden_dim[0])
        self.functions["activation_0"] = getActivationFunction( activation )
        self.functions["output"] = nn.Linear( hidden_dim[0], input_dim )
        self.functions["softmax"] = nn.Softmax( dim=1 )
        
        
        self.seq = nn.Sequential(  self.functions )
        
        
        self.init_parameters()
    
    def init_parameters(self):
        
        if self.activation == "Tanh":
            for idx, m in enumerate( self.seq.modules() ):
                if type(m) == nn.Linear:
                    torch.nn.init.xavier_uniform_(m.weight.data)
                    
    def forward(self,x):
        return self.seq(x)
    
    
    
    

In [24]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        super( VAE, self ).__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, x):
        z_mean, z_var = self.encoder(x)    
        
        std = torch.sqrt( torch.exp( z_var ) )
        eps = torch.normal(0, 1., z_mean.shape)
        
        x_sample = z_mean + eps * std
        
        predicted = self.decoder( x_sample )
        
        return predicted, z_mean, z_var

In [25]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(reconstruction_x, x, z_mean, z_var):
    BCE = F.binary_cross_entropy(reconstruction_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + z_var - z_mean.pow(2) - z_var.exp())

    return BCE + KLD


In [26]:
def training_loop(network, learning_rate=0.01):
    
    network.train()
    optimizer = optim.Adam(network.parameters(), lr=learning_rate)
    
    sumLoss = 0.0
    
    
    for i, (x, _) in enumerate(train_iterator):
        
        # reshape the data into [batch_size, 784]
        x = x.view(-1, 28 * 28)
        x = x.to(device)
                
        # zero the parameter gradients    
        optimizer.zero_grad()

        #Forward propagation
        x_sample, output_mean, output_var = network(x)


        #Compute loss, loss backward gradient and optimize parameters
        loss = loss_function( x_sample, batch_input, output_mean, output_var )
        loss.backward()

        sumLoss += loss.item()

        optimizer.step()

            
    return sumLoss


In [27]:
def test_loop(network):
    
    network.eval()
    
    
    testSize = test_data[0].shape[0]
    sumLoss = 0.0
    
    # we don't need to track the gradients, since we are not updating the parameters during evaluation / testing
    with torch.no_grad():
        
        for i, (x, _) in enumerate(test_iterator):
            
            # reshape the data
            x = x.view(-1, 28 * 28)
            x = x.to(device)
                
            #Forward propagation
            x_sample, output_mean, output_var = network(x)
            
           
            #Compute loss, loss backward gradient and optimize parameters
            loss = loss_function( x_sample, batch_input, output_mean, output_var )
            
            sumLoss += loss.item()

            
    return sumLoss


In [28]:
def executeVAE( input_dim, hidden_dim, latent_dim, activation, n_epochs = 50, batch_size = 1 ):

    encoder = Encoder(input_dim,hidden_dim,latent_dim, activation )
    decoder = Decoder(input_dim,hidden_dim,latent_dim, activation )

    vae = VAE( encoder, decoder )
    
    print(vae)
    
    best_test_loss = float('inf')

    for epoch in range(n_epochs):

        train_loss = training_loop( vae, learning_rate=0.01 )
        test_loss = test_loop( vae )

        train_loss /= len(train_dataset)
        test_loss /= len(test_dataset)

        print('>> epoch ', epoch)
        print('---> Train loss: >> ', train_loss)
        print('-----> Test loss: >> ', test_loss)
        print('')

        if best_test_loss > test_loss:
            best_test_loss = test_loss
            patience_counter = 1
        else:
            patience_counter += 1

        if patience_counter > 3:
            break
      
    return vae



In [30]:
vae = executeVAE(784,[100],2, "ReLU", 5, 10 )


VAE(
  (encoder): Encoder(
    (seq): Sequential(
      (affine_0): Linear(in_features=784, out_features=100, bias=True)
      (activation_0): ReLU()
    )
    (meanSeq): Sequential(
      (mean): Linear(in_features=100, out_features=2, bias=True)
    )
    (varSeq): Sequential(
      (var): Linear(in_features=100, out_features=2, bias=True)
    )
  )
  (decoder): Decoder(
    (seq): Sequential(
      (affine_0): Linear(in_features=2, out_features=100, bias=True)
      (activation_0): ReLU()
      (output): Linear(in_features=100, out_features=784, bias=True)
      (softmax): Softmax(dim=1)
    )
  )
)


RuntimeError: Expected object of scalar type Long but got scalar type Float for sequence element 1 in sequence argument at position #1 'tensors'

In [31]:

e = torch.normal(0, 1., (10, 2))
images = vae.decoder(e).sigmoid()

for i in range(10):
    picture = images[i].clone().detach().numpy()
    plt.imshow(picture.reshape(28,28), cmap='Greys')
    plt.show()       


NameError: name 'vae' is not defined