# Introduction
<hr style="border:2px solid black"> </hr>

<div class="alert alert-warning">
<font color=black>

**What?** Variational autoencoder on CIFAR-10

</font>
</div>

# Import modules
<hr style="border:2px solid black"> </hr>

<div class="alert alert-info">
<font color=black>

- Note on module installation (as per MacOS)
    - `pip install pytorch-lightning`
    - `pip install pytorch-lightning-bolts`
    - `pip install wandb`

</font>
</div>

In [2]:
import pytorch_lightning as pl
from torch import nn
import torch
from torch.nn import functional as F
from pl_bolts.models.autoencoders.components import (
    resnet18_decoder,
    resnet18_encoder,
)
from pl_bolts.datamodules import CIFAR10DataModule
from matplotlib.pyplot import imshow, figure
import numpy as np
from torchvision.utils import make_grid
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization

# Load the dataset
<hr style="border:2px solid black"> </hr>

In [None]:
datamodule = CIFAR10DataModule('.')

# VAE Loss
<hr style="border:2px solid black"> </hr>

<div class="alert alert-info">
<font color=black>

- The VAE uses a loss function called ELBO.
- **ELBO** stands for Evidence Lower Bound
- This is made of two parts as in:

$$
ELBO = \color{red}{KL} + \color{blue}{reconstruction} \\ 
= \color{red}{\min \mathbb{E}_{q}[ \log q(z|x) - \log p(z)]} - \color{blue}{\mathbb{E}_{q} \log p(x|z)}
$$

</font>
</div>

## ELBO and Monte Carlo Sampling

<div class="alert alert-info">
<font color=black>

- Most tutorials assume that `p(z)`, the prior is gaussian and also assumes that `p(z|x)` is also gaussian. This come handy as the KL is then known by a closed-form solution. So, most tutorials end up with a KL divergence that looks like:

$$
-0.5 \sum{1 \log{(\sigma)} - \mu^2 - e^{\log{(\sigma)}}}
$$


- But if we want to retain the flexibility to modify distributions as needed, we then need another strategy. Since we don't know the KL between all possible pairs of distributions, we'll actually just use **Monte-Carlo sampling** for this.

</font>
</div>

## ELBO intuition

<div class="alert alert-info">
<font color=black>

- KL is used to measure how different two distribution are.
- You can also see it in another way: KL Divergence helps us to measure just how much information we lose when we choose an approximation, thus we can even use it as our objective function to pick which approximation would work best for the problem at hand.
- In this cell we'll create two normal distributions and we'll get the value of KL. Then we'll modify the second one by making it more similar to the first. Upon recomputing the KL you'll see how the KL value is lowered. This means the two distibutiona re becoming more similar.

</font>
</div>

In [41]:
# Let's create a normal distribution of p the prior 
p = torch.distributions.Normal(0, 1)
# Let's create a normal distribution of q, the posterior
q = torch.distributions.Normal(2, 4)

# Let's draw a sample (z) from the q distribution
z = q.rsample()
print(z)

# Log propability of z under p distribution
log_pz = p.log_prob(z)
# Log propability of z under q distribution
log_qzx = q.log_prob(z)

print('log prob pz: ', log_pz, 'prob:', torch.exp(log_pz))
print('log prob qzx: ', log_qzx, 'prob:', torch.exp(log_qzx))

tensor(3.5350)
log prob pz:  tensor(-7.1670) prob: tensor(0.0008)
log prob qzx:  tensor(-2.3789) prob: tensor(0.0927)


In [42]:
kl_divergence = log_qzx - log_pz
kl_divergence

tensor(4.7881)

In [43]:
# now, if we manually move q closer to p, we see that this distrance has shrunk.
q_new = torch.distributions.Normal(0.1, 1.1)

log_qzx_new = q_new.log_prob(z)
print('log prob qzx: ', log_qzx_new, 'prob:', torch.exp(log_qzx_new))

log prob qzx:  tensor(-5.8899) prob: tensor(0.0028)


In [44]:
new_kl_divergence_new = log_qzx_new - log_pz
new_kl_divergence_new

tensor(1.2771)

<div class="alert alert-info">
<font color=black>

- Now, this z has a single dimension. But in the real world, we care about n-dimensional zs. To handle this in the implementation, we simply sum over the last dimension. 
- The **trick here** is that when sampling from a univariate distribution (in this case Normal), if you sum across many of these distributions, it’s equivalent to using an n-dimensional distribution (n-dimensional Normal in this case). 
- This generic form of the KL is called the **monte-carlo approximation**. This means we sample z many times and estimate the KL divergence. (in practice, these estimates are really good and with a batch size of 128 or more, the estimate is very accurate)

</font>
</div>

# Implementation
<hr style="border:2px solid black"> </hr>

In [None]:
class VAE(pl.LightningModule):
    def __init__(self, enc_out_dim=512, latent_dim=256, input_height=32):
        super().__init__()

        self.save_hyperparameters()

        # encoder, decoder
        self.encoder = resnet18_encoder(False, False)
        self.decoder = resnet18_decoder(
            latent_dim=latent_dim,
            input_height=input_height,
            first_conv=False,
            maxpool1=False
        )

        # distribution parameters
        self.fc_mu = nn.Linear(enc_out_dim, latent_dim)
        self.fc_var = nn.Linear(enc_out_dim, latent_dim)

        # for the gaussian likelihood
        self.log_scale = nn.Parameter(torch.Tensor([0.0]))

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

    def gaussian_likelihood(self, mean, logscale, sample):
        scale = torch.exp(logscale)
        dist = torch.distributions.Normal(mean, scale)
        log_pxz = dist.log_prob(sample)
        return log_pxz.sum(dim=(1, 2, 3))

    def kl_divergence(self, z, mu, std):
        # --------------------------
        # Monte carlo KL divergence
        # --------------------------
        # 1. define the first two probabilities (in this case Normal for both)
        p = torch.distributions.Normal(
            torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)

        # 2. get the probabilities from the equation
        log_qzx = q.log_prob(z)
        log_pz = p.log_prob(z)

        # kl
        kl = (log_qzx - log_pz)
        kl = kl.sum(-1)
        return kl

    def training_step(self, batch, batch_idx):
        x, _ = batch

        # encode x to get the mu and variance parameters
        x_encoded = self.encoder(x)
        mu, log_var = self.fc_mu(x_encoded), self.fc_var(x_encoded)

        # sample z from q
        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()

        # decoded
        x_hat = vae.decoder(z)

        # reconstruction loss
        recon_loss = self.gaussian_likelihood(x_hat, self.log_scale, x)

        # kl
        kl = self.kl_divergence(z, mu, std)

        # elbo
        elbo = (kl - recon_loss)
        elbo = elbo.mean()

        self.log_dict({
            'elbo': elbo,
            'kl': kl.mean(),
            'recon_loss': recon_loss.mean(),
            'reconstruction': recon_loss.mean(),
            'kl': kl.mean(),
        })

        return elbo

In [None]:
pl.seed_everything(1234)

vae = VAE()
trainer = pl.Trainer(gpus=0, max_epochs=30, progress_bar_refresh_rate=10)
trainer.fit(vae, datamodule)

# Post-processing
<hr style="border:2px solid black"> </hr>

In [None]:
figure(figsize=(8, 3), dpi=300)

# Z COMES FROM NORMAL(0, 1)
num_preds = 16
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
z = p.rsample((num_preds,))

# SAMPLE IMAGES
with torch.no_grad():
    pred = vae.decoder(z.to(vae.device)).cpu()

# UNDO DATA NORMALIZATION
normalize = cifar10_normalization()
mean, std = np.array(normalize.mean), np.array(normalize.std)
img = make_grid(pred).permute(1, 2, 0).numpy() * std + mean

# PLOT IMAGES
imshow(img);

# Clean-up the repository
<hr style="border:2px solid black"> </hr>

In [None]:
try:
    os.rm("cifar-10-python.tar.gz")
except:
    pass

# Conclusions
<hr style="border:2px solid black"> </hr>

<div class="alert alert-danger">
<font color=black>

- Some things may not be obvious still from this explanation. First, each image will end up with its own q. The KL term will push all the qs towards the same p (called the prior). But if all the qs, collapse to p, then the network can cheat by just mapping everything to zero and thus the VAE will collapse.

- The reconstruction ter, forces each `q` to be uniqye and spread iut so taht the image can be reconstructed correctly. This keeps all the qs from collapsing onto each other.

- This is done in a way to seek a balance. This is also why you may experience **instability** in training VAEs! GANs?

</font>
</div>

# References
<hr style="border:2px solid black"> </hr>

<div class="alert alert-warning">
<font color=black>

- https://towardsdatascience.com/variational-autoencoder-demystified-with-pytorch-implementation-3a06bee395ed
- https://github.com/williamFalcon/pytorch-lightning-vae  
- https://colab.research.google.com/drive/1_yGmk8ahWhDs23U4mpplBFa-39fsEJoT?usp=sharing#scrollTo=EYDKIsTtk3hJ
- https://github.com/ethen8181/machine-learning/blob/master/model_selection/kl_divergence.ipynb 

</font>
</div>