<a href="https://colab.research.google.com/github/juampamuc/notebooks/blob/main/Geometric_Latent_Spaces_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Geometric Latent Spaces Tutorial for Deep Learning 2 course



In this tutorial we will study latent space of different Variational Autoencoders (VAEs). In this tutorial we cover the following topics:

We start with Euclidean VAEs and then to the Geometric VAEs!

1. We start with vanilla VAEs (simple Euclidean VAEs) and interpolate in latent space using geodesics

2. Next we train a hyperspherical VAE (S-VAE) with latent space as a surface of a sphere and try to interpolate in the latent space.

3. We then motivate using other manifolds for latent space! Why manifolds? Is there a way to train VAE such that latent variable lie on more general manifolds?

4. Next, we look into whether we can learn the manifold from the data and train a VAE on such a learned manifold. Discuss RHVAE (latent space as a Riemannian manifold)


Equivariant VAEs
5.  We train Spatial VAEs, on MNIST and test on rotated MNIST which explicitly disentangling image rotation and translation



In [None]:
## Standard libraries
from __future__ import print_function,division
import os
import json
import math
import numpy as np
import sys

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import plotly.graph_objects as go
## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torchvision.utils import make_grid
import torch.optim as optim
from torch.autograd import Variable


import torch.utils.data
from torchvision.datasets import MNIST
from torchvision import transforms, datasets
from collections import defaultdict
import torchvision
from torchvision import transforms
from PIL import Image

# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    ! pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Tensorboard extension (for visualization purposes later)
from torch.utils.tensorboard import SummaryWriter
%load_ext tensorboard

# # Path to the folder where the datasets are/should be downloaded
# DATASET_PATH = "../data"
# # Path to the folder where the pretrained models are saved
# CHECKPOINT_PATH = "../saved_models/tutorial4"

# Setting the seed
pl.seed_everything(1221)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

INFO:lightning_lite.utilities.seed:Global seed set to 1221


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard
Device: cuda:0


### 1. Euclidean VAEs

We first start with Vanilla [VAEs](https://arxiv.org/pdf/1312.6114.pdf).







Suppose we have a high dimentional objects, say images, $x \in \mathcal{X}^d$ and the low dimensional latent variables $z \in R^M$ which we can think of as the hidden features in data.   

We first sample $z \sim p(z)$ and then create an image with all necessary details, sample $x$ from a conditional distribution $p(x|z)$. The objective here is to optimize the log-likelihood of the data, $log \int p_{\theta}(x, z)dz $.
When we parameterize this joint distribution with a neural network and then marginalize over latent variables, $z$, the integral is intractable. So instead we maximize the Evidence Lower Bound (ELBO), where we use approximate posterior distribution $q(z)$. In order to approximate the true posterior, we must optimize the approximate posterior at each point, but that makes the task not very scalable. Thus we use an inference network, parameterized by a neural network $q_{\phi}(z|x)$ that gives a probability distribution for each data point.


Simply put, VAE consists of two parts, encoder and a decoder, where
- The encoder acts as a variational inference network, mapping observed inputs to (approximate) posterior distributions $(q_{\phi})$ for each latent attribute $z$. It describes a probability distribution for each latent attribute.
- The decoder acts as a generative network, capable of mapping latent variables to new samples similar to the input images.


Image from [DLAI tutorials](https://github.com/lucmos/DLAI-s2-2020-tutorials)
![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/08/vae1.png)


Did we miss something?

In computing the the expectation we ELBO, we do encounter an integral. We could use MC approximation of log likelihood with a variational posterior.
Is there any other way? This is where we use the reparameterization trick. The idea is that we express a random variable as a tranformation of an independent random variable. Here, we use reparameterization in the encoder, where we write our latent varaible $z$ as a transformation of a random varaible $\epsilon$. This helps in reducing the variance of the gradient by a huge margin.


So what actually happens in a VAE?
1. We take $x$ and apply encoder network to get $\mu_\phi(x)$ and $\sigma^2_{\phi}(x)$
2. We then calculate $z_\phi$ through reparemetrization trick, $z_\phi =\mu_\phi + \sigma_n * ϵ $
3. Apply decoder on $z_\phi$, get reconstructions of $x$
4. Calculate ELBO  


Part of text is inspired from blog post which contains a more detailed take on VAE. Check [this](https://jmtomczak.github.io/blog/4/4_VAE.html) out!

**Code begins**



In [None]:
## Loading MNIST data using pytorch dataloader.

#train_loader
train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True,
    transform=transforms.ToTensor()), batch_size=128, shuffle=True)

#test_loader
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=False, download=True,
    transform=transforms.ToTensor()), batch_size=128)


**Note** for consistency in the code, we have a separate variable 'distribution' in each class. So you could try to play around with different prior distributions and sample from that!

In [None]:
# Below we have a very simple Encoder and Decoder class. It is common to have Conv layers with dropout, but here we stick to simple MLPs!


# Encoder
class Encoder(nn.Module):

    def __init__(self,
                 hidden_dim: int,
                 latent_dim : int,
                 distribution : str,
                 act_fn : object = nn.ReLU):
        """
        Inputs:
            - hidden_dim : Dimensionality of hidden layers
            - latent_dim : Dimensionality of latent representation z
            - distribution : 'normal' or 'vmf'
            - act_fn : Activation function used throughout the encoder network

        Outputs:
            - z_mean : mean
            - z_var : variance
        """

        super().__init__()
        self.distribution = distribution
        self.net = nn.Sequential(
            nn.Linear(784, 2 * hidden_dim ),
            act_fn(),
            # nn.Linear(4 * hidden_dim, 2*hidden_dim),
            # act_fn(),
            nn.Linear(2 * hidden_dim, hidden_dim),
            act_fn()
            )

        if self.distribution == 'normal':
            # compute mean and std of the normal distribution
            self.fc_mean = nn.Linear(hidden_dim, latent_dim)
            self.fc_var =  nn.Linear(hidden_dim, latent_dim)

        elif self.distribution == 'vmf':
            # compute mean and concentration of the von Mises-Fisher
            self.fc_mean = nn.Linear(hidden_dim, latent_dim)
            #print(self.fc_mean.size)
            self.fc_var = nn.Linear(hidden_dim, 1)
            #print(self.fc_var.size)

        else:
            raise NotImplementedError

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        h = self.net(x)

        if self.distribution == 'normal':
          # compute mean and std of the normal distribution
          z_mean = self.fc_mean(h)
          # predicting log_variance
          z_var = F.softplus(self.fc_var(h))

        elif self.distribution == 'vmf':
            # compute mean and concentration of the von Mises-Fisher
            z_mean = self.fc_mean(h)
            z_mean = z_mean / z_mean.norm(dim=-1, keepdim=True)
            # the `+ 1` prevent collapsing behaviors
            z_var = F.softplus(self.fc_var(h)) + 1
            print(z_mean.shape)
            print(z_var.shape)
        else:
            raise NotImplementedError

        return z_mean, z_var

We start with the classical Euclidiean Varitional Autoencoder that you all must be familiar with from Deep Learning 1 course. The code is similar in format as that of the Deep Autoencoders from Deep Learning 1 with a separate encoder and decoder class, and a model class. It is important to remember that unlike Autoencoders, in Variational Autoencoders we have stochastic encoders and decoders and using the reparameterization trick we approximate the intractable integral and sample from variational posterior. A short revision of VAE can be found below!

But first we import most of our standard libraries.

In [None]:
# Decoder class
class Decoder(nn.Module):

    def __init__(self,
                 hidden_dim: int,
                 latent_dim : int,
                 distribution : str,
                 act_fn : object = nn.ReLU):
        """
        Inputs:
            - hidden_dim : Dimensionality of hidden layers
            - latent_dim : Dimensionality of latent representation z
            - distribution : 'normal' or 'vmf'
            - act_fn : Activation function used throughout the encoder network
        """

        super().__init__()
        self.distribution = distribution
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            act_fn(),
            nn.Linear(hidden_dim, 2 * hidden_dim ),
            act_fn(),
            # nn.Linear(2*hidden_dim, 4 * hidden_dim ),
            # act_fn(),
            nn.Linear(2 * hidden_dim, 784)
            )


    def forward(self, z):
        x_ = self.net(z)
        return x_

In [None]:
#Here we call encoder and decoder class, reparmeterize and compute loss.

#VAE
class VariationalAutoencoder(pl.LightningModule):

    def __init__(self,
                 distribution: str,
                 latent_dim: int,
                 hidden_dim: int,
                 encoder_class : object = Encoder,
                 decoder_class : object = Decoder):

        super().__init__()

        # Saving hyperparameters of VAE
        self.save_hyperparameters()

        # Creating encoder and decoder
        self.encoder = encoder_class(hidden_dim, latent_dim, distribution)
        self.decoder = decoder_class(hidden_dim, latent_dim, distribution)
        self.distribution = distribution
        self.latent_dim = latent_dim
        self.criterion = nn.BCEWithLogitsLoss(reduction='none')

    def reparameterize(self, z_mean, z_var):
        """
        Reparameterizes z and outputs prior and posterior distributions
        """

        if self.distribution == 'normal':
            q_z = torch.distributions.normal.Normal(z_mean, z_var)
            p_z = torch.distributions.normal.Normal(torch.zeros_like(z_mean), torch.ones_like(z_var))

        elif self.distribution == 'vmf':
            q_z = VonMisesFisher(z_mean, z_var)
            p_z = HypersphericalUniform(self.latent_dim - 1)

        else:
            raise NotImplemented

        return q_z, p_z

    def forward(self, x):
        """
        It takes an image and returns the mean and variance, prior and posterior distributions and reconstructions
        """

        z_mean, z_var = self.encoder(x)
        print(z_mean.shape)
        print(z_var.shape)
        q_z, p_z = self.reparameterize(z_mean, z_var)
        print("q_z", q_z)
        print("p_z", p_z)
        z = q_z.rsample()
        x_ = self.decoder(z)

        return (z_mean, z_var), (q_z, p_z), z, x_

    def _get_reconstruction_loss(self, batch):
        """
        It returns loss (combined Binary Cross Entropy with KLD) and reconstructions
        """

        x, _ = batch

        #dynamic binarization to add noise
        x = (x > torch.distributions.Uniform(0, 1).sample(x.shape).to(self.device)).float()

        (z_mean, z_var), (q_z, p_z), z, x_ = self.forward(x)

        # Binary cross entropy loss
        BCE = nn.BCEWithLogitsLoss(reduction='none')(x_, x.reshape(-1, 784)).sum(-1).mean()

        #KL divergence
        if self.distribution == 'normal':
            KLD = torch.distributions.kl.kl_divergence(q_z, p_z).sum(-1).mean()

        elif self.distribution == 'vmf':
            KLD = torch.distributions.kl.kl_divergence(q_z, p_z).mean()

        else:
            raise NotImplemented

        return BCE + KLD, x_

    def configure_optimizers(self):
        """
        Set optimizer
        """
        optimizer = optim.Adam(self.parameters(), lr=1e-2)
        return {"optimizer": optimizer}

    def training_step(self, batch, batch_idx):
        loss, x_ = self._get_reconstruction_loss(batch)
        self.log('train_loss', loss)
        return loss

    def test_step(self, batch, batch_idx):
        loss, x_ = self._get_reconstruction_loss(batch)

        self.log('test_loss', loss)
        return loss, x_

In [None]:
#This function is a trainer function for MNIST, returns a model and test loss

def train_mnist(latent_dim, hidden_dim, distribution):

    # Create a PyTorch Lightning trainer
    trainer_vae = pl.Trainer(gpus=1 if str(device).startswith("cuda") else 0,
                         max_epochs=20, enable_progress_bar=True)
    model = VariationalAutoencoder(distribution, latent_dim, hidden_dim)
    print(model)
    trainer_vae.fit(model, train_loader)
    trainer_vae.save_checkpoint("./vanilla_vae.ckpt")
    test_result = trainer_vae.test(model, dataloaders=test_loader, verbose=False)
    result = {"test": test_result}

    return model, result

# # Use this for loading from a checkpoint
# # trainer_vae.fit(model, train_dataloaders=train_loader, ckpt_path="./vanilla_vae.ckpt")


We will now train our VAE with 2D dimensional latent space so that it is easy to visualize.


In [None]:
# Set the parameters
latent_dim=4
hidden_dim=128
distribution='normal'

#VAE with normal distribution


In [None]:


model_vae, result = train_mnist(latent_dim, hidden_dim, distribution=distribution)
print(model_vae)

  f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type              | Params
------------------------------------------------
0 | encoder   | Encoder           | 234 K 
1 | decoder   | Decoder           | 235 K 
2 | criterion | BCEWithLogitsLoss | 0     
------------------------------------------------
470 K     Trainable params
0         Non-trainable params
470 K     Total params
1.880     Total estimated model params size (MB)


VariationalAutoencoder(
  (encoder): Encoder(
    (net): Sequential(
      (0): Linear(in_features=784, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=128, bias=True)
      (3): ReLU()
    )
    (fc_mean): Linear(in_features=128, out_features=4, bias=True)
    (fc_var): Linear(in_features=128, out_features=4, bias=True)
  )
  (decoder): Decoder(
    (net): Sequential(
      (0): Linear(in_features=4, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=784, bias=True)
    )
  )
  (criterion): BCEWithLogitsLoss()
)


Training: 0it [00:00, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
q_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
p_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
torch.Size([128, 4])
torch.Size([128, 4])
q_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
p_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
torch.Size([128, 4])
torch.Size([128, 4])
q_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
p_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
torch.Size([128, 4])
torch.Size([128, 4])
q_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
p_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
torch.Size([128, 4])
torch.Size([128, 4])
q_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
p_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
torch.Size([128, 4])
torch.Size([128, 4])
q_z Normal(loc: torch.Size([128, 4]), scale: torch.Size

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


q_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
p_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
torch.Size([128, 4])
torch.Size([128, 4])
q_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
p_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
torch.Size([128, 4])
torch.Size([128, 4])
q_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
p_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
torch.Size([128, 4])
torch.Size([128, 4])
q_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
p_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
torch.Size([96, 4])
torch.Size([96, 4])
q_z Normal(loc: torch.Size([96, 4]), scale: torch.Size([96, 4]))
p_z Normal(loc: torch.Size([96, 4]), scale: torch.Size([96, 4]))


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

torch.Size([128, 4])
torch.Size([128, 4])
q_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
p_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
torch.Size([128, 4])
torch.Size([128, 4])
q_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
p_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
torch.Size([128, 4])
torch.Size([128, 4])
q_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
p_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
torch.Size([128, 4])
torch.Size([128, 4])
q_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
p_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
torch.Size([128, 4])
torch.Size([128, 4])
q_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
p_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
torch.Size([128, 4])
torch.Size([128, 4])
q_z Normal(loc: torch.Size([128, 4]), scale: torch.Size([128, 4]))
p_z Normal(

In [None]:
result

{'test': [{'test_loss': 130.3579864501953}]}

### 2. Hyperspherical VAEs

Now we move on Hyperspherical VAE, S-VAEs now. Instead of Gaussian distribution, what if we use Von Mises Fisher (vMF) distribution?

vMF distribution is often seen as a normal Gaussian distribution on a hypersphere (added geometry advantage!). It is given by

$$q(z| \mu, \kappa) = \mathcal{C}_m (\kappa) exp(\kappa \mu^T z)$$

where $\mathcal{C}_m (\kappa)=\frac{{\kappa}^{m/2-1}}{(2 \pi)^{m/2} \mathcal{I}_{m/2-1}(\kappa)}$, $||\mu||^2 =1$ and $\mathcal{I}_\nu$ denotes a modified Bessel function of the first kind at order $\nu$.  

The KL term to be optimized is

$$KL(vMF (\mu, \kappa)|| U(S^{m-1}))= \kappa \frac{\mathcal{I}_{m/2}(k)}{\mathcal{I}_{m/2-1}(k)} + log\mathcal{C}_m(\kappa) -log \frac{2(\pi^{m/2})}{\Gamma(m/2)}^{-1} $$

For the derivation of the KL term, you can refer to [Hyperspherical VAE](https://arxiv.org/abs/1804.00891)

**Note**

- The special case of $\kappa=0$ leads to the standard Gaussian distribution.
- The KL term is independent of $\mu$, but is dependent on $\kappa$. Due to the Bessel function is used to define $C_m(\kappa)$,  we cannot use automatic differentiation. Thus have to manually derive the gradient w.r.t. $\kappa$

For vMF sampling, we use the following procedure by [Ulrich](https://www.jstor.org/stable/2347441)

<!-- ![vMF sampling](https://drive.google.com/uc?export=view&id=11alYSKCZgIzpmG2LBG_i1fW50PCNII1D)  -->



Below, we have some code for computing the necessary functions for SVAE! These functions are mainly taken from the following [code](https://github.com/nicola-decao/s-vae-pytorch), with a modified Bessel function computation!

In [None]:
#@title HypersphericalUniform class for sampling using vmf distribution { display-mode: "form" }
import math
import torch


class HypersphericalUniform(torch.distributions.Distribution):

    support = torch.distributions.constraints.real
    has_rsample = False
    _mean_carrier_measure = 0

    @property
    def dim(self):
        return self._dim

    @property
    def device(self):
        return self._device

    @device.setter
    def device(self, val):
        self._device = val if isinstance(val, torch.device) else torch.device(val)

    def __init__(self, dim, validate_args=None, device="cuda"):
        super(HypersphericalUniform, self).__init__(
            torch.Size([dim]), validate_args=validate_args
        )
        self._dim = dim
        self.device = device

    def sample(self, shape=torch.Size()):
        output = (
            torch.distributions.Normal(0, 1)
            .sample(
                (shape if isinstance(shape, torch.Size) else torch.Size([shape]))
                + torch.Size([self._dim + 1])
            )
            .to(self.device)
        )

        return output / output.norm(dim=-1, keepdim=True)

    def entropy(self):
        return self.__log_surface_area()

    def log_prob(self, x):
        return -torch.ones(x.shape[:-1], device=self.device) * self.__log_surface_area()

    def __log_surface_area(self):
        if torch.__version__ >= "1.0.0":
            lgamma = torch.lgamma(torch.tensor([(self._dim + 1) / 2]).to(self.device))
        else:
            lgamma = torch.lgamma(
                torch.Tensor([(self._dim + 1) / 2], device=self.device)
            )
        return math.log(2) + ((self._dim + 1) / 2) * math.log(math.pi) - lgamma


arg_constraints = {}

In [None]:
#@title VonMisesFisher class for computing its mean and variance
from torch.distributions.kl import register_kl

class VonMisesFisher(torch.distributions.Distribution):

    arg_constraints = {
        "loc": torch.distributions.constraints.real,
        "scale": torch.distributions.constraints.positive,
    }
    support = torch.distributions.constraints.real
    has_rsample = True
    _mean_carrier_measure = 0

    @property
    def mean(self):
        # option 1:
        #return self.loc * (
            #ive(self.__m / 2, self.scale) / ive(self.__m / 2 - 1, self.scale)
        #)
        # option 2:
        return self.loc * ive_fraction_approx(torch.tensor(self.__m / 2), self.scale)
        # options 3:
        # return self.loc * ive_fraction_approx2(torch.tensor(self.__m / 2), self.scale)

    @property
    def stddev(self):
        return self.scale

    def __init__(self, loc, scale, validate_args=None, k=1):
        self.dtype = loc.dtype
        self.loc = loc
        self.scale = scale
        self.device = loc.device
        self.__m = loc.shape[-1]
        self.__e1 = (torch.Tensor([1.0] + [0] * (loc.shape[-1] - 1))).to(self.device)
        self.k = k

        super().__init__(self.loc.size(), validate_args=validate_args)

    def sample(self, shape=torch.Size()):
        with torch.no_grad():
            return self.rsample(shape)

    def rsample(self, shape=torch.Size()):
        shape = shape if isinstance(shape, torch.Size) else torch.Size([shape])

        w = (
            self.__sample_w3(shape=shape)
            if self.__m == 3
            else self.__sample_w_rej(shape=shape)
        )

        v = (
            torch.distributions.Normal(0, 1)
            .sample(shape + torch.Size(self.loc.shape))
            .to(self.device)
            .transpose(0, -1)[1:]
        ).transpose(0, -1)
        v = v / v.norm(dim=-1, keepdim=True)

        w_ = torch.sqrt(torch.clamp(1 - (w ** 2), 1e-10))
        x = torch.cat((w, w_ * v), -1)
        z = self.__householder_rotation(x)

        return z.type(self.dtype)

    def __sample_w3(self, shape):
        shape = shape + torch.Size(self.scale.shape)
        u = torch.distributions.Uniform(0, 1).sample(shape).to(self.device)
        self.__w = (
            1
            + torch.stack(
                [torch.log(u), torch.log(1 - u) - 2 * self.scale], dim=0
            ).logsumexp(0)
            / self.scale
        )
        return self.__w

    def __sample_w_rej(self, shape):
        c = torch.sqrt((4 * (self.scale ** 2)) + (self.__m - 1) ** 2)
        b_true = (-2 * self.scale + c) / (self.__m - 1)

        # using Taylor approximation with a smooth swift from 10 < scale < 11
        # to avoid numerical errors for large scale
        b_app = (self.__m - 1) / (4 * self.scale)
        s = torch.min(
            torch.max(
                torch.tensor([0.0], dtype=self.dtype, device=self.device),
                self.scale - 10,
            ),
            torch.tensor([1.0], dtype=self.dtype, device=self.device),
        )
        b = b_app * s + b_true * (1 - s)

        a = (self.__m - 1 + 2 * self.scale + c) / 4
        d = (4 * a * b) / (1 + b) - (self.__m - 1) * math.log(self.__m - 1)

        self.__b, (self.__e, self.__w) = b, self.__while_loop(b, a, d, shape, k=self.k)
        return self.__w

    @staticmethod
    def first_nonzero(x, dim, invalid_val=-1):
        mask = x > 0
        idx = torch.where(
            mask.any(dim=dim),
            mask.float().argmax(dim=1).squeeze(),
            torch.tensor(invalid_val, device=x.device),
        )
        return idx

    def __while_loop(self, b, a, d, shape, k=20, eps=1e-20):
        #  matrix while loop: samples a matrix of [A, k] samples, to avoid looping all together
        b, a, d = [
            e.repeat(*shape, *([1] * len(self.scale.shape))).reshape(-1, 1)
            for e in (b, a, d)
        ]
        w, e, bool_mask = (
            torch.zeros_like(b).to(self.device),
            torch.zeros_like(b).to(self.device),
            (torch.ones_like(b) == 1).to(self.device),
        )

        sample_shape = torch.Size([b.shape[0], k])
        shape = shape + torch.Size(self.scale.shape)

        while bool_mask.sum() != 0:
            con1 = torch.tensor((self.__m - 1) / 2, dtype=torch.float64)
            con2 = torch.tensor((self.__m - 1) / 2, dtype=torch.float64)
            e_ = (
                torch.distributions.Beta(con1, con2)
                .sample(sample_shape)
                .to(self.device)
                .type(self.dtype)
            )

            u = (
                torch.distributions.Uniform(0 + eps, 1 - eps)
                .sample(sample_shape)
                .to(self.device)
                .type(self.dtype)
            )

            w_ = (1 - (1 + b) * e_) / (1 - (1 - b) * e_)
            t = (2 * a * b) / (1 - (1 - b) * e_)

            accept = ((self.__m - 1.0) * t.log() - t + d) > torch.log(u)
            accept_idx = self.first_nonzero(accept, dim=-1, invalid_val=-1).unsqueeze(1)
            accept_idx_clamped = accept_idx.clamp(0)
            # we use .abs(), in order to not get -1 index issues, the -1 is still used afterwards
            w_ = w_.gather(1, accept_idx_clamped.view(-1, 1))
            e_ = e_.gather(1, accept_idx_clamped.view(-1, 1))

            reject = accept_idx < 0
            accept = ~reject if torch.__version__ >= "1.2.0" else 1 - reject

            w[bool_mask * accept] = w_[bool_mask * accept]
            e[bool_mask * accept] = e_[bool_mask * accept]

            bool_mask[bool_mask * accept] = reject[bool_mask * accept]

        return e.reshape(shape), w.reshape(shape)

    def __householder_rotation(self, x):
        u = self.__e1 - self.loc
        u = u / (u.norm(dim=-1, keepdim=True) + 1e-5)
        z = x - 2 * (x * u).sum(-1, keepdim=True) * u
        return z

    def entropy(self):
        # option 1:
        # output = (
        #     -self.scale
        #     * ive(self.__m / 2, self.scale)
        #     / ive((self.__m / 2) - 1, self.scale)
        # )
        # option 2:
        output = - self.scale * ive_fraction_approx(torch.tensor(self.__m / 2), self.scale)
        # option 3:
        # output = - self.scale * ive_fraction_approx2(torch.tensor(self.__m / 2), self.scale)

        return output.view(*(output.shape[:-1])) + self._log_normalization()

    def log_prob(self, x):
        return self._log_unnormalized_prob(x) - self._log_normalization()

    def _log_unnormalized_prob(self, x):
        output = self.scale * (self.loc * x).sum(-1, keepdim=True)

        return output.view(*(output.shape[:-1]))

    def _log_normalization(self):
        output = -(
            (self.__m / 2 - 1) * torch.log(self.scale)
            - (self.__m / 2) * math.log(2 * math.pi)
            - (self.scale + torch.log(ive(self.__m / 2 - 1, self.scale)))
        )

        return output.view(*(output.shape[:-1]))


@register_kl(VonMisesFisher, HypersphericalUniform)
def _kl_vmf_uniform(vmf, hyu):
    return -vmf.entropy() + hyu.entropy()

In [None]:
#@title Ive class for Bessel funcions and some stable variations
import scipy.special
from numbers import Number


class IveFunction(torch.autograd.Function):
    @staticmethod
    def forward(self, v, z):

        assert isinstance(v, Number), "v must be a scalar"

        self.save_for_backward(z)
        self.v = v
        z_cpu = z.data.cpu().numpy()

        if np.isclose(v, 0):
            output = scipy.special.i0e(z_cpu, dtype=z_cpu.dtype)
        elif np.isclose(v, 1):
            output = scipy.special.i1e(z_cpu, dtype=z_cpu.dtype)
        else:  #  v > 0
            output = scipy.special.ive(v, z_cpu, dtype=z_cpu.dtype)
        #         else:
        #             print(v, type(v), np.isclose(v, 0))
        #             raise RuntimeError('v must be >= 0, it is {}'.format(v))

        return torch.Tensor(output).to(z.device)

    @staticmethod
    def backward(self, grad_output):
        z = self.saved_tensors[-1]
        return (
            None,
            grad_output * (ive(self.v - 1, z) - ive(self.v, z) * (self.v + z) / z),
        )


class Ive(torch.nn.Module):
    def __init__(self, v):
        super(Ive, self).__init__()
        self.v = v

    def forward(self, z):
        return ive(self.v, z)


ive = IveFunction.apply


##########
# The below provided approximations were provided in the
# respective source papers, to improve the stability of
# the Bessel fractions.
# I_(v/2)(k) / I_(v/2 - 1)(k)

# source: https://arxiv.org/pdf/1606.02008.pdf
def ive_fraction_approx(v, z):
    # I_(v/2)(k) / I_(v/2 - 1)(k) >= z / (v-1 + ((v+1)^2 + z^2)^0.5
    return z / (v - 1 + torch.pow(torch.pow(v + 1, 2) + torch.pow(z, 2), 0.5))


# source: https://arxiv.org/pdf/1902.02603.pdf
def ive_fraction_approx2(v, z, eps=1e-20):
    def delta_a(a):
        lamb = v + (a - 1.0) / 2.0
        return (v - 0.5) + lamb / (
            2 * torch.sqrt((torch.pow(lamb, 2) + torch.pow(z, 2)).clamp(eps))
        )

    delta_0 = delta_a(0.0)
    delta_2 = delta_a(2.0)
    B_0 = z / (
        delta_0 + torch.sqrt((torch.pow(delta_0, 2) + torch.pow(z, 2))).clamp(eps)
    )
    B_2 = z / (
        delta_2 + torch.sqrt((torch.pow(delta_2, 2) + torch.pow(z, 2))).clamp(eps)
    )

    return (B_0 + B_2) / 2.0

Now lets train a SVAE on MNIST dataset!

In [None]:
#@title Log likelihood computation for sanity check!
# def log_likelihood(model, x, n=10):
#     """
#     :param model: model object
#     :param optimizer: optimizer object
#     :param n: number of MC samples
#     :return: MC estimate of log-likelihood
#     """

#     z_mean, z_var = model.encoder(x.reshape(-1, 784))
#     q_z, p_z = model.reparameterize(z_mean, z_var)
#     z = q_z.rsample(torch.Size([n]))
#     x_mb_ = model.decoder(z)

#     log_p_z = p_z.log_prob(z)

#     if model.distribution == 'normal':
#         log_p_z = log_p_z.sum(-1)

#     log_p_x_z = -nn.BCEWithLogitsLoss(reduction='none')(x_mb_, x.reshape(-1, 784).repeat((n, 1, 1))).sum(-1)

#     log_q_z_x = q_z.log_prob(z)

#     if model.distribution == 'normal':
#         log_q_z_x = log_q_z_x.sum(-1)

#     return ((log_p_x_z + log_p_z - log_q_z_x).t().logsumexp(-1) - np.log(n)).sum()

# def test_ll(model, test_loader, batch_size):
#   LL=torch.tensor(0.0).to(device)
#   bs = 0
#   model = model.to(device)
#   model.eval()
#   with torch.no_grad():
#     for x, y in test_loader:
#       LL+= log_likelihood(model,x.to(device), n=10)
#       bs += batch_size
#   return LL/bs
# print(-test_ll(model_svae, test_loader, 64))
# print(-test_ll(model_vae, test_loader, 64))

In [None]:
def train_mnist_svae(latent_dim, hidden_dim, distribution):

    # Create a PyTorch Lightning trainer
    trainer_svae = pl.Trainer(gpus=1 if str(device).startswith("cuda") else 0,
                         max_epochs=20, enable_progress_bar=True)

    model = VariationalAutoencoder(distribution, latent_dim, hidden_dim)
    print(model)
    # Use this for loading from a checkpoint
    trainer_svae.fit(model, train_loader)
    # trainer_svae.save_checkpoint("./svae.ckpt")
    test_result = trainer_svae.test(model, dataloaders=test_loader, verbose=False)
    result = {"test": test_result}

    return model, result

# # Use this for loading from a checkpoint
# trainer_svae.fit(model, train_dataloaders=train_loader, ckpt_path="./svae.ckpt")

In [None]:
# Set the parameters

latent_dim= 3
hidden_dim=128
distribution='vmf'

model_svae, result = train_mnist_svae(latent_dim, hidden_dim, distribution=distribution)

  f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type              | Params
------------------------------------------------
0 | encoder   | Encoder           | 234 K 
1 | decoder   | Decoder           | 235 K 
2 | criterion | BCEWithLogitsLoss | 0     
------------------------------------------------
469 K     Trainable params
0         Non-trainable params
469 K     Total params
1.878     Total estimated model params size (MB)


VariationalAutoencoder(
  (encoder): Encoder(
    (net): Sequential(
      (0): Linear(in_features=784, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=128, bias=True)
      (3): ReLU()
    )
    (fc_mean): Linear(in_features=128, out_features=3, bias=True)
    (fc_var): Linear(in_features=128, out_features=1, bias=True)
  )
  (decoder): Decoder(
    (net): Sequential(
      (0): Linear(in_features=3, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=784, bias=True)
    )
  )
  (criterion): BCEWithLogitsLoss()
)


Training: 0it [00:00, ?it/s]

  'with `validate_args=False` to turn off validation.')


torch.Size([128, 3])
torch.Size([128, 1])
torch.Size([128, 3])
torch.Size([128, 1])
q_z VonMisesFisher(loc: torch.Size([128, 3]), scale: torch.Size([128, 1]))
p_z 

NotImplementedError: ignored

In [None]:
print(result)

Now we plot the latent space of the SVAE. Here we plot the 3D latent representations of the MNIST digits. We observe neat and well spread out  clusters.

In [None]:
# Plot latent space of the SVAE model

def plot_hyperspherical_latent(model, data, num_batches=10):
    model= model.to('cpu')
    cluster_X = []
    cluster_Y = []
    batch_idx = 0
    z_x = []
    z_y = []
    z_z = []
    labels = []
    for i, (x, y) in enumerate(data):
        z_coords, z_other = model.encoder(x.to('cpu'))
        z_coords = z_coords.detach().numpy()
        z_x.append(z_coords[:, 0])
        z_y.append(z_coords[:, 1])
        z_z.append(z_coords[:, 2])
        labels.append(y.detach().numpy())
    z_x = np.concatenate(z_x)
    z_y = np.concatenate(z_y)
    z_z = np.concatenate(z_z)
    labels = np.concatenate(labels)
    fig = go.Figure(data=[go.Scatter3d(x=z_x, y=z_y, z=z_z,
                                mode='markers',
                                marker=dict(
                                size=5,
                                color=labels,
                                colorscale='hsv',
                                opacity=1.0
                            ))])
    fig.show()
plot_hyperspherical_latent(model_svae, train_loader, num_batches=10)

Lets see some reconstructions of the MNIST digits

In [None]:
visualize_reconstructions(model_svae, x_1.unsqueeze(0))

Let's see if we can interpolate in the latent space using straight lines.


In [None]:
x, y = train_loader.__iter__().next()
x_1 = x[y == 0][1]
x_2 = x[y == 3][1]
interpolate(model_svae, x_1, x_2, n=10)

Is straight line the shortest distance for a hypersphere?

It isn't! We see some fuzzy images in the interpolation and that could be because straight can give out of manifold data for a hypersphere (or any non-Euclidean latent space). In order to interpolate in the hyperspherical latent space, we need to compute the geodesic for a hypersphere.

For geodesic computaion and visualization, we use [Geomstats](https://geomstats.github.io) library!


In [None]:
!pip install geomstats

In [None]:
import geomstats.backend as gs
import geomstats.visualization as visualization

from geomstats.geometry.hypersphere import Hypersphere

SPHERE2 = Hypersphere(dim=2)
METRIC = SPHERE2.metric

#Pick a sample and get its latent
model_svae.eval()
with torch.no_grad():
    x_1 = x_1.to('cuda')
    model = model_svae.cuda().float()
    z_1, _ = model_svae.encoder(x_1)


#Set z_1 as initial point!
initial_point = gs.array(z_1.cpu())

#Randomly pick a tangent vec
initial_tangent_vec = SPHERE2.to_tangent(
    vector=gs.array([1, 2.0, 0.8]), base_point=initial_point
)
geodesic = METRIC.geodesic(
    initial_point=initial_point, initial_tangent_vec=initial_tangent_vec
)

n_steps = 10
t = gs.linspace(0.0, 1.0, n_steps)

points = geodesic(t)

visualization.plot(points, space="S2")
plt.show()


In [None]:
# Lets decode the endpoint:
end_point= torch.from_numpy(points[9]).double()
model_svae.eval()
with torch.no_grad():
    end_point = end_point.to('cuda')
    model = model_svae.double()
    reconstructed_image = model_svae.decoder(end_point.double())
    recon= reconstructed_image.cpu().reshape(28, 28)
    plt.imshow(recon, cmap='gray')


Now let's dive a bit more into geometries other than hypersphere!
Is it possible to define a latent space in an arbitary manifold?

But first, let's look at some definitions:

**Manifold**: A manifold is a (topological) set that is locally diffeomorphic to $R^n$.

**Riemannian manifold**: A pair $(M,g)$ is a Riemannian manifold, where $M$ is a smooth manifold and $g$ is the Riemannian metric.

Or alternately A Riemannian manifold is a smooth manifold $M$, s.t. at each element $p \in M$ an inner product (or metric) $G|p$ on the tangent space $T_p(M)$ is defined.

**metric**:The Riemanian metric is a 2-tensor field $g : T^2(M)  R$ (it takes two tangent vectors and returns a scalar), that is Symmetric, i.e. $(g(X,Y ) = g(Y,X)$ for all $X,Y ∈ T_p(M))$ and Positive definite, i.e. $(g(X,X) \geq 0)$

**Geodesic** : A geodesic is defined as the curve $\gamma$ minimizing the total curve length measured by $G|\gamma(t)$ and has length $d(m_1,m_2)$.



In the next part of the tutorial we will focus on latent spaces that lies on a Riemannian manifold with a metric $g$. In order to have a correct manifold (given by the data) we learn the metric (and hence the manifold) that is defined in the latent space and thus improve on the posterior distribution. This leads to better samples.

Posterior Sampling can be improved by using Hamiltonian Monte Carlo Sampler for Reimmanian manifolds. But as we do not have a metric for the manifold, it is difficult to interpolate in the latent space.

**Hamiltonian MCMC**
In the HMC, a random variable $z$ and a momentum varaible $\rho$, to get a target density $\pi(z, \rho) =p((z| \rho)p(\rho)$ as given

 $$ \pi (z, \rho)= \frac{\exp^{-H(z, \rho)}}{\int \exp^{-H(z, \rho)}dz d\rho}$$
 where $H$ is a Hamiltonian and corresponds to negative log density of target distribution
$$H(z, \rho)= -log \pi(z, \rho) = - log \pi (z) + \frac{1}{2}log ((2\pi)^d |M|) + \rho^T M^{-1}\rho$$.

The time evolution of $z$ and $\rho$ is given Hamiltonian equations.
$$
\frac{\partial z}{\partial t}= \frac{\partial H}{\partial \rho} = M^{-1} \rho \\
\\
\frac{\partial \rho}{\partial t}= -\frac{\partial H}{\partial z} = \nabla_z log(\pi(z))
$$

The solution of the above set of equations has to satify the following constraints and solves using a standard integrator.
1. Time reversability
2. Volume preseving
3. Preserves Hamiltonian

But of course, for this we need to know $M$ or $M^{-1}$.
In RHVAE, the inverse of the metric is parameterized as given below:

$$ G^{-1}(z)= \sum_{i=1}^N  L_{\psi_i} L_{\psi_i} ^T\exp(- \frac{||z-c||_2^2}{T^2}) + \lambda I_d $$

where $L_{\psi_i}$ are lower trinagular matrices parametrized using neural networks, $T$ is a smoothing parameter for the metric,
$c_i$ are centroids,
$\lambda$ is a regularization parameter

Note that here, we only use the inverse of the metric and hence parameterizing that is adequate.

In a VAE the target density is the true posterior distribution $p_\theta(z|x)$. In order to sample via HMC, we need to be able to compute gradient of the true posterior. Now instead of focusing on conditional distribution, we look at the joint distribution $p_\theta(x, z)$ which is given by

$$ p_\theta(x, z)= p_\theta(x|z)q_{prior}(z)$$

Using HMCMC steps, we get better generative capacity.

Note that here can easily compute geodesics using the closed form experession of the metric and metric G can be easily learned via backpropagation using the VAE.


Another way to think about this model is that it learns the manifold by using joint distribution for each input point. So the model is learning the manifold using density of points, and can be thought of manifold lerning + density estimation in a single step.

![RHVAE](https://drive.google.com/uc?export=view&id=1gSftSSjruYcuvjQbXiqX1r8OlVGJAWWO)


To visualize the geodesics, we use the python library Geomstats

Now we use [Pyraug](https://pyraug.readthedocs.io/en) library to use RHVAE to generate MNIST digits. Using this library we can learn the manifold along with density of the data (density + manifold learning) and hence generate realistic samples from less data, and hence has a lot of applications in medical AI. But for now, we use MNIST dataset for ease of this colab. First, let's install the library!

In [None]:
!pip install pyraug

In [None]:
from pyraug.models.base.base_config import BaseModelConfig
from pyraug.models import BaseVAE

RHVAE has shown promizing results for Data augmentation.
![Data Augmentation pipeline](https://drive.google.com/uc?export=view&id=1cYZCtgoK15wy656rABBa-FRaKnEIqMCJ)

In [None]:
config = BaseModelConfig(
input_dim=10)

BaseVAE(model_config=config)

We start with the BaseEncoder with the above architecture and use 200 samples from MNIST dataset for training RHVAE!

In [None]:
## We load the baseEncoder and then train RHVAE using the training pipeline:
## The following code is taken mostly from the pyraug library !

from pyraug.models.nn import BaseEncoder
from pyraug.trainers.training_config import TrainingConfig
from pyraug.pipelines.training import TrainingPipeline

mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=None)
n_samples = 200

config = TrainingConfig(
    output_dir='my_model',
    train_early_stopping=50,
    learning_rate=1e-3,
    batch_size=200, # (default: 50)
    max_epochs=500
)

# Here we train
pipeline = TrainingPipeline(training_config=config)

dataset_to_augment = mnist_trainset.data[:n_samples]
# This will launch the Pipeline on the data
pipeline(train_data=dataset_to_augment, log_output_dir='output_logs')

In [None]:
from pyraug.pipelines.generation import GenerationPipeline

from pyraug.models import RHVAE
import os

last_training = sorted(os.listdir('my_model'))[-1]

# reload the model
model = RHVAE.load_from_folder(os.path.join('my_model', last_training, 'final_model'))

# This creates the Pipeline
generation_pipe = GenerationPipeline(
    model=model
)

# This will launch the Pipeline
generation_pipe(100)

last_generation = sorted(os.listdir('dummy_output_dir'))[-1]
generated_data = torch.load(os.path.join('dummy_output_dir', last_generation, 'generated_data_100_0.pt'))

Lets see some MNIST samples:

In [None]:
fig, axes = plt.subplots(2, 10, figsize=(10, 2))
for i in range(2):
        for j in range(10):
                axes[i][j].matshow(dataset_to_augment[i*10 +j].reshape(28, 28), cmap='gray')
                axes[i][j].axis('off')

plt.tight_layout(pad=0.8)

Lets now see some generated samples!

In [None]:
fig, axes = plt.subplots(2, 10, figsize=(10, 2))
for i in range(2):
        for j in range(10):
                axes[i][j].matshow(generated_data[i*10 +j].cpu().reshape(28, 28), cmap='gray')
                axes[i][j].axis('off')

plt.tight_layout(pad=0.8)

**Note** The above generated data samples do look somewhat like MNIST digits, but are not the best. How can we improve them?

**Exercise**

Design your own autoencoder (better than the one here !) for MNIST, and try to run it on part of MNIST dataset!

### [Spatial VAEs](https://arxiv.org/abs/1909.11663)

In this work, a generative model is formulated as a function of the spatial coordinate! This VAE is then trained to perform approximate inference on these latent variables while explicitly constraining them to only represent rotation and translation.

In a) we have a generative model which maps spatial coordinates and latent variables to the distribution parameters of the pixel intensity (at that coordinate). This model is applied to each coordinate in the pixel grid to generate a complete image.

In b) Approximate inference is performed on the rotation using an inference network.

![Spatial VAE framework](https://drive.google.com/uc?export=view&id=1jeO4e8oiLCMuLTLnPclWVYNcQDX769Ez)


**Note** It is important to note that the coordinate transformations are applied directly to the spatial coordinates before being decoded by the generator network.

The above figure is from the Spatial VAE [paper](https://arxiv.org/abs/1909.11663)

In [None]:
#@title Spatial-VAE
## The following code is slightly modified from the official code of the paper!
## Define a separate Inference Network!


class ResidLinear(nn.Module):
    def __init__(self, n_in, n_out, activation=nn.Tanh):
        super(ResidLinear, self).__init__()

        self.linear = nn.Linear(n_in, n_out)
        self.act = activation()

    def forward(self, x):
        return self.act(self.linear(x) + x)


class InferenceNetwork(nn.Module):
    def __init__(self, n, latent_dim, hidden_dim, num_layers=1, activation=nn.Tanh, resid=False):
        super(InferenceNetwork, self).__init__()

        self.latent_dim = latent_dim
        self.n = n

        layers = [nn.Linear(n, hidden_dim),
                  activation(),
                 ]
        for _ in range(1, num_layers):
            if resid:
                layers.append(ResidLinear(hidden_dim, hidden_dim, activation=activation))
            else:
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                layers.append(activation())

        layers.append(nn.Linear(hidden_dim, 2*latent_dim))

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        z = self.layers(x)

        ld = self.latent_dim
        z_mu = z[:,:ld]
        z_logstd = z[:,ld:]

        return z_mu, z_logstd


class SpatialGenerator(nn.Module):
    def __init__(self, latent_dim, hidden_dim, n_out=1, num_layers=1, activation=nn.Tanh
                , softplus=False, resid=False, expand_coords=False, bilinear=False):
        super(SpatialGenerator, self).__init__()

        self.softplus = softplus
        self.expand_coords = expand_coords

        in_dim = 2
        if expand_coords:
            in_dim = 5

        self.coord_linear = nn.Linear(in_dim, hidden_dim)
        self.latent_dim = latent_dim
        if latent_dim > 0:
            self.latent_linear = nn.Linear(latent_dim, hidden_dim, bias=False)

        if latent_dim > 0 and bilinear:
            self.bilinear = nn.Bilinear(in_dim, latent_dim, hidden_dim, bias=False)

        layers = [activation()]
        for _ in range(1,num_layers):
            if resid:
                layers.append(ResidLinear(hidden_dim, hidden_dim, activation=activation))
            else:
                layers.append(nn.Linear(hidden_dim,hidden_dim))
                layers.append(activation())
        layers.append(nn.Linear(hidden_dim, n_out))

        self.layers = nn.Sequential(*layers)

    def forward(self, x, z):
        if len(x.size()) < 3:
            x = x.unsqueeze(0)
        b = x.size(0)
        n = x.size(1)
        x = x.view(b*n, -1)
        if self.expand_coords:
            x2 = x**2
            xx = x[:,0]*x[:,1]
            x = torch.cat([x, x2, xx.unsqueeze(1)], 1)

        h_x = self.coord_linear(x)
        h_x = h_x.view(b, n, -1)

        h_z = 0
        if hasattr(self, 'latent_linear'):
            if len(z.size()) < 2:
                z = z.unsqueeze(0)
            h_z = self.latent_linear(z)
            h_z = h_z.unsqueeze(1)

        h_bi = 0
        if hasattr(self, 'bilinear'):
            if len(z.size()) < 2:
                z = z.unsqueeze(0)
            z = z.unsqueeze(1)
            x = x.view(b, n, -1)
            z = z.expand(b, x.size(1), z.size(2)).contiguous()
            h_bi = self.bilinear(x, z)

        h = h_x + h_z + h_bi
        h = h.view(b*n, -1)

        y = self.layers(h)
        y = y.view(b, n, -1)

        if self.softplus:
            y = torch.cat([F.softplus(y[:,:,:1]), y[:,:,1:]], 2)

        return y


class VanillaGenerator(nn.Module):
    def __init__(self, n, latent_dim, hidden_dim, n_out=1, num_layers=1, activation=nn.Tanh
                , softplus=False, resid=False):
        super(VanillaGenerator, self).__init__()
        """
        The standard MLP structure for image generation. Decodes each pixel location as a funciton of z.
        """

        self.n_out = n_out
        self.softplus = softplus

        layers = [nn.Linear(latent_dim,hidden_dim),
                  activation()]
        for _ in range(1,num_layers):
            if resid:
                layers.append(ResidLinear(hidden_dim, hidden_dim, activation=activation))
            else:
                layers.append(nn.Linear(hidden_dim,hidden_dim))
                layers.append(activation())
        layers.append(nn.Linear(hidden_dim, n*n_out))
        if softplus:
            layers.append(nn.Softplus())

        self.layers = nn.Sequential(*layers)

    def forward(self, x, z):

        y = self.layers(z).view(z.size(0), -1, self.n_out)
        if self.softplus:
            y = torch.cat([F.softplus(y[:,:,:1]), y[:,:,1:]], 2)

        return y

In [None]:
def eval_minibatch(x, y, p_net, q_net, rotate=True, translate=True, dx_scale=0.1, theta_prior=np.pi, use_cuda=False):
    b = y.size(0)
    x = x.expand(b, x.size(0), x.size(1))

    # first do inference on the latent variables
    if use_cuda:
        y = y.cuda()

    z_mu,z_logstd = q_net(y)
    z_std = torch.exp(z_logstd)
    z_dim = z_mu.size(1)

    # draw samples from variational posterior to calculate
    # E[p(x|z)]

    r = Variable(x.data.new(b,z_dim).normal_())
    z = z_std*r + z_mu
    z= z.cuda()
    kl_div = 0
    if rotate:
        # z[0] is the rotation
        theta_mu = z_mu[:,0]
        theta_std = z_std[:,0]
        theta_logstd = z_logstd[:,0]
        theta = z[:,0]
        z = z[:,1:]
        z_mu = z_mu[:,1:]
        z_std = z_std[:,1:]
        z_logstd = z_logstd[:,1:]

        # calculate rotation matrix
        rot = Variable(theta.data.new(b,2,2).zero_())
        rot[:,0,0] = torch.cos(theta)
        rot[:,0,1] = torch.sin(theta)
        rot[:,1,0] = -torch.sin(theta)
        rot[:,1,1] = torch.cos(theta)
        x = torch.bmm(x, rot) # rotate coordinates by theta

        # calculate the KL divergence term
        sigma = theta_prior
        kl_div = -theta_logstd + np.log(sigma) + (theta_std**2 + theta_mu**2)/2/sigma**2 - 0.5

    if translate:
        # z[0,1] are the translations
        dx_mu = z_mu[:,:2]
        dx_std = z_std[:,:2]
        dx_logstd = z_logstd[:,:2]
        dx = z[:,:2]*dx_scale # scale dx by standard deviation
        dx = dx.unsqueeze(1)
        z = z[:,2:]

        x = x + dx # translate coordinates

    # reconstruct
    y_hat = p_net(x.contiguous(), z)
    y_hat = y_hat.view(b, -1)

    size = y.size(1)
    log_p_x_g_z = -F.binary_cross_entropy_with_logits(y_hat, y)*size

    # unit normal prior over z and translation
    z_kl = -z_logstd + 0.5*z_std**2 + 0.5*z_mu**2 - 0.5
    kl_div = kl_div + torch.sum(z_kl, 1)
    kl_div = kl_div.mean()

    elbo = log_p_x_g_z - kl_div

    return elbo, log_p_x_g_z, kl_div

In [None]:
def train_epoch(iterator, x_coord, p_net, q_net, optim, rotate=True, translate=True
               , dx_scale=0.1, theta_prior=np.pi
               , epoch=10, num_epochs=1, N=1, use_cuda=False):
    p_net.train()
    q_net.train()

    c = 0
    gen_loss_total = 0
    kl_loss_total = 0
    elbo_total = 0

    for y, in iterator:
        b = y.size(0)
        x = Variable(x_coord)
        y = Variable(y)

        elbo, log_p_x_g_z, kl_div = eval_minibatch(x, y, p_net, q_net, rotate=rotate, translate=translate
                                                  , dx_scale=dx_scale, theta_prior=theta_prior
                                                  , use_cuda=use_cuda)

        loss = -elbo
        loss.backward()
        optim.step()
        optim.zero_grad()

        elbo = elbo.item()
        gen_loss = -log_p_x_g_z.item()
        kl_loss = kl_div.item()

        c += b
        delta = b*(gen_loss - gen_loss_total)
        gen_loss_total += delta/c

        delta = b*(elbo - elbo_total)
        elbo_total += delta/c

        delta = b*(kl_loss - kl_loss_total)
        kl_loss_total += delta/c

        template = '# [{}/{}] training {:.1%}, ELBO={:.5f}, Error={:.5f}, KL={:.5f}'
        line = template.format(epoch+1, num_epochs, c/N, elbo_total, gen_loss_total
                              , kl_loss_total)
        print(line, end='\r', file=sys.stderr)

    print(' '*80, end='\r', file=sys.stderr)
    return elbo_total, gen_loss_total, kl_loss_total


def eval_model(iterator, x_coord, p_net, q_net, rotate=True, translate=True
              , dx_scale=0.1, theta_prior=np.pi, use_cuda=False):
    p_net.eval()
    q_net.eval()

    c = 0
    gen_loss_total = 0
    kl_loss_total = 0
    elbo_total = 0

    for y, in iterator:
        b = y.size(0)
        x = Variable(x_coord)
        y = Variable(y)

        elbo, log_p_x_g_z, kl_div = eval_minibatch(x, y, p_net, q_net, rotate=rotate, translate=translate
                                                  , dx_scale=dx_scale, theta_prior=theta_prior
                                                  , use_cuda=use_cuda)

        elbo = elbo.item()
        gen_loss = -log_p_x_g_z.item()
        kl_loss = kl_div.item()

        c += b
        delta = b*(gen_loss - gen_loss_total)
        gen_loss_total += delta/c

        delta = b*(elbo - elbo_total)
        elbo_total += delta/c

        delta = b*(kl_loss - kl_loss_total)
        kl_loss_total += delta/c

    return elbo_total, gen_loss_total, kl_loss_total

In [None]:
#@title get Data
mnist_train = torchvision.datasets.MNIST('data/mnist/', train=True, download=True)
mnist_test = torchvision.datasets.MNIST('data/mnist/', train=False, download=True)

array = np.zeros((len(mnist_train),28,28), dtype=np.uint8)
for i in range(len(mnist_train)):
    array[i] = np.array(mnist_train[i][0], copy=False)
mnist_train = array

array = np.zeros((len(mnist_test),28,28), dtype=np.uint8)
for i in range(len(mnist_test)):
    array[i] = np.array(mnist_test[i][0], copy=False)
mnist_test = array

In [None]:
mnist_train = torch.from_numpy(mnist_train).float()/255
mnist_test = torch.from_numpy(mnist_test).float()/255

#make grid for computating translations and rotations:
#params
n = 28
m= 28
z_dim= 2
hidden_dim= 100
num_layers =2
use_cuda= True


xgrid = np.linspace(-1, 1, m)
ygrid = np.linspace(1, -1, n)
x0,x1 = np.meshgrid(xgrid, ygrid)
x_coord = np.stack([x0.ravel(), x1.ravel()], 1)
x_coord = torch.from_numpy(x_coord).float()

y_train = mnist_train.view(-1, n*m)
y_test = mnist_test.view(-1, n*m)

In [None]:
if use_cuda==True:
  y_train = y_train.cuda()
  y_test = y_test.cuda()
  x_coord = x_coord.cuda()

data_train = torch.utils.data.TensorDataset(y_train)
data_test = torch.utils.data.TensorDataset(y_test)


In [None]:
dx_scale = 0.1
theta_prior = np.pi/4

print('# using priors: theta={}, dx={}'.format(theta_prior, dx_scale), file=sys.stderr)

N = len(mnist_train)

p_net = VanillaGenerator(n*m, z_dim, hidden_dim, num_layers=num_layers, activation=nn.Tanh).cuda()

q_net = InferenceNetwork(n*m, z_dim+1, hidden_dim, num_layers=num_layers, activation=nn.Tanh).cuda()
params = list(p_net.parameters()) + list(q_net.parameters())

lr = 1e-4
optim = torch.optim.Adam(params, lr=lr)

minibatch_size = 100
max_epochs=20

# train_ds = torchvision.datasets.MNIST(root='data/mnist/', train=True, transform=None)
test_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                 torchvision.transforms.RandomRotation(
                                                     [0, 360],
                                                     torchvision.transforms.InterpolationMode.BILINEAR,
                                                     fill=0),
                                                 torchvision.transforms.Normalize((0.1307,), (0.3081,))
                                                 ])

test_ds = torchvision.datasets.MNIST(root='data/mnist/', train=False, transform=test_transform)
train_iterator = torch.utils.data.DataLoader(data_train, batch_size=minibatch_size,
                                         shuffle=True)
test_iterator = torch.utils.data.DataLoader(data_test, batch_size=minibatch_size)


Trainig Spatial-VAE on MNIST

In [None]:
for epoch in range(max_epochs):

    elbo_total, gen_loss_total, kl_loss_total = train_epoch(train_iterator, x_coord, p_net, q_net,
                                                          optim, rotate=True, translate=False,
                                                          dx_scale=dx_scale, theta_prior=theta_prior,
                                                          epoch=epoch, num_epochs=max_epochs, N=N,
                                                          use_cuda=use_cuda)

Testing Spatial VAE for Rotated MNIST

In [None]:
elbo_total, gen_loss_total, kl_loss_total = eval_model(test_iterator, x_coord, p_net,
                                                         q_net, rotate=True, translate=False,
                                                         dx_scale=dx_scale, theta_prior=theta_prior,
                                                         use_cuda=use_cuda
                                                        )

In [None]:
print(elbo_total)
print(gen_loss_total)
print(kl_loss_total)

#### If you have any questions, suggestions, edits for this tutorial please write to me at s.p.vadgama@uva.nl !