# Variational Auto-Encoders

In this exercise, I will try to build a Variational Auto-Encoder using PyTorch library. Next, VAE will be trained on MNIST dataset with a subsequent analysis of the results.

Special thanks to **Datacamp** website for their crash course on Pytorch.
Also to **Yandex Data Science School**, their Deep Learning course was helpful.
Finally, thanks to the **PyTorch development team** for this example: https://github.com/pytorch/examples/blob/master/vae/main.py

*Disclaimer*: VAE class implemented here is an almost identical copy of the VAE class that can be found by the link above. The main goal for me was to expand my knowledge of PyTorch, analyze an example of good usage of PyTorch and understand VAEs, and NOT to create an authentic piece of software.

In [12]:
# make necessary imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data
from torch.optim import *
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torchvision.utils import save_image

## Load MNIST dataset

In [13]:
# mean and std of the MNIST dataset (can be googled)
MNIST_mean = 0.1307
MNIST_std  = 0.3081

# batch size for an epoch of training, changeable
MNIST_batch_size = 32

MNIST_transform = transforms.Compose(
    [transforms.ToTensor()])

# prepare train and test sets
trainset = torchvision.datasets.MNIST("mnist", train=True, download=True, transform=MNIST_transform)
testset  = torchvision.datasets.MNIST("mnist", train=False, download=True, transform=MNIST_transform)

# prepare data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size = MNIST_batch_size, shuffle=True)
testloader  = torch.utils.data.DataLoader(testset, batch_size=MNIST_batch_size, shuffle=False)

## Creating VAE net
In this section we will try and implement VAE using PyTorch. This section may be subject to future updates and experiments in order to improve performance.

In [14]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        self.fc1  = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3  = nn.Linear(20, 400)
        self.fc4  = nn.Linear(400, 784)
    
    
    def encode(self, X):
        """
        encodes the inout image into two vectors: mean and variance
        :param X: input image in torch Tensor format
        :returns: mu and var
        """
        hidden1 = F.relu(self.fc1(X))
        return self.fc21(hidden1), self.fc22(hidden1)
    
    
    def reparameterize(self, mu, logvar):
        """
        implementation of the reparameterization trick, allowing for training with random sampling
        :param mu: mean values tensor
        :param var: logvariance tensor
        :returns: random tensor from the Gaussian distribution
        """
        # get variance
        std = torch.exp(0.5*logvar)
        
        # get random tensor from normal distribution of mean 0 and var 1 of size like std
        eps = torch.randn_like(std)
        return mu+eps*std
    
    
    def decode(self, z):
        """
        project a tensor from the latent space back into original coordinates
        :param z: tensor in the latent space to be decoded
        """
        hidden3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(hidden3))
    
    
    def forward(self, x):
        """
        forward function of VAE NN
        :param x: input image in torch Tensor format
        returns: x decoded from latent space along with mean and logvar tensors
        """
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar 

In [17]:
#instantiate a model
model = VAE()

# instantiate an Adam optimizer with L2-regularization
optimizer = optim.Adam(model.parameters(), lr = 1e-3, weight_decay = 0.001)

# define custom loss-function
def loss_func(reconstructed_x, x, mu, logvar):
    """
    defines a loss for our VAE. Consists of
    a) Reconstruction loss
    b) Kullback-Leibler divergence losses summed over all elements and the batch
    """

    reconstr_loss = F.binary_cross_entropy(reconstructed_x, x.view(-1, 784), reduction='sum')
    KLd = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return reconstr_loss+KLd


def train_VAE(epoch):
    # put model in train mode
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(trainloader):
        # zero all gradients
        optimizer.zero_grad()
        reconstructed_batch, mu, logvar = model(data)
        loss = loss_func(reconstructed_batch, data, mu, logvar)
        
        # do backprop
        loss.backward()
        
        train_loss += loss.item()
        optimizer.step()
    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(trainloader.dataset)))
    
def test_VAE(epoch):
    # put model into evaluation mode
    model.eval()
    test_loss = 0
    
    # deactivate autograd engine (backprop become unavailable, but that speeds up computations and
    # reduces memory usage; also, we don't update weights here, so backprop is not needed).
    with torch.no_grad():
        for i, (data, _) in enumerate(testloader):
            recon_batch, mu, logvar = model(data)
            test_loss += loss_func(recon_batch, data, mu, logvar).item()
    test_loss /= len(testloader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

## Applying VAE in MNIST dataset
In this section we will experiment with VAE and use everything we prepared up until this point.

In [None]:
EPOCH_NUM = 10
for epoch in range(1, EPOCH_NUM + 1):
    train_VAE(epoch)
    test_VAE(epoch)
    with torch.no_grad():
        # let's try to decode a random vector from our latent space!
        sample = torch.randn(64, 20)
        sample = model.decode(sample)
        save_image(sample.view(64, 1, 28, 28),'results/sample_' + str(epoch) + '.png')

====> Epoch: 1 Average loss: 135.7474
====> Test set loss: 115.0611
====> Epoch: 2 Average loss: 113.1201
====> Test set loss: 110.2719
====> Epoch: 3 Average loss: 110.0302
====> Test set loss: 108.1107
====> Epoch: 4 Average loss: 108.5750
====> Test set loss: 107.5699
====> Epoch: 5 Average loss: 107.7129
====> Test set loss: 107.1990
====> Epoch: 6 Average loss: 107.1025
====> Test set loss: 106.3990
====> Epoch: 7 Average loss: 106.5550
====> Test set loss: 106.3621
====> Epoch: 8 Average loss: 106.1665
====> Test set loss: 105.8615
