# Exercise: Variational Autoencoders

In [None]:
from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor, Compose, Normalize
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
if torch.backends.cuda.is_built():
    device = torch.device('cuda')
elif torch.backends.mps.is_built():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

print('Using device:', device)

In [None]:
mnist_train = MNIST('./data', download=True, transform=Compose([
                               ToTensor()]))

mnist_test = MNIST('./data', download=True, train=False, transform=Compose([
                               ToTensor()]))


mnist_train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True)
mnist_test_loader = DataLoader(mnist_test, batch_size=64, shuffle=True)

# Exercise description

Objective of this exercise is to develop a variational autoencoder, plot reconstruction of input images is, visualize how the examples are distributed in the embedding space, and generated and plot new images.

The autoencoder will be based on a convolutional architecture with the same spec as the one used in the previous exercise. 

## Encoder Layers:

- encoder1: Convolutional layer with 1 input channel, 16 output channels, a kernel size of 3, stride of 2, and padding of 1. Relu activation.
- encoder2: Convolutional layer with 16 input channels, 32 output channels, a kernel size of 3, stride of 2, and padding of 1. Relu activation.
- encoder3: Convolutional layer with 32 input channels, 64 output channels, a kernel size of 7, stride of 1, and no padding. Relu activation.
- mu: Fully connected (linear) layer reducing encoder3 dimensionality to z_size. 
- sigma: Fully connected (linear) layer reducing encoder3 dimensionality to z_size. 


## Decoder Layers:

- decoder1: Fully connected (linear) layer increasing the dimensionality from z_size to 64. Relu activation.
- decoder2: Transposed convolutional layer with 64 input channels, 32 output channels, a kernel size of 7, stride of 1, and no padding. Relu activation.
- decoder3: Transposed convolutional layer with 32 input channels, 16 output channels, a kernel size of 3, stride of 2, padding of 1, and output padding of 1. Relu activation.
- decoder4: Transposed convolutional layer with 16 input channels, 1 output channel, a kernel size of 3, stride of 2, padding of 1, and output padding of 1. Relu activation.

where z_size is a parameter of the model specifying the size of the embedding space.

In [None]:
class VAEModel(torch.nn.Module):
    def __init__(self, z_size=3):
        super(VAEModel, self).__init__()
        pass



    def reparameterize(self, mu, sigma):
        pass

    def encode(self, x):
        pass

    
    def decode(self, z):
        pass

    def forward(self, x):
        pass

    # Generate a sample from the model by sampling hidden representations from the a random normal distribution and decoding them
    def generate(self):
        z = torch.randn(1, self.z_size).to(device)
        return self.decode(z)


The loss function will be the sum of the reconstruction loss and the KL divergence loss.

In [None]:

def loss_fn(x, y):
    pass

In [None]:
model = VAEModel(3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Training is nothing special and will be done with the Adam optimizer and MSE loss.

In [None]:
# Here the code to train the model
# ...


# Reconstruction

In [None]:
# torch.save(model.state_dict(), "vae_model.pth")
model.load_state_dict(torch.load("vae_model.pth"))

In [None]:
plt.figure(figsize=(5,5))
plt.gray()

imshape=(28,28)
preds = []

for i in range(25):
    with torch.no_grad():
        preds.append(model(mnist_test[i][0].to(device))[0].cpu().detach().float().numpy())

for i in range(25):
    ax = plt.subplot(5,5,i+1)
    # ax.imshow(preds[i].reshape(imshape))
    ax.imshow(preds[i].reshape(imshape), cmap='gray')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

plt.show()

# Generation

In [None]:
plt.figure(figsize=(5,5))
plt.gray()

imshape=(28,28)
preds = []

for i in range(25):
    with torch.no_grad():
        preds.append(model.generate().cpu().detach().float().numpy())
        # preds.append(model.generate().float().numpy())

for i in range(25):
    ax = plt.subplot(5,5,i+1)
    # ax.imshow(preds[i].reshape(imshape))
    ax.imshow(preds[i].reshape(imshape), cmap='gray')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)