# VAE Improvements. 

## How to improve variational family?

Last time we spoke about ordinary VAE. We saw, that it is quite simple model and sort of naive in its assumption about variational family.

Specifically, it assumes that posterior distribution of latent variables is unimodal factorized Gaussian, and it looks for such parameters $\theta$ of decoder, which maximize a lowerbound on marginal likelihood with limitations above. This could and leads to worse latent representation and worse generator model.

To address this limitation, today we will go through some improvements, which allow us to make posterior more expressive

In [None]:
## First, import some models and functions we will use
from models import Base, VAE, IWAE, VAE_with_flows, VAE_MCMC
from main import make_dataloaders, get_activations
from models.samplers import HMC

## Visualization
from plotting import plot_digit_samples, plot_posterior
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib widget

## Math processing
import numpy as np
import torch
import torch.nn
from torchvision.transforms import ToTensor
## Wrapper on top of PyTorch, which ease work on neural nets
import pytorch_lightning as pl
## Is there is a cuda device on the machine, lets use it!
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

First, load data.

As previously, we will be using FashionMNIST dataset.

In [2]:
train_loader, val_loader = make_dataloaders(dataset='fashionmnist',
                                            gpus=1,
                                            batch_size=32,
                                            val_batch_size=64)

In [3]:
act_func = get_activations()

In [4]:
## Function to recover given images
def recover_image(model, pics, num_samples=50):
    with torch.no_grad():
        pics_rec = torch.sigmoid(model.step([pics.to(device), None])[1]).cpu().view(
                (num_samples, -1, 784)).mean(0).view((64, 28, 28))
    return pics_rec

## Let us fix the random sample, which we will use for image generation:
random_vector = torch.randn((64, 20)).to(device)

## And let us write a function to generate images:
def generate_image(model, random_vector):
    with torch.no_grad():
        generated = torch.sigmoid(model(random_vector)).cpu().view((64, 28, 28))
    return generated

def get_posterior_samples(model, pics, num_samples=50, hidden_dim=20):
    with torch.no_grad():
        z_samples = model.step([pics.to(device), None])[2].reshape(num_samples, -1, hidden_dim).mean(0).cpu()
    return z_samples

## Sample a batch from validation dataset
pics = None
labels = None
for b in val_loader:
    pics = b[0]
    labels = b[1]
    break
    
whole_dataset = val_loader.dataset.data[:1000] * 1. / torch.max(val_loader.dataset.data)
whole_labels = val_loader.dataset.targets[:1000]

Then, let us load already trained models (despite we can train it within the notebook, it takes some time). So just load them

In [5]:
vae = VAE.load_from_checkpoint(checkpoint_path='./checkpoints/VAE_None.ckpt', act_func=act_func['tanh'], num_samples=50, hidden_dim=20).to(device)
iwae = VAE.load_from_checkpoint(checkpoint_path='./checkpoints/IWAE_None.ckpt', act_func=act_func['tanh'], num_samples=50, hidden_dim=20).to(device)
vae_iaf = VAE_with_flows.load_from_checkpoint(checkpoint_path='./checkpoints/VAE_with_flows_IAF.ckpt', act_func=act_func['tanh'], num_samples=10, hidden_dim=20, flow_type='IAF', num_flows=5, need_permute=False).to(device)
vae_bnaf = VAE_with_flows.load_from_checkpoint(checkpoint_path='./checkpoints/VAE_with_flows_BNAF.ckpt', act_func=act_func['tanh'], num_samples=10, hidden_dim=20, flow_type='BNAF', num_flows=5, need_permute=False).to(device)
vae_rnvp = VAE_with_flows.load_from_checkpoint(checkpoint_path='./checkpoints/VAE_with_flows_RNVP.ckpt', act_func=act_func['tanh'], num_samples=10, hidden_dim=20, flow_type='RNVP', num_flows=5, need_permute=True).to(device)
vae_mcmc = VAE_MCMC.load_from_checkpoint(checkpoint_path='./checkpoints/VAE_MCMC_None.ckpt', act_func=act_func['tanh'], num_samples=10, hidden_dim=20, n_leapfrogs=5, step_size=0.1, use_barker=True).to(device)

# Vanilla VAE model

Let us recap the vanilla VAE model which we studies last time.

We were otimizing the following objective:

$$
\mathcal{L}_{\phi, \theta}(X) = \frac{1}{N}  \sum_{i=1}^{N} \mathcal{L}_{\phi, \theta}(x_i) = \frac{1}{N} \sum_{i=1}^{N} \int_{z} q_\phi (z|x_i) \log \frac{p_\theta(z, x_i)}{q_\phi(z|x_i)} dz =  \frac{1}{N} \sum_{i=1}^{N} ( \frac{1}{K} \sum_{k=1}^{K} p_\theta(x_i | z_k) - \text{KL}(q_\phi(z|x_i)\|p(z)) )
$$

In case of Gaussian variational family and prior, the last term could be computed in closed form, so reducing variance of the estimator.

The gradients could be computed using "reparametrization trick".

This is the code for model training. It requires only 3 lines of code.
```python
model = VAE(act_func=act_func['tanh'], num_samples=50, hidden_dim=20)
trainer = pl.Trainer(gpus=1)
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
```

In [6]:
Base??

[0;31mInit signature:[0m [0mBase[0m[0;34m([0m[0mact_func[0m[0;34m,[0m [0mnum_samples[0m[0;34m,[0m [0mhidden_dim[0m[0;34m,[0m [0mname[0m[0;34m=[0m[0;34m'VAE'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Helper class that provides a standard way to create an ABC using
inheritance.
[0;31mSource:[0m        
[0;32mclass[0m [0mBase[0m[0;34m([0m[0mpl[0m[0;34m.[0m[0mLightningModule[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m__init__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mact_func[0m[0;34m,[0m [0mnum_samples[0m[0;34m,[0m [0mhidden_dim[0m[0;34m,[0m [0mname[0m[0;34m=[0m[0;34m"VAE"[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0msuper[0m[0;34m([0m[0mBase[0m[0;34m,[0m [0mself[0m[0;34m)[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0;31m# Encoder[0m[0;34m[0m
[0;34m[0m        [0mself[0m[0;34m.[0m[0mfc1[0m [0;34m=[

In [7]:
VAE??

[0;31mInit signature:[0m [0mVAE[0m[0;34m([0m[0mact_func[0m[0;34m,[0m [0mnum_samples[0m[0;34m,[0m [0mhidden_dim[0m[0;34m,[0m [0mname[0m[0;34m=[0m[0;34m'VAE'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Helper class that provides a standard way to create an ABC using
inheritance.
[0;31mSource:[0m        
[0;32mclass[0m [0mVAE[0m[0;34m([0m[0mBase[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0mloss_function[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mrecon_x[0m[0;34m,[0m [0mx[0m[0;34m,[0m [0mmu[0m[0;34m,[0m [0mlogvar[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0mBCE[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mbinary_cross_entropy_with_logits[0m[0;34m([0m[0mrecon_x[0m[0;34m,[0m [0mx[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m,[0m [0;36m784[0m[0;34m)[0m[0;34m,[0m [0mreduction[0m[0;34m=[0m[0;34m'none'[0m[0;34m)[0m[0;34m.[0m[0mview[0m

In [8]:
reconstructed = recover_image(vae, pics)
generated = generate_image(vae, random_vector)
plot_digit_samples(original=pics.squeeze(), reconstucted=reconstructed, generated=generated)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [22]:
z_posterior = get_posterior_samples(vae, whole_dataset.to(device))
plot_posterior([whole_labels.numpy(), z_posterior.cpu().numpy()])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# IWAE model

```python
model = IWAE(act_func=act_func['tanh'], num_samples=50)
trainer = pl.Trainer(gpus=1, deterministic=True)
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
```

Let us now move to more expressive models. The first one -- Importance Weightd Autoencoders, or IWAE


The idea is the following: instead of maximizing the previous ELBO, we will optimize another one, which corresponds to the k-sample importance weighting
estimate of the log-likelihood:

$$
\mathcal{L}_{\phi, \theta}^{\text{IWAE}}(X) = \frac{1}{N}  \sum_{i=1}^{N} \mathcal{L}_{\phi, \theta}^{\text{IWAE}}(x_i) = \frac{1}{N} \sum_{i=1}^{N} \mathbb{E}_{q_\phi(z|x_i)} \left( \frac{1}{K} \sum_{k=1}^{K} \log \frac{p_\theta(z_k, x_i)}{q_\phi(z_k|x_i)} \right) = \frac{1}{N} \sum_{i=1}^{N} \int_{z} q_\phi (z|x_i) \frac{1}{K} \sum_{k=1}^{K} \log \frac{p_\theta(z_k, x_i)}{q_\phi(z_k|x_i)} dz
$$

In [10]:
IWAE??

[0;31mInit signature:[0m [0mIWAE[0m[0;34m([0m[0mact_func[0m[0;34m,[0m [0mnum_samples[0m[0;34m,[0m [0mhidden_dim[0m[0;34m,[0m [0mname[0m[0;34m=[0m[0;34m'VAE'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Helper class that provides a standard way to create an ABC using
inheritance.
[0;31mSource:[0m        
[0;32mclass[0m [0mIWAE[0m[0;34m([0m[0mBase[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0mloss_function[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mrecon_x[0m[0;34m,[0m [0mx[0m[0;34m,[0m [0mmu[0m[0;34m,[0m [0mlogvar[0m[0;34m,[0m [0mz[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0mlog_Q[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mdistributions[0m[0;34m.[0m[0mNormal[0m[0;34m([0m[0mloc[0m[0;34m=[0m[0mmu[0m[0;34m,[0m[0;34m[0m
[0;34m[0m                                           [0mscale[0m[0;34m=[0m[0mtorch[0m[0;34m.[0m[0mexp[0m[0;34m([0m[0;36m0.5

In [11]:
reconstructed = recover_image(iwae, pics)
generated = generate_image(iwae, random_vector)
plot_digit_samples(original=pics.squeeze(), reconstucted=reconstructed, generated=generated)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [12]:
z_posterior = get_posterior_samples(iwae, whole_dataset.to(device))
plot_posterior([whole_labels.numpy(), z_posterior.cpu().numpy()])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# VAE model with flows

The next idea how we can improve on variational family is the usage of Normalizing flows.

It works as follows: given some input vector, it transform it in such a way, that the determinant of Jacobian this transformation is easy to compute.

TODO: Write formular, probably an example on NF on simple densities.

In [13]:
VAE_with_flows??

[0;31mInit signature:[0m
[0mVAE_with_flows[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mact_func[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnum_samples[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mhidden_dim[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mflow_type[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnum_flows[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mneed_permute[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Helper class that provides a standard way to create an ABC using
inheritance.
[0;31mSource:[0m        
[0;32mclass[0m [0mVAE_with_flows[0m[0;34m([0m[0mBase[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m__init__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mact_func[0m[0;34m,[0m [0mnum_samples[0m[0;34m,[0m [0mhidden_dim[0m[0;34m,[0m [0mflow_type[0m[0;34m,[0m [0mnum_flows[0m[0;34m,[0m [0mneed_permute[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m   

## IAF

```python
model = VAE_with_flows(act_func=act_func['tanh'], num_samples=10, num_flows=5, flow_type="IAF", need_permute=False)
trainer = pl.Trainer(gpus=1, deterministic=True)
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
```

In [14]:
reconstructed = recover_image(vae_iaf, pics, num_samples=10)
generated = generate_image(vae_iaf, random_vector)
plot_digit_samples(original=pics.squeeze(), reconstucted=reconstructed, generated=generated)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
z_posterior = get_posterior_samples(vae_iaf, whole_dataset.to(device), num_samples=10)
plot_posterior([whole_labels.numpy(), z_posterior.cpu().numpy()])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## BNAF

```python
model = VAE_with_flows(act_func=act_func['tanh'], num_samples=10, num_flows=5, flow_type="IAF", need_permute=False)
trainer = pl.Trainer(gpus=1, deterministic=True)
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
```

In [16]:
reconstructed = recover_image(vae_bnaf, pics, num_samples=10)
generated = generate_image(vae_bnaf, random_vector)
plot_digit_samples(original=pics.squeeze(), reconstucted=reconstructed, generated=generated)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [17]:
z_posterior = get_posterior_samples(vae_bnaf, whole_dataset.to(device), num_samples=10)
plot_posterior([whole_labels.numpy(), z_posterior.cpu().numpy()])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## RNVP

```python
model = VAE_with_flows(act_func=act_func['tanh'], num_samples=10, num_flows=5, flow_type="RNVP", need_permute=True)
trainer = pl.Trainer(gpus=1, deterministic=True)
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
```

In [18]:
reconstructed = recover_image(vae_rnvp, pics, num_samples=10)
generated = generate_image(vae_rnvp, random_vector)
plot_digit_samples(original=pics.squeeze(), reconstucted=reconstructed, generated=generated)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [19]:
z_posterior = get_posterior_samples(vae_rnvp, whole_dataset.to(device), num_samples=10)
plot_posterior([whole_labels.numpy(), z_posterior.cpu().numpy()])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# VAE with MCMC

The last idea -- add MCMC method to improve variational approximation.

The approach we consider utilizes an idea of decoupling training procedures for encoder and for decoder.



In [20]:
reconstructed = recover_image(vae_mcmc, pics, num_samples=10)
generated = generate_image(vae_mcmc, random_vector)
plot_digit_samples(original=pics.squeeze(), reconstucted=reconstructed, generated=generated)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [21]:
z_posterior = get_posterior_samples(vae_mcmc, whole_dataset.to(device), num_samples=10)
plot_posterior([whole_labels.numpy(), z_posterior.cpu().numpy()])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …