# Assignment 3: Variational Autoencoders

*Author:* Thomas Adler / Eric Volkmann

*Copyright statement:* This  material,  no  matter  whether  in  printed  or  electronic  form,  may  be  used  for  personal  and non-commercial educational use only.  Any reproduction of this manuscript, no matter whether as a whole or in parts, no matter whether in printed or in electronic form, requires explicit prior acceptance of the authors.

## Theory

### Exercise 1: Derivation of ELBO (1 pts)

Prove the evidence lower bound (ELBO), which states that 
\begin{align*}
    \mathbb E_{q(z \mid x)}[\log \frac{p(x, z)}{q(z \mid x)}] \leq \log p(x),
\end{align*}
for some distributions $q(z \mid x), p(x, z), p(x)$. 

########## YOUR SOLUTION HERE ##########

### Exercise 2: Decomposition of ELBO (1 pts)

We have the observable variable $x$ and the latent variable $z$. 
We can observe $x$ only through a dataset $\mathcal D$. 
For $z$, we choose the prior $p(z) = \mathcal N(0, 1)$ and the posterior (or encoder) distribution $q(z \mid x) = \mathcal N(\mu(x), \sigma^2(x))$, where $\mu(x)$ and $\sigma^2(x)$ are neural networks with shared parameters, i.e., $(\mu(x), \sigma^2(x)) = \text{encode}(x)$. 
This makes $q(z \mid x)$ easy to sample from. 
Subsequently, we can use the decoder $p(x \mid z)$ to reconstruct $x$. 

There are several interpretations as to why the maximization of the ELBO is a suitable objective for VAEs. 
To obtain one of them, prove the following identity 
\begin{align*}
    \mathbb E_{q(z \mid x)}[\log \frac{p(x, z)}{q(z \mid x)}] = \mathbb E_{q(z \mid x)}[\log p(x \mid z)] - D_{\text{KL}}(q(z \mid x) \mathbin{||} p(z))
\end{align*}
where $D_{\text{KL}}$ denotes the Kullback-Leibler divergence. 

########## YOUR SOLUTION HERE ##########

### Exercise 3: Reconstruction Term (1 pts)

The first term on the right-hand side of above equation is called the *reconstruction term*. 
To see why, analyze it under the expectation over $x$. 
To be more clear, let us denote the reconstruction of $x$ by the VAE as $\tilde x$. 
Prove the identity
\begin{align*}
    \mathbb{E}_{p(x)} [\mathbb{E}_{q(z \mid x)}[\log p(\tilde x \mid z)]] = - \mathcal H(p(x \mid z)) - \mathbb E_{p(z)}[D_{\text{KL}}(p(x \mid z) \mathbin{||} p(\tilde x \mid z)))].
\end{align*}
Which effects does the maximization of the reconstruction term have on the different parts of the VAE?
That is, interpret the two terms on the right-hand side in the context of VAE training. 

########## YOUR SOLUTION HERE ##########

### Exercise 4: Regularization Term (1 pts)

The second term of the decomposition of the ELBO is a regularization term. 
Again, we analyze it under the expectation over $x$. 
Prove that 
\begin{align*}
    \mathbb E_{p(x)}[D_{\text{KL}}(q(z \mid x) \mathbin{||} p(z))] = I(x, z),
\end{align*}
where $I(\cdot, \cdot)$ denotes mutual information. 
For each side of this identity, give a different interpretation of the imposed regularization. 

########## YOUR SOLUTION HERE ##########

### Exercise 5: Derivation of Decoder Loss Function (1 pts)

Our considerations of the ELBO so far were somewhat abstract. 
In this exercise we will derive a concrete loss function that is ready for implementation. 
To this end, revisit the decomposition of the ELBO from exercise 2. 
We will use a softmax function on the decoder output to parameterize $p(x \mid z)$ as a categorical distribution with $k$ categories. 
We denote the softmax-activated decoder output as $\sigma(z)_i, i \in \{1, \dots, k\}$. 
Prove that under these assumptions 
\begin{align*}
    \log p(x \mid z) = x \log \sigma(z)_x.
\end{align*}
How do we obtain the expectation under $q(z \mid x)$?
Note that in practice $x$ will be a vector and we will use 
\begin{align*}
    \log p(x_1, \dots, x_d \mid z) = \sum_{j=1}^d x_j \log \sigma(z)_{x_j}.
\end{align*}
Which assumption is implied in this identity? 
Is it justified?

Since we maximize the ELBO, our loss function will be 
\begin{align*}
    \mathcal L_{\text{decoder}} = -\sum_{j=1}^d x_j \log \sigma(z)_{x_j}.
\end{align*}
Thus, we can just use the familiar cross-entropy loss function. 

*Note that we used $\sigma^2(x)$ as one of the outputs of the encoder. Here, $\sigma(x)$ denotes the softmax-activated decoder output. So you should read the square more as a part of the name than an operation.*

########## YOUR SOLUTION HERE ##########

### Exercise 6: Derivation of Encoder Loss Function (1 pts)

Prove that the encoder loss function derived from the ELBO is 
\begin{align*}
    \mathcal L_{encoder} = D_{\text{KL}}(q(z \mid x) \mathbin{||} p(z)) = \frac12 (\mu^2(x) + \sigma^2(x) - \log \sigma^2(x) - 1).
\end{align*}
Interpret this result. 

########## YOUR SOLUTION HERE ##########

### Exercise 7: Reparametrization Trick (1 pts)

We are almost ready to implement our VAE. 
There remains only one small problem to solve. 
That is, passing down gradients from the decoder to the encoder is not possible due to the sampling step between them, i.e., drawing $z \sim \mathcal N(\mu(x), \sigma^2(x))$. 
This sampling step introduces a discontinuity which we cannot differentiate. 
Luckily, there is a simple solution for that. 
We sample a different random variable $\varepsilon \sim \mathcal N(0, 1)$ and define $z = g(\varepsilon, \mu(x), \sigma^2(x))$ via a deterministic function $g$. 
This is known as the *reparameterization trick*. 
What form must $g$ have? 
How does it resolve the discontinuity problem? 
Argue why $g(\varepsilon, \mu(x), \sigma^2(x))$ and $\varepsilon$ have the same distribution only with different moments. 

########## YOUR SOLUTION HERE ##########

## Implementation and Training

### Exercise 8: VAE Training (1 pts)

Below you find a basic autoencoder architecture and a MNIST dataloader. 
Implement the forward passes and a training loop that features the reparametrization trick and maximizes the ELBO as derived in the previous exercises. 
Visualize the training progress. 

You can then try to modify the architecture (e.g. CNN-based architecture, different activation functions, residual connections,...) or the training loop (weighting of loss terms, weight decay, gradient clipping...) to improve the results

If you want to try a latent dim > 1, you can assume for a multivariate gaussian  the KL-Divergence has this simple form

$D_{KL}(\mathcal{N}(\mu, \exp( I \sigma^2))) || \mathcal{N}(0, 1)) = \frac{1}{2} \sum_{j=1}^{latent\_dim} (\sigma_j^2 + \mu_j^2 - 1 - \log(\sigma_j^2))$

However, in modern PyTorch it is recommended to use `torch.distributions.kl.kl_divergence()`, check out https://pytorch.org/docs/stable/distributions.html

You can initialize a `torch.distribution` and use `.rsample()`, which is the reparametrization trick built-in to PyTorch.

*Hint: You can you these features to also use gaussian which have a more complicated covariance matrix than $I \sigma^2$*

In [1]:
from IPython.display import Image
from IPython.core.display import HTML 
Image(url= "https://external-content.duckduckgo.com/iu/?u=https%3A%2F%2Fmlarchive.com%2Fwp-content%2Fuploads%2F2022%2F09%2FNew-Project-3.png&f=1&nofb=1&ipt=c92212efee02295a5612e1ef0639d4ceb260a7d293579a10afdacb4cd55e27ef&ipo=images")

In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from itertools import chain
from tqdm import tqdm

In [2]:
batch_size = 128

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

In [3]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, eta):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.SiLU(),  # Swish activation function
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.SiLU(),  # Swish activation function
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.SiLU(),  # Swish activation function
            nn.Linear(hidden_dim // 4, hidden_dim // 8),
            nn.SiLU(),  # Swish activation function
            nn.Linear(hidden_dim // 8, latent_dim), # 2 for mean and variance.
        )
        self.eta = eta ## Hint: add eta to the variance predictions for numerical stability
        self.latent_dim = 2
        self.softplus = nn.Softplus()

    def forward(self, x):
        ########## YOUR SOLUTION HERE ##########
        


class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 8),
            nn.SiLU(),  # Swish activation function
            nn.Linear(hidden_dim // 8, hidden_dim // 4),
            nn.SiLU(),  # Swish activation function
            nn.Linear(hidden_dim // 4, hidden_dim // 2),
            nn.SiLU(),  # Swish activation function
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.SiLU(),  # Swish activation function
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid(),
        )

    def forward(self, z):
        ########## YOUR SOLUTION HERE ##########

In [4]:
torch.random.manual_seed(420)

learning_rate = 1e-3
weight_decay = 1e-2
num_epochs = 10
eta = 1e-9

optimizer_class = torch.optim.AdamW #You can try out different optimizers

In [1]:
########## YOUR SOLUTION HERE ##########

### Exercise 9: VAE Inference (1 pts)

- Compute the test loss
- Visualize some reconstruction examples from both training and test set. 
- Generate new samples: Since the latent variables $z$ have the simple distribution $\mathcal N(0, I)$, we can easily generate new samples from the training distribution by sampling $z$ and feeding it to the decoder.

In [23]:
def imshow(ax, img):
    img = img / 2 + 0.5 # img in [-1, 1]
    ax.imshow(img.numpy())

In [24]:
# compute test error
encoder.train(False)
decoder.train(False)
running_loss = 0
running_recon = 0
running_kld = 0

In [2]:
########## YOUR SOLUTION HERE ##########

### Bonus Exercise: Visualization (1 pts)

Try to visualize the distribution of the encoded latents of the test set. One possibility is to use a scatter plot and color the points according to their label.

Next, try to interpolate in the latent space. Can you generate smooth transitions between the digits? 
Discuss your results! 

In [None]:
########## YOUR SOLUTION HERE ##########