# AutoEncoder, Variational AutoEncoder and GAN

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from datetime import datetime
from utils import *

seed = 265
torch.manual_seed(seed)
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f"Training on device {device}.")

## General instructions

## Introduction

In this assignment we will go through 3 types of unsupervised neural network: AutoEncoder (AE), Variational AutoEncoder (VAE) and Generative Adversarial Network (GAN). In the first section we will also introduce a new type of layer: the transpose convolution as it is widely used in these unsupervised methods.

Unsupervised have many advantages including the fact that they don't need labels but they are also harder to train. It is normal if you don't get good results.

## Contents

1. Transpose convolution
2. AutoEncoder
3. Variational AutoEncoder
4. GAN

## Related videos from the curriculum

- [Lecture 15.1 — From PCA to autoencoders](https://www.youtube.com/watch?v=PSOt7u8u23w&list=PLLssT5z_DsK_gyrQ_biidwvPYCRNGI3iv&index=69)
- [Lecture 15.2 — Deep autoencoders](https://www.youtube.com/watch?v=6jhhIPdgkp0&list=PLLssT5z_DsK_gyrQ_biidwvPYCRNGI3iv&index=70) 
- [Lecture 15.3 — Deep autoencoders for document retrieval](https://www.youtube.com/watch?v=ZCNbjpcX0yg&list=PLLssT5z_DsK_gyrQ_biidwvPYCRNGI3iv&index=71)
- [Lecture 15.6 — Shallow autoencoders for pre training](https://www.youtube.com/watch?v=xjlvVfEbhz4&list=PLLssT5z_DsK_gyrQ_biidwvPYCRNGI3iv&index=74)
- [Lecture 13 | Generative Models](https://www.youtube.com/watch?v=5WoItGTWV54)


# 2. (Reminders) Encoder and Decoder

### Modules 

In the cell below are defined the following modules that we will need in this section

1. **MyEncoder**
    - input: image
    - output: tensor `z` in latent space (lower dimension than input space)
1. **MyDecoder**
    - input: tensor `z` in latent space (lower dimension than image space)
    - output: reconstructed image


In [None]:
class MyEncoder(nn.Module):
    """
    Encoder module: 
    - input: image
    - output: tensor `z` in latent space (lower dimension than input space)
    """

    def __init__(self, z_dim):
        super().__init__() 
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=6, stride=1)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=5, kernel_size=5, stride=1)
        self.conv3 = nn.Conv2d(in_channels=5, out_channels=4, kernel_size=4, stride=1)
        self.fc3 = nn.Linear(in_features=256, out_features=z_dim)
        
    def forward(self, x):
        N = x.shape[0]
        out = torch.relu(self.conv1(x))
        out = torch.relu(self.conv2(out))
        out = torch.relu(self.conv3(out))
        out = out.view(N, -1)
        out = torch.relu(self.fc3(out))
        return out

class MyDecoder(nn.Module):
    """
    Decoder module: 
    - input: tensor `z` in latent space (lower dimension than image space)
    - output: reconstructed image
    """
    def __init__(self, z_dim):
        super().__init__() 
        c1 = 3
        self.fc1 = nn.Linear(z_dim, 128)
        self.fc2 = nn.Linear(128, 18*18)
        self.transconv3 = nn.ConvTranspose2d(in_channels=1, out_channels=1,  kernel_size=3, stride=1)
        
        
    def forward(self, x):
        N, z_dim = x.shape
        out = F.relu(self.fc1(x))
        out = F.relu(self.fc2(out))
        out = out.view(N, 1, 18, 18)
        out = torch.sigmoid(self.transconv3(out))
        return out

### Utils 

Some useful functions:
- **plot_generated_images**: Plot images generated by a VAE
- **training_vae**: Training loop for a VAE

In [None]:
def plot_generated_images(vae):
    """
    Plot images generated by a VAE
    """
    vae.eval()
    N_img = 100
    fig, axs = plt.subplots(nrows=10, ncols=10, figsize=(13,13), sharex=True, sharey=True, tight_layout=True)
    fig.suptitle("Image generation", fontsize=15)
    for i in range(N_img):
        with torch.no_grad():
            a_z = torch.randn(1,vae.z_dim)
            a_img = vae.decoder(a_z)
            axs.flat[i].imshow(a_img[0].permute(1, 2, 0), cmap='Greys')
    return fig, axs

def training_vae(n_epochs, optimizer, model, loss_fn, train_loader, kld_weight=True, device=None, mse_threshold=0.08,
                w_min=10e-9, w_max=10e-3):
    """
    Training loop for a VAE
    """
    if device is None:
        device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
    model.train()
    
    # Weighted version of the loss
    if kld_weight:
        w = w_min
    else:
        w = 1
    
    for epoch in range(1, n_epochs + 1):
        loss_train = 0.0
        loss_train_mse = 0.
        loss_train_kld = 0.

        for imgs in train_loader:

            imgs = imgs.to(device=device) 

            outputs = model(imgs)
            # Final loss is the sum of the 2 terms (with potentially a weight term)
            mse_loss, kld_loss = loss_fn(outputs, imgs, model.mu, model.logvar)
            
            # Update weight on KLD loss
            with torch.no_grad():
                if kld_weight:
                    if mse_loss < mse_threshold:
                        w = min(w*2, w_max)
                    else:
                        w = max(w/2, w_min)
                
            loss = mse_loss + w*kld_loss
        
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            loss_train += loss.item()
            loss_train_mse += mse_loss.item()
            loss_train_kld += kld_loss.item()

        
        if epoch == 1 or epoch % 5 == 0:
            print('{}  |  Epoch {}  |  Training loss {:.3f}  |  MSE loss {:.3f}  |  KLD loss {:.5}  | KLD weight {:.8f}'.format(
                datetime.now().strftime("%H:%M:%S"), 
                epoch,
                loss_train / len(train_loader),
                loss_train_mse / len(train_loader),
                loss_train_kld / len(train_loader),
                w,
            ))


In [None]:
# Load the data

# It's already hard enough to train VAE, a subset of MNIST will be more than enough.
labels_kept = [0,1,2,3]
n_labels = len(labels_kept)
data_train, data_val, data_test = load_MNIST(label_list=labels_kept)
imgs_train = [img for img, _ in data_train]
imgs_val = [img for img, _ in data_val]

# 3. Variational AutoEncoder

*related videos from the curriculum*

- [Lecture 13 | Generative Models](https://www.youtube.com/watch?v=5WoItGTWV54&list=PL3FW7Lu3i5JvHM8ljYj-zLfQRF3EO8sYv&index=17) 
  - from 27:05 to 31:05: Introducting VAE 
  - *Let's forget about tractability :) *
  - from 40:55 to 44:00: VAE loss and VAE training
  - from 44:00 to 49:00: Generating data using VAE and summary 

**Introduction to VAE**

A Variational AutoEncoder (VAE) is a neural network that is similar to a AE in its structure as it is composed of 2 sub-networks: an Encoder and a Decoder. However, our main objective is no longer to efficiently represent some data lying on a non-linear manifold. Instead we want to generate some new data that would look like the training data but that is just a simple copy of a training input! We want **new** data.

To do so, we want our model to first get a good representation of what real data look like and then to be able to generate new plausible instances. The *get a good representation* part seems to be similar to what an Encoder can do and the *generate plausible instances* part to what a Decoder can do. However we can not really *generate new* data with a Decoder but just *reconstruct*. The *generate* part is actually what makes VAE different from AE and it is achieved in 2 steps: 

- a reparameterization step in the forward pass between the Encoder and Decoder
- a KL-divergence term added to the loss function 

These 2 steps aim at forcing the elements of latent space to look like normally distributed samples. Once the training is finished, this forcing will allow us to generate new data by simply giving a random normally distributed sample to the Decoder (and the Encoder can be thrown away, so this is the opposite of AE where we could throw away the Decoder and keep the Encoder).

**Reparameterization**

The reparameterization step consists in defining the latent vector ``z`` not as the output of the encoder but as a random sample from $\mathcal{N}$(``mu``, ``std``$)$ where ``mu`` and ``std`` are the actual outputs of the Encoder (for computational reasons the Encoder actually returns ``mu`` and ``logvar (=log(std**2))``)

**KL-divergence**

The second difference between a VAE and an AE is in the loss function. In addition to the reconstruction term we want to force the Encoder to learn the parameters of a normal distribution. There exists a measure for that, the Kullback–Leibler divergence or simply KL-divergence ([Wikipedia](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence), [PyTorch](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html?highlight=kld#torch.nn.KLDivLoss)). The KL-divergence measures how one probability distribution $P$ is different from a second $Q$ by computing the following:

$$D_{KL}(P||Q) = \int_{-\infty}^{+\infty} p(x)log\Big(\frac{p(x)}{q(x)}\Big) \,dx \$$

Where $p$ and $q$ are the probability densities of the probability distributions $P$ and $Q$. This measure can be used for different purposes and have as many interpretations (see [Interpretations](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence#Interpretations) section from the Wikipedia page) but in the context of Variational AutoEncoder we will interpret $Q$ as our prior and $P$ as our *true* distribution and $D_{KL}(P||Q)$ can then be interpreted as the information lost when our prior $Q$ is used to approximate $P$. In our very specific case where $Q \sim \mathcal{N}(0, 1)$ and $P \sim \mathcal{N}$(``mu``, ``std``$)$, $D_{KL}(P||Q)$ can be formulated as follows (see [Multivariate normal distributions](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence#Multivariate_normal_distributions) section from the Wikipedia page):

$$D_{KL}(P||Q) = \frac{1}{2} \sum_{i=0}^{z_{dim}-1} \Big( \sigma_i^2 + \mu_i^2 -1 -log(\sigma_i^2) \Big) \qquad \text{with std} = \begin{bmatrix}\sigma_0 \cdots \sigma_{z_{dim}-1} \end{bmatrix} \quad \text{and mu} = \begin{bmatrix}\mu_0 \cdots \mu_{z_{dim}-1} \end{bmatrix}  $$


--------------------
## TODO

### MyVAE class: A variational AutoEncoder

Complete the ``MyVAE`` class below (that is the variational counterpart of the ``MyAE`` class implemented in section 2). You don't have to start from scratch, you can re-use the ``MyEncoder`` and the ``MyDecoder`` classes from section 2. However a few details must be adapted:

1. Encoder must return 2 tensors of shape ``(N, z_dim)``: one for ``mu`` and one for ``logvar (=log(var)=log(std**2))`` **or equivalently** use ``MyEncoder`` with ``z_dim = 2*z_dim`` and then define ``mu`` and ``logvar`` in the forward method as follows: 

  ```
  self.mu_logvar = self.encoder(x)               # Output of the encoder, shape=(N, 2*z_dim)
  self.mu = self.mu_logvar[:,:self.z_dim]        # mu                   , shape=(N, z_dim)   (1st half of the encoded vector)
  self.logvar = self.mu_logvar[:,self.z_dim:]    # logvar               , shape=(N, z_dim)   (2nd half of the encoded vector)
  ```

2. On top of returning the reconstructed image, the forward pass of your VAE must store ``mu`` and ``log`` (as suggested in the lines above) because we'll need them when computing the KL-divergence term of the loss function.

3. A ``reparameterization`` method must be added that draws a sample $z$ (of shape ``(N, z_dim)``) from $\mathcal{N}(0,1)$ and return $z \times std + mu$ so that it corresponds to a sample from $\mathcal{N}$(``mu``, ``std``$)$ (Reminder: ``logvar = log(std**2))`` so ``std = exp(logvar/2)``). **Hint** you can use [torch.randn](https://pytorch.org/docs/stable/generated/torch.randn.html?highlight=randn#torch.randn) or [torch.randn_like](https://pytorch.org/docs/stable/generated/torch.randn_like.html?highlight=randn#torch.randn_like) or [torch.nn.init.normal_](https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.normal_). This method is to be called between the Encoder and the Decoder both during the training and can also be thrown away once the training is complete.

4. Decoder is exactly same as for an AutoEncoder

5. Write a ``generate_images`` method that takes as parameter an integer ``N_imgs`` defining the number of images to generate and returns ``imgs_generated`` (shape ``(N_imgs, C_in, H_in, W_in)``) the images generated by the VAE. **Hint** you can use [torch.randn](https://pytorch.org/docs/stable/generated/torch.randn.html?highlight=randn#torch.randn)

### loss_VAE: A loss adapted to VAE

Complete the ``loss_VAE`` function below where 
- ``inputs`` is the original images (shape ``(N, C_in, H_in, W_in)``)
- ``outputs`` is the reconstructed images (shape ``(N, C_in, H_in, W_in)``), 
- ``mu`` and ``logvar (=log(std**2)`` are the outputs of the Encoder representing the parameters of our normal distribution $P$  (both of shape ``(N, z_dim)``)). 

It returns the 2 terms of the VAE loss function:
- ``mse_loss``, the reconstruction term. **Hint** you can use [F.mse_loss](https://pytorch.org/docs/stable/nn.functional.html?highlight=mse_loss#torch.nn.functional.mse_loss) with ``reduction="mean"``
- ``kld_loss``, the KL-divergence term. which is defined in the cell above. To adapt formula to batch computations, we need to re-write it as follows:

$$D_{KL}(P||Q) =\text{mean}\Big( \frac{1}{2} \sum_{i=0}^{z_{dim}-1} \Big( \sigma_{:,i}^2 + \mu_{:,i}^2 -1 -log(\sigma_{:,i}^2) \Big) \Big) \qquad \text{with std} = \begin{bmatrix}\sigma_{:,0} \cdots \sigma_{:, z_{dim}-1} \end{bmatrix} \quad \text{and mu} = \begin{bmatrix}\mu_{:,0} \cdots \mu_{:, z_{dim}-1} \end{bmatrix}  $$




In [None]:
class MyVAE(nn.Module):

    def __init__(self, z_dim):
        super().__init__() 
        # Latent space dimension
        self.z_dim = z_dim
        # Encoder similar to what we used for the AE but used to encode both mu and logvar 
        #TODO
        self.encoder = 
        # There is no difference between a VAE and AE decoder
        #TODO
        self.decoder = 

    def reparameterize(self):
        """
        Reparameterization: draw a sample z of shape (N, z_dim) (z~N(mu, logvar))

        mu and logvar should be accessible via self.mu and self.logvar
        """
        # Initialize a vector z with the right shape whose elements are drawn from a normal distribution N(0, 1)
        #TODO
        # Shift z so that it is equivalent to a sample drawn from N(mu, std)
        #TODO
        return z
        
    def forward(self, x):
        # Encode the data and output the estimated parameters mu and logvar
        #TODO
        # mu can be defined as the first half of the encoder output
        #TODO
        self.mu =
        # logvar can be defined as the second half of the encoder output
        #TODO
        self.logvar =

        # Reparameterization: draw a sample z of shape (N, z_dim) (z~N(mu, logvar))
        # by calling the reparameterize method
        #TODO

        # Generate (decode) an image from the sample z
        #TODO
        return out

    def generate(self, N_imgs):
        """
        Generate new images by giving sampled latent vectors from N(0,1) to the decoder
        """
        #TODO
        return imgs_generated

def loss_VAE(inputs, outputs, mu, logvar):
    """
    Loss for a VAE: a reconstruction term (mse loss) and a distribution term (kl divergence)
    """
    # Regular reconstruction term using the mse loss, same as what we used for the AutoEncoder
    #TODO
    mse_loss = 
    # Distribution term: force the latent space to behave like a normal distribution
    # Special case of the KL divergence when the prior is Q~N(0,1) and P~N(mu, std)
    #TODO
    kld_loss = 
    return mse_loss, kld_loss
                

### Training your Variational AutoEncoder

Run the cell below to train your VAE.
Keep in mind that training unsupervised model is not easy and that it is okay in this assignment if you don't get good results at all.

In [None]:
z_dim = 15

vae = MyVAE(z_dim=z_dim)
vae.to(device=device)

train_loader_imgs = DataLoader(imgs_train, batch_size=512, shuffle=True)
val_loader_imgs = DataLoader(imgs_val, batch_size=512, shuffle=True)

lr = 0.0001

optimizer = optim.Adam(vae.parameters(), lr=lr)
loss_fn = loss_VAE

training_vae(
    n_epochs = 300,
    optimizer = optimizer,
    model = vae,
    loss_fn = loss_fn,
    train_loader = train_loader_imgs,
    kld_weight = False,
    device=device,
)

In [None]:
vae.to(device=torch.device('cpu')) 
fig, axs = plot_true_VS_reconstructed(vae, imgs_train)
fig.suptitle("Training dataset", fontsize=15)
plt.show()
fig, axs = plot_true_VS_reconstructed(vae, imgs_val)
fig.suptitle("Validation dataset", fontsize=15)
plt.show()
fig, axs = plot_generated_images(vae)
plt.show()

### Example of VAE results on the training dataset
![Example of VAE reconstruction results on the training dataset with the KLD weight term (see VAE_train_reconstruction image)](VAE_train_reconstruction.png)

### Example of VAE results on the validation dataset

![Example of VAE reconstruction results on the validation dataset with the KLD weight term (see VAE_val_reconstruction image)](VAE_val_reconstruction.png)

--------------------
## TODO
Analyse the results in the image above.

1. Comment the behavior of the VAE when reconstructing different digits.
2. Mode collapse is a very common problem when training VAE (and other unsupervised method), can you explain what this problem is?

### Example of VAE generation
![Example of VAE generation results with the KLD weight term (see VAE_generation image)](VAE_generation.png)

--------------------
## TODO
Analyse the results in the image above.

1. Do the generated images look similar to the outputs of the reconstructed images?
1. Would you say that the VAE learnt well the parameters of a normal distribution?

### Training your Variational AutoEncoder: with a weighted KLD loss

In order to fix the 'mode collapse' problem and to make sure our model learns to both reconstruct images and organize the latent space we will add a weight on the KL divergence term. We will first define this weight as extremely low so that the model learns first how to reconstruct and once the reconstruction is considered good enough we will increasingly amplify this weight so that the model re-organize the latent space properly.

Run the cell below to train your VAE.
Keep in mind that training unsupervised model is not easy and that it is okay in this assignment if you don't get good results at all.

In [None]:
vae = MyVAE(z_dim=z_dim)
vae.to(device=device)
epochs = 300
lr = 0.0001
optimizer = optim.Adam(vae.parameters(), lr=lr)
kld_weight = True

training_vae(
    n_epochs = epochs,
    optimizer = optimizer,
    model = vae,
    loss_fn = loss_fn,
    train_loader = train_loader_imgs,
    kld_weight = kld_weight,
)

In [None]:
vae.to(device=torch.device('cpu')) 
fig, axs = plot_true_VS_reconstructed(vae, imgs_train)
fig.suptitle("Training dataset", fontsize=15)
plt.show()
fig, axs = plot_true_VS_reconstructed(vae, imgs_val)
fig.suptitle("Validation dataset", fontsize=15)
plt.show()
fig, axs = plot_generated_images(vae)
plt.show()

### Example reconstruction
![Example of VAE reconstruction results on the training dataset with the KLD weight term (see VAE_weighted_train_reconstruction image)](VAE_weighted_train_reconstruction.png)

![Example of VAE reconstruction results on the validation dataset with the KLD weight term (see VAE_weighted_val_reconstruction image)](VAE_weighted_val_reconstruction.png)

--------------------
## TODO
Analyse the results in the image above.

1. Comment the behavior of the VAE when reconstructing different digits.

### Example generation
![Example of VAE generation results with the KLD weight term (see VAE_weighted_generation image)](VAE_weighted_generation.png)

--------------------
## TODO
Analyse the results in the image above.

1. Do the generated images look similar to the outputs of the reconstructed images?
1. Would you say that the VAE learnt well the parameters of a normal distribution?