# VAE Improvements. 

## How to improve variational family?

Last time we spoke about ordinary VAE. We saw that it is quite a simple model and sort of naive in its assumption about the variational family.

Specifically, it assumes that the posterior distribution of latent variables is unimodal factorized Gaussian. It looks for such parameters $\theta$ of decoder, which maximize a lower-bound on marginal likelihood with limitations above. This could and leads to worse latent representation and a worse generator model.

To address this limitation, today, we will go through some improvements, which allow us to make posterior more expressive.

In [47]:
## First, import some models and functions we will use
from models import Base, VAE, IWAE, VAE_with_flows, VAE_MCMC
from main import make_dataloaders, get_activations
from models.samplers import HMC

## Visualization
from utils import plot_digit_samples, plot_posterior, estimate_ll
import matplotlib.pyplot as plt
import seaborn as sns

## Math processing
import numpy as np
from scipy.stats import norm
import torch
import torch.nn
from torchvision.transforms import ToTensor
## Wrapper on top of PyTorch, which ease work on neural nets
import pytorch_lightning as pl
## Is there is a cuda device on the machine, lets use it!
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [2]:
## If you don't have "matplotlib widget", use "matplotlib inline" instead!!
# %matplotlib inline
%matplotlib widget

In [3]:
hidden_dim = 2

First, load data.

As previously, we will be using FashionMNIST dataset.

In [4]:
train_loader, val_loader = make_dataloaders(dataset='fashionmnist',
                                            gpus=1,
                                            batch_size=32,
                                            val_batch_size=64)

In [5]:
act_func = get_activations()

In [49]:
def plot_pics_manifold(model):
    n = 15  # figure with 15x15 panels
    image_size = 28
    figure = np.zeros((image_size * n, image_size * n))
    grid_x = norm.ppf(np.linspace(0.05, 0.95, n)).astype(np.float32)
    grid_y = norm.ppf(np.linspace(0.05, 0.95, n)).astype(np.float32)

    for i, yi in enumerate(grid_x):
        for j, xi in enumerate(grid_y):
            z_sample = torch.tensor(np.array([[xi, yi]]), dtype=torch.float32, device=model.device)
            with torch.no_grad():
                x_decoded = torch.sigmoid(model.decode(z_sample)).cpu().numpy()
            image = x_decoded[0].reshape(image_size, image_size)
            figure[i * image_size: (i + 1) * image_size,
                   j * image_size: (j + 1) * image_size] = image

    plt.figure(figsize=(10, 10))
    plt.imshow(figure, cmap='Greys_r')
    plt.tight_layout()
    plt.show()

In [6]:
## Function to recover given images
def recover_image(model, pics, num_samples=50):
    with torch.no_grad():
        pics_rec = torch.sigmoid(model.step([pics.to(device), None])[1].cpu().view(
                (num_samples, -1, 784)).mean(0).view((64, 28, 28)))
    return pics_rec

## Let us fix the random sample, which we will use for image generation:
random_vector = torch.randn((64, hidden_dim)).to(device)

## And let us write a function to generate images:
def generate_image(model, random_vector):
    with torch.no_grad():
        generated = torch.sigmoid(model(random_vector)).cpu().view((64, 28, 28))
    return generated

def get_posterior_samples(model, pics, num_samples=50, hidden_dim=20):
    with torch.no_grad():
        z_samples = model.step([pics.to(device), None])[2].reshape(num_samples, -1, hidden_dim).mean(0).cpu()
    return z_samples

## Sample a batch from validation dataset
pics = None
labels = None
for b in val_loader:
    pics = b[0]
    labels = b[1]
    break
    
whole_dataset = val_loader.dataset.data[:2000] * 1. / torch.max(val_loader.dataset.data)
whole_labels = val_loader.dataset.targets[:2000]

Then, let us load already trained models (despite we can train it within the notebook, it takes some time). So just load them

In [7]:
vae = VAE.load_from_checkpoint(checkpoint_path='./checkpoints/VAE_None.ckpt', act_func=act_func['tanh'], num_samples=50, hidden_dim=hidden_dim).to(device)
iwae = VAE.load_from_checkpoint(checkpoint_path='./checkpoints/IWAE_None.ckpt', act_func=act_func['tanh'], num_samples=50, hidden_dim=hidden_dim).to(device)
vae_iaf = VAE_with_flows.load_from_checkpoint(checkpoint_path='./checkpoints/VAE_with_flows_IAF.ckpt', act_func=act_func['tanh'], num_samples=10, hidden_dim=hidden_dim, flow_type='IAF', num_flows=5, need_permute=False).to(device)
vae_bnaf = VAE_with_flows.load_from_checkpoint(checkpoint_path='./checkpoints/VAE_with_flows_BNAF.ckpt', act_func=act_func['tanh'], num_samples=10, hidden_dim=hidden_dim, flow_type='BNAF', num_flows=5, need_permute=False).to(device)
vae_rnvp = VAE_with_flows.load_from_checkpoint(checkpoint_path='./checkpoints/VAE_with_flows_RNVP.ckpt', act_func=act_func['tanh'], num_samples=10, hidden_dim=hidden_dim, flow_type='RNVP', num_flows=5, need_permute=True).to(device)
# vae_mcmc = VAE_MCMC.load_from_checkpoint(checkpoint_path='./checkpoints/VAE_MCMC_None.ckpt', act_func=act_func['tanh'], num_samples=10, hidden_dim=hidden_dim, n_leapfrogs=5, step_size=0.1, use_barker=True).to(device)

# Vanilla VAE model

Let us recap the vanilla VAE model, which we study last time.

We were otimizing the following objective:

$$
\mathcal{L}_{\phi, \theta}(X) = \frac{1}{N}  \sum_{i=1}^{N} \mathcal{L}_{\phi, \theta}(x_i) = \frac{1}{N} \sum_{i=1}^{N} \int_{z} q_\phi (z|x_i) \log \frac{p_\theta(z, x_i)}{q_\phi(z|x_i)} dz =  \frac{1}{N} \sum_{i=1}^{N} ( \frac{1}{K} \sum_{k=1}^{K} p_\theta(x_i | z_k) - \text{KL}(q_\phi(z|x_i)\|p(z)) )
$$

In the case of the Gaussian variational family and prior, the last term could be computed in closed form, reducing the variance of the estimator.

The gradients could be computed using the "reparametrization trick."

This is the code for model training. It requires only 3 lines of code.
```python
model = VAE(act_func=act_func['tanh'], num_samples=50, hidden_dim=20)
trainer = pl.Trainer(gpus=1)
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
```

In [8]:
Base??

[0;31mInit signature:[0m [0mBase[0m[0;34m([0m[0mact_func[0m[0;34m,[0m [0mnum_samples[0m[0;34m,[0m [0mhidden_dim[0m[0;34m,[0m [0mname[0m[0;34m=[0m[0;34m'VAE'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Helper class that provides a standard way to create an ABC using
inheritance.
[0;31mSource:[0m        
[0;32mclass[0m [0mBase[0m[0;34m([0m[0mpl[0m[0;34m.[0m[0mLightningModule[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m__init__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mact_func[0m[0;34m,[0m [0mnum_samples[0m[0;34m,[0m [0mhidden_dim[0m[0;34m,[0m [0mname[0m[0;34m=[0m[0;34m"VAE"[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0msuper[0m[0;34m([0m[0mBase[0m[0;34m,[0m [0mself[0m[0;34m)[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0;31m# Encoder[0m[0;34m[0m
[0;34m[0m        [0mself[0m[0;34m.[0m[0mfc1[0m [0;34m=[

In [9]:
VAE??

[0;31mInit signature:[0m [0mVAE[0m[0;34m([0m[0mact_func[0m[0;34m,[0m [0mnum_samples[0m[0;34m,[0m [0mhidden_dim[0m[0;34m,[0m [0mname[0m[0;34m=[0m[0;34m'VAE'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Helper class that provides a standard way to create an ABC using
inheritance.
[0;31mSource:[0m        
[0;32mclass[0m [0mVAE[0m[0;34m([0m[0mBase[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0mloss_function[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mrecon_x[0m[0;34m,[0m [0mx[0m[0;34m,[0m [0mmu[0m[0;34m,[0m [0mlogvar[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0mBCE[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mbinary_cross_entropy_with_logits[0m[0;34m([0m[0mrecon_x[0m[0;34m,[0m [0mx[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m,[0m [0;36m784[0m[0;34m)[0m[0;34m,[0m [0mreduction[0m[0;34m=[0m[0;34m'none'[0m[0;34m)[0m[0;34m.[0m[0mview[0m

In [25]:
reconstructed = recover_image(vae, pics)
generated = generate_image(vae, random_vector)
plot_digit_samples(original=pics.squeeze(), reconstucted=reconstructed, generated=generated)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [26]:
z_posterior = get_posterior_samples(model=vae, pics=whole_dataset.to(device), hidden_dim=hidden_dim)
plot_posterior([whole_labels.numpy(), z_posterior.cpu().numpy()], names=val_loader.dataset.classes)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [50]:
plot_pics_manifold(vae)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
nll_vae = estimate_ll(vae, val_loader)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=157.0), HTML(value='')))

In [28]:
print(f"For VAE, NLL estimation is {-np.mean(nll_vae)} +- {np.std(nll_vae)}")

For VAE, NLL estimation is 112.03470849687127 +- 17.636927481679177


# IWAE model

```python
model = IWAE(act_func=act_func['tanh'], num_samples=50)
trainer = pl.Trainer(gpus=1, deterministic=True)
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
```

Let us now move to more expressive models. The first one -- Importance Weightd Autoencoders, or IWAE


The idea is the following: instead of maximizing the previous ELBO, we will optimize another one, which corresponds to the k-sample importance weighting
estimate of the log-likelihood:

$$
\mathcal{L}_{\phi, \theta}^{\text{IWAE}}(X) = \frac{1}{N}  \sum_{i=1}^{N} \mathcal{L}_{\phi, \theta}^{\text{IWAE}}(x_i) = \frac{1}{N} \sum_{i=1}^{N} \mathbb{E}_{q_\phi(z|x_i)} \left( \frac{1}{K} \sum_{k=1}^{K} \log \frac{p_\theta(z_k, x_i)}{q_\phi(z_k|x_i)} \right) = \frac{1}{N} \sum_{i=1}^{N} \int_{z} q_\phi (z|x_i) \frac{1}{K} \sum_{k=1}^{K} \log \frac{p_\theta(z_k, x_i)}{q_\phi(z_k|x_i)} dz
$$

In [None]:
IWAE??

In [31]:
reconstructed = recover_image(iwae, pics)
generated = generate_image(iwae, random_vector)
plot_digit_samples(original=pics.squeeze(), reconstucted=reconstructed, generated=generated)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
z_posterior = get_posterior_samples(model=iwae, pics=whole_dataset.to(device), hidden_dim=hidden_dim)
plot_posterior([whole_labels.numpy(), z_posterior.cpu().numpy()], names=val_loader.dataset.classes)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [51]:
plot_pics_manifold(iwae)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
nll_iwae = estimate_ll(iwae, val_loader)

In [34]:
print(f"For IWAE, NLL estimation is {-np.mean(nll_iwae)} +- {np.std(nll_iwae)}")

For IWAE, NLL estimation is 113.07576484437202 +- 17.96286111816276


# VAE model with flows

The next idea how we can improve on variational family is the usage of Normalizing flows.

It works as follows: given some input vector, it transform it in such a way, that the determinant of Jacobian this transformation is easy to compute.


Let us denote by $z$ -- the latent variable after reparametrization trick. Then, we apply a transformation with tractable Jacobian.

$$
\tilde z = f_{\phi}(z)
$$

We can also express new density in terms of the old one and logarithm of det Jacobian of the transformation:
$$
\log p_\phi (\tilde z) = \log p(z) - \log \det \text{Jac}_\phi(z) 
$$

And use it in the oroginal ELBO.

In [None]:
VAE_with_flows??

## IAF

```python
model = VAE_with_flows(act_func=act_func['tanh'], num_samples=10, num_flows=5, flow_type="IAF", need_permute=False)
trainer = pl.Trainer(gpus=1, deterministic=True)
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
```

In [35]:
reconstructed = recover_image(vae_iaf, pics, num_samples=10)
generated = generate_image(vae_iaf, random_vector)
plot_digit_samples(original=pics.squeeze(), reconstucted=reconstructed, generated=generated)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [36]:
z_posterior = get_posterior_samples(model=vae_iaf, pics=whole_dataset.to(device), num_samples=10, hidden_dim=hidden_dim)
plot_posterior([whole_labels.numpy(), z_posterior.cpu().numpy()], names=val_loader.dataset.classes)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [52]:
plot_pics_manifold(vae_iaf)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
nll_vae_iaf = estimate_ll(vae_iaf, val_loader)

In [38]:
print(f"For IAF, NLL estimation is {-np.mean(nll_vae_iaf)} +- {np.std(nll_vae_iaf)}")

For IAF, NLL estimation is 111.34282509384641 +- 17.40768579630878


## BNAF

```python
model = VAE_with_flows(act_func=act_func['tanh'], num_samples=10, num_flows=5, flow_type="IAF", need_permute=False)
trainer = pl.Trainer(gpus=1, deterministic=True)
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
```

In [39]:
reconstructed = recover_image(vae_bnaf, pics, num_samples=10)
generated = generate_image(vae_bnaf, random_vector)
plot_digit_samples(original=pics.squeeze(), reconstucted=reconstructed, generated=generated)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [40]:
z_posterior = get_posterior_samples(model=vae_bnaf, pics=whole_dataset.to(device), num_samples=10, hidden_dim=hidden_dim)
plot_posterior([whole_labels.numpy(), z_posterior.cpu().numpy()], names=val_loader.dataset.classes)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [53]:
plot_pics_manifold(vae_bnaf)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
nll_vae_bnaf = estimate_ll(vae_bnaf, val_loader)

In [41]:
print(f"For BNAF, NLL estimation is {-np.mean(nll_vae_bnaf)} +- {np.std(nll_vae_bnaf)}")

For BNAF, NLL estimation is 129.81535149835477 +- 27.66986290751501


## RNVP

```python
model = VAE_with_flows(act_func=act_func['tanh'], num_samples=10, num_flows=5, flow_type="RNVP", need_permute=True)
trainer = pl.Trainer(gpus=1, deterministic=True)
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
```

In [42]:
reconstructed = recover_image(vae_rnvp, pics, num_samples=10)
generated = generate_image(vae_rnvp, random_vector)
plot_digit_samples(original=pics.squeeze(), reconstucted=reconstructed, generated=generated)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [43]:
z_posterior = get_posterior_samples(model=vae_rnvp, pics=whole_dataset.to(device), num_samples=10, hidden_dim=hidden_dim)
plot_posterior([whole_labels.numpy(), z_posterior.cpu().numpy()], names=val_loader.dataset.classes)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [54]:
plot_pics_manifold(vae_rnvp)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
nll_vae_rnvp = estimate_ll(vae_rnvp, val_loader)

In [44]:
print(f"For RNVP, NLL estimation is {-np.mean(nll_vae_rnvp)} +- {np.std(nll_vae_rnvp)}")

For RNVP, NLL estimation is 110.95768907875012 +- 16.698460047559898


# VAE with MCMC

The last idea -- add MCMC method to improve variational approximation.

The approach we consider utilizes an idea of decoupling training procedures for encoder and for decoder.

Specifically, we first do the same optimization as in vanilla VAE, but when update weights, we update only parameters of decoder.

After that, we detach latent sample and perform HMC, initialized in it. After 100 steps of HMC, we use the resulting sample to compute likelihood and optimize only over parameters of decoder.

In [None]:
reconstructed = recover_image(vae_mcmc, pics, num_samples=10)
generated = generate_image(vae_mcmc, random_vector)
plot_digit_samples(original=pics.squeeze(), reconstucted=reconstructed, generated=generated)

In [None]:
z_posterior = get_posterior_samples(model=vae_mcmc, pics=whole_dataset.to(device), num_samples=10, hidden_dim=hidden_dim)
plot_posterior([whole_labels.numpy(), z_posterior.cpu().numpy()], names=val_loader.dataset.classes)