# Variational Auto-Encoder

To build a new Variational Auto-Encoder, you need two networks:

* An encoder that will take as input an image and compute the parameters of list of Normal distributions
* A decoder that will take a sample from each Normal distribution and will output an image
For simplicity we will assume that:

* each network as a single hidden layer of size 100
* the latent space contains only 2 points

To understand exactly what a VAE is, you can:

* check the slides of Michèle Sebag
* check this tutorial: https://arxiv.org/abs/1606.05908

## 1.2. Encoder

* Compute an hidden representation:  𝑧=𝑟𝑒𝑙𝑢(𝑊1𝑥+𝑏1)
* Compute the means of the normal distributions:  𝑚𝑢=𝑊2𝑥+𝑏2 
* Compute the log variance of the normal distributions: 𝑙𝑜𝑔_𝑠𝑖𝑔𝑚𝑎_𝑠𝑞𝑢𝑎𝑟𝑒𝑑=𝑊3𝑥+𝑏3

In [18]:
import os
import numpy as np
import matplotlib.pyplot as plt
import math
import dataset_loader
from collections import OrderedDict
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

%matplotlib inline

In [19]:
# Download mnist dataset 
if("mnist.pkl.gz" not in os.listdir(".")):
    !wget http://deeplearning.net/data/mnist/mnist.pkl.gz
        
mnist_path = "./mnist.pkl.gz"

# load the 3 splits
train_data, dev_data, test_data = dataset_loader.load_mnist(mnist_path)

In [20]:
def getActivationFunction( activation ):    
    if activation == "ReLU":
        return nn.ReLU()
    if activation == "Tanh":
        return nn.Tanh()
    if activation == "Sigmoid":
        return nn.Sigmoid() 
    else:
        raise Exception('Wrong name for the activation function')

In [21]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, activation, dropout):
        super( Encoder, self ).__init__()
        
        self.activation = activation
        
        self.functions = OrderedDict()
        self.functions["affine_0"] = nn.Linear(input_dim,hidden_dim[0])
        self.functions["activation_0"] = getActivationFunction( activation )
        
        self.seq = nn.Sequential( self.functions )
        
        self.meanFunctions = OrderedDict()
        self.varFunctions = OrderedDict()
        
        self.meanFunctions["mean"] = nn.Linear( hidden_dim[0], latent_dim )
        self.meanSeq = nn.Sequential(  self.meanFunctions )
        
        self.varFunctions["var"] = nn.Linear( hidden_dim[0], latent_dim )  
        self.varSeq = nn.Sequential( self.varFunctions )
        
        self.init_parameters()
    
    def init_parameters(self):
        
        if self.activation == "Tanh":
            for idx, m in enumerate( self.meanSeq.modules() ):
                if type(m) == nn.Linear:
                    torch.nn.init.xavier_uniform_(m.weight.data)
                    
            for idx, m in enumerate( self.varSeq.modules() ):
                if type(m) == nn.Linear:
                    torch.nn.init.xavier_uniform_(m.weight.data)
            
    def forward(self,x):
        return self.meanSeq( self.seq(x) ), self.varSeq( self.seq(x) )
    
    
    

In [22]:
network = Encoder(784,[400],20, "ReLU", 0.0 )

print(network)

Encoder(
  (seq): Sequential(
    (affine_0): Linear(in_features=784, out_features=400, bias=True)
    (activation_0): ReLU()
  )
  (meanSeq): Sequential(
    (mean): Linear(in_features=400, out_features=20, bias=True)
  )
  (varSeq): Sequential(
    (var): Linear(in_features=400, out_features=20, bias=True)
  )
)


In [23]:
class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, activation, dropout):
        super( Encoder, self ).__init__()
        
        self.activation = activation
    
        self.functions = OrderedDict()
        self.functions["affine_0"] = nn.Linear(latent_dim,hidden_dim[0])
        self.functions["activation_0"] = getActivationFunction( activation )
        self.functions["output"] = nn.Linear( hidden_dim[0], input_dim )
        self.functions["softmax"] = nn.Softmax( dim=1 )
        
        self.init_parameters()
    
    def init_parameters(self):
        
        if self.activation == "Tanh":
            for idx, m in enumerate( self.meanSeq.modules() ):
                if type(m) == nn.Linear:
                    torch.nn.init.xavier_uniform_(m.weight.data)
                    
            for idx, m in enumerate( self.varSeq.modules() ):
                if type(m) == nn.Linear:
                    torch.nn.init.xavier_uniform_(m.weight.data)
            
    def forward(self,x):
        return self.seq(x)
    
    
    
    

In [24]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, x):
        z_mean, z_var = self.enconder(x)
        
        std = torch.exp(z_var/2)
        eps = torch.randn_like(std)
        
        x_sample = eps.mul( std ).add( z_mu )
        
        predicted = self.decoder( x_sample )
        
        return predicted, z_mean, z_var

In [None]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(reconstruction_x, x, z_mean, z_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + z_var - mean.pow(2) - z_var.exp())

    return BCE + KLD




In [None]:
def training_loop(nertwork, criterion, train_data, dev_data, test_data, n_epochs, batch_size, leatning_rate=0.01):
    
    network.train()
    trainingSize = train_data[0].shape[0]
    
    for epoch in range(n_epochs):
        
        print('>> epoch ', epoch)

        #Shuffle training indexes
        indexes = np.arange(trainingSize)
        np.random.shuffle(indexes)

        sumLoss = 0.0
        first = 0
        corrects_overall = 0
        
        # ----- Training loop ------
        
        n_acc = 0
        while(first < trainingSize) :
            
            # When batch_size = 1, batch_input corresponds to one image.
            batch_input = torch.cat(
                [
                    # we reshape the image tensor so it has dimension (1, 784)
                    torch.from_numpy(picture).reshape(1, -1) for picture in train_data[0][indexes[first:first + batch_size]]
                ],
                    # we want to concatenate on the batch dimension
                    dim=0
            )

            # When batch_size = 1, batch_label corresponds to the label of one image.
            batch_label = torch.tensor( train_data[1][indexes[first:first + batch_size]])
            first += batch_size
                
            # zero the parameter gradients    
            optimizer.zero_grad()
            
            #Forward propagation
            x_sample, output_mean, output_var = network(batch_input)
            
            #Compute loss, loss backward gradient and optimize parameters
            loss = loss_function( x_sample, batch_input, output_mean, output_var )
            loss.backward()

            optimizer.step()
            sumLoss += loss.item()

    return results

In [None]:
def test_loop(nertwork, criterion, test_data, batch_size, leatning_rate=0.01):
    
    network.eval()
    testSize = train_data[0].shape[0]
    
   
    sumLoss = 0.0
    first = 0
    corrects_overall = 0

    # ----- Training loop ------

    with torch.no_grad():
        
        n_acc = 0
        
        while(first < testSize) :

            # When batch_size = 1, batch_input corresponds to one image.
            batch_input = torch.cat(
                [
                    # we reshape the image tensor so it has dimension (1, 784)
                    torch.from_numpy(picture).reshape(1, -1) for picture in train_data[0][indexes[first:first + batch_size]]
                ],
                    # we want to concatenate on the batch dimension
                    dim=0
            )

            # When batch_size = 1, batch_label corresponds to the label of one image.
            batch_label = torch.tensor( train_data[1][indexes[first:first + batch_size]])
            first += batch_size

            # zero the parameter gradients    
            optimizer.zero_grad()

            #Forward propagation
            x_sample, output_mean, output_var = network(batch_input)

            #Compute loss, loss backward gradient and optimize parameters
            loss = loss_function( x_sample, batch_input, output_mean, output_var )
            loss.backward()

            optimizer.step()
            sumLoss += loss.item()

    return results

In [None]:

encoder = Encoder( 784, 400, 20, "ReLU", 0.0 )
decoder = Decoder( 784, 400, 20, "ReLU", 0.0 )


vaeModel = VAE( encoder, decoder )
optimizer = optim.Adam(vaeModel.parameters(), lr=0.001)

