## Variational Auto Encoder

A VAE models complicated distribution through a complex deterministic transformation of a simple distribution. Like many other methods ( one of which was discussed in first notebook ), VAE also tries to maximize the log-likelihood of a sample from the target distribution. Particularly, VAEs maximize the variational lower bound to the log-likelihood. During training phase, it also learns an encoder to estimate the variational lower bound. 

VAE can be represented by the following :- 

<img src="images/VAE.png", width="900", align=”left”>



Let's learn an VAE which generates digits from the MNIST dataset.

In [11]:
################
# import modules
################

import numpy as np
import matplotlib.pyplot as plt


import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

In [12]:
# gpu or cpu
cuda = torch.cuda.is_available()

# if gpu is used
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

seed = 1

# dataset related
n_classes = 10
z_dim = 2
X_dim = 784
train_batch_size = 100
val_batch_size = train_batch_size
N = 1000
epochs = 5

# load into a dictionary
params = {}
params['cuda'] = cuda
params['n_classes'] = n_classes
params['z_dim'] = z_dim
params['X_dim'] = X_dim
params['train_batch_size'] = train_batch_size
params['val_batch_size'] = val_batch_size
params['N'] = N
params['epochs'] = epochs

In [17]:
# loading the dataset

transform = torchvision.transforms.Compose([
    transforms.Normalize((0.1307,),(0.3081,)),
    transforms.ToTensor()])

train_dataset = torchvision.datasets.MNIST(root='datasets/', 
                                            train=True, 
                                            transform=None, 
                                            target_transform=None, 
                                            download=True)
val_dataset = torchvision.datasets.MNIST(root='datasets/',
                                          train=False,
                                          transform = None,
                                          target_transform=None,
                                          download=True)

train_labeled_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=train_batch_size,
                                                   shuffle=True, **kwargs)
val_labeled_loader = torch.utils.data.DataLoader(val_dataset,
                                                   batch_size=val_batch_size,
                                                   shuffle=True, **kwargs)

In [19]:
# defining the encoder and decpder networks


class QNet(nn.Module):
    def __init__(self):
        super(QNet, self).__init__()
        self.layer1 = nn.Linear(X_dim, N)
        self.layer2 = nn.Linear(N,N)
        self.layer3_mean = nn.Linear(N,z_dim)
        self.layer3_var = nn.Linear(N,z_dim)
        
    def forward(self, x):
        x = F.dropout(self.layer1, p=0.2)
        x = F.relu(x)
        x = F.dropout(self.layer2, p=0.2)
        x = F.relu(x)
        
        x_mean = self.layer3_mean(x)
        x_var = self.layer3_var(x)
        return x_mean, x_var
        
class PNet(nn.Module):
    def __init__(self):
        super(PNet, self).__init__()
        self.layer1 = nn.Linear(z_dim, N)
        self.layer2 = nn.Linear(N, N)
        self.layer3 = nn.Linear(N, X_dim)
    
    def forward(self, x):
        x = F.dropout(self.layer1(x), p=0.2)
        x = F.relu(x)
        x = F.dropout(self.layer2(x), p=0.2)
        x = F.relu(x)
        x = F.sigmoid(self.layer3(x))
        

In [20]:
# train function