**Problem 5:** Perform an interpolation experiment with your trained model from Problem 4 

#Import and Definations
We have followed official documentation of PyTorch. It is an open source python framework used for machine learning

In [None]:
%matplotlib inline
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt

#Changing the form
img = transforms.Compose([
    transforms.ToTensor()
])
# Using MNIST dataset as stated in the project.
mnist_DataSet = MNIST(root='./data/MNIST', download=True, train=True, transform=img)
data_Loader = DataLoader(mnist_DataSet, batch_size=8, shuffle=False)

#Variational Autoencoders (VAEs)
In order to ensure that the latent space of the autoencoder, the VAE, has good qualities and can produce some new data, the distribution of its encodings is regularized during training. Additionally, the word "variational" derives from the tight connection between the regularization and variational inference methods in statistics.

In [None]:
# Convolution encoder and decoder tends to perform better with same number of parameters.
class En(nn.Module):
    def __init__(self):
        super(En, self).__init__()
        c = 64
        self.firstC = nn.Conv2d(1,c,4,2,1)
        self.secoundC = nn.Conv2d(c,c*2,4,2,1)
        self.fc_mu = nn.Linear(c*2*7*7, 2)
        self.fc_logvar = nn.Linear(c*2*7*7, 2)
            
    def forward(self, x):
        x = F.relu(self.firstC(x))
        x = F.relu(self.secoundC(x))
        x = x.view(x.size(0), -1) # flatten batch of multi-channel feature maps to a batch of feature vectors
        x_mu = self.fc_mu(x)
        x_logvar = self.fc_logvar(x)
        return x_mu, x_logvar

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        c = 64
        self.fc = nn.Linear(2, c*2*7*7)
        self.firstC = nn.ConvTranspose2d(c, 1, 4, 2, 1)
        self.secondC = nn.ConvTranspose2d(c*2, c, 4, 2, 1)
        
            
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 64*2, 7, 7) # unflatten batch of feature vectors to a batch of multi-channel feature maps
        x = F.relu(self.secondC(x))
        x = torch.sigmoid(self.firstC(x)) # last layer before output is sigmoid, since we are using BCE as reconstruction loss
        return x

In [None]:
class VariationalAutoencoder(nn.Module):
    def __init__(self):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = En()
        self.decoder = Decoder()
    
    def forward(self, x):
        latent_mu, latent_logvar = self.encoder(x)
        latent = self.latent_sample(latent_mu, latent_logvar)
        x_recon = self.decoder(latent)
        return x_recon, latent_mu, latent_logvar
    
    def latent_sample(self, mu, logvar):
        if self.training:
            # the reparameterization trick
            std = logvar.mul(0.5).exp_()
            eps = torch.empty_like(std).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu
    
def vae_loss(recon_x, x, mu, logvar):
    
    recon_loss = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
    kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + 0.01 * kldivergence, recon_loss, 0.01 * kldivergence
    
    
vae = VariationalAutoencoder()

device = torch.device("cpu")
vae = vae.to(device)

num_params = sum(p.numel() for p in vae.parameters() if p.requires_grad)


- `n_epochs` and `learning rate` are the training hyper-parameters.

In [None]:
num_epochs =10
learning_rate = 1e-5
optimizer = torch.optim.Adam(params=vae.parameters(), lr=learning_rate, weight_decay=1e-5)

# set to training mode
vae.train()

train_loss_avg = []
print('Training ...')
for epoch in range(num_epochs):
    train_loss_avg.append(0)
    num_batches = 0
    
    for image_batch, _ in data_Loader:
        
        image_batch = image_batch.to(device)

        # vae recaonstruction
        image_batch_recon, latent_mu, latent_logvar = vae(image_batch)
        
        # reconstruction error
        loss,_,_ = vae_loss(image_batch_recon, image_batch, latent_mu, latent_logvar)
        
        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        # # one step of the optmizer (using the gradients from backpropagation)
        optimizer.step()
        
        train_loss_avg[-1] += loss.item()
        num_batches += 1
        
    train_loss_avg[-1] /= num_batches
    print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, num_epochs, train_loss_avg[-1]))


Training ...
Epoch [1 / 10] average reconstruction error: 1648.601481
Epoch [2 / 10] average reconstruction error: 1474.469752


In [None]:
# # Specify a path to save the trainet network
# PATH = "Trained_MNIST_Model.pt"

# torch.save(vae, PATH)

### Visualizing Reconstruction and Interpolation
We can interpolate source images in latent space using q as stochastic encoder, then decoding the lineraly interpolated latent into images space by reverse space by the reverse process.


In [None]:


def to_img(x):
    x = x.clamp(0, 1)
    return x

vae.eval()

def interpolation(lambda1, model, firstImage, secondImage):
    
    with torch.no_grad():
    
        # frist image latent vector
        firstImage = firstImage.to(device)
        latent_1, _ = model.encoder(firstImage)

        # second image latent vector
        secondImage = secondImage.to(device)
        latent_2, _ = model.encoder(secondImage)

        # interpolation of the two latent vectors
        both_Img_Latent = lambda1* latent_1 + (1- lambda1) * latent_2

        # reconstruction of the interpolated image
        both_Img_Latent = model.decoder(both_Img_Latent)
        both_Img_Latent = both_Img_Latent.cpu()

        return both_Img_Latent
    
# Output part of the test dataset by number
digits = [[] for _ in range(10)]
for img_batch, label_batch in data_Loader:
    for imgCount in range(img_batch.size(0)):
        digits[label_batch[imgCount]].append(img_batch[imgCount:imgCount+1])
    if sum(len(d) for d in digits) >= 1000:
        break;

# lambda interpolation
lambda_range=np.linspace(0,1,10)

fig, axs = plt.subplots(2,5, figsize=(15, 6))
fig.subplots_adjust(hspace = .5, wspace=.001)
axs = axs.ravel()

for ind,l in enumerate(lambda_range):
    both_Img_Latent=interpolation(float(l), vae, digits[7][0], digits[1][0])
   
    both_Img_Latent = to_img(both_Img_Latent)
    
    image = both_Img_Latent.numpy()
   
    axs[ind].imshow(image[0,0,:,:], cmap='gray')
    axs[ind].set_title('lambda='+str(round(l,1)))
plt.show() 