<a href="https://colab.research.google.com/github/ziatdinovmax/atomai/blob/master/examples/notebooks/atomai_rVAE_digits.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Variational autoencoders and their extensions

Notebook prepared by Maxim Ziatdinov 

Email: ziatdinovmax@gmail.com


---

This notebook demonstrates application of different variational autoencoders (VAEs) to rotated images. Specifically, we discuss rotationally invariant version of VAE (rVAE) and class-conditioned rVAE. 

---



Install AtomAI:

In [None]:
!pip install atomai

Imports:

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets
from PIL import Image
import atomai_app as aoi

(Down)load MNIST:

In [None]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)

In [None]:
imstack_train_o = mnist_trainset.train_data.numpy()
labels_train = mnist_trainset.train_labels.numpy()

Apply arbitrary rotations:

In [None]:
imstack_train = np.zeros_like(imstack_train_o)

for i, digit in enumerate(imstack_train_o):
    im = Image.fromarray(digit)
    im = im.rotate(np.random.randint(-90, 90), resample=Image.BICUBIC)
    imstack_train[i] = im

imstack_train = imstack_train / imstack_train.max()

In [None]:
fig, axes = plt.subplots(10, 10, figsize=(8, 8),
                         subplot_kw={'xticks':[], 'yticks':[]},
                         gridspec_kw=dict(hspace=0.1, wspace=0.1))

for ax in axes.flat:
    i = np.random.randint(len(imstack_train))
    ax.imshow(imstack_train[i], cmap='binary', interpolation='nearest')
    ax.text(0.05, 0.05, str(labels_train[i]),
            transform=ax.transAxes, color='green')

We first apply regular VAE to the rotated digits. The result of VAE on the non-rotated MNIST data can be found [elsewhere](https://keras.io/examples/generative/vae/).

In [None]:
input_dim = (28, 28)

# Intitialize model
vae = aoi.models.VAE(input_dim) 
# Train
vae.fit(imstack_train, training_cycles=100, batch_size=100)

In [None]:
vae.manifold2d(d=12, origin='upper')

Looks like it didn't do a very good job. To tackle this problem and analyze the general imaging data, here we utilize the rotationally
invariant extension of VAE (rVAE). The rVAE is based on the concept of [spatial decoder](https://arxiv.org/abs/1909.11663) introduced by Bepler *et al.* and represents a special class of VAEs where three of the latent variables are rotation and x- and y-offsets, complemented by classical latent variables associated with image content. Thus, rVAE adds rotational and (in this case) offset invariance to the analysis workflow. In other words, it is expected to recognize the images even if they are shifted and rotated with respect to each other.

Initialize and train rVAE model:

In [None]:
input_dim = (28, 28)

# Intitialize model
rvae = aoi.models.rVAE(input_dim) 
# Train
rvae.fit(imstack_train, rotation_prior=np.pi/4, training_cycles=100, batch_size=100)

View results:

In [None]:
rvae.manifold2d(d=12, origin='upper')

Looks much better! Now let's train a class-conditioned rVAE.

In [None]:
input_dim = (28, 28)

# Intitialize model
rvae = aoi.models.rVAE(input_dim, nb_classes=10) 
# Train
rvae.fit(imstack_train, labels_train, rotation_prior=np.pi/2, training_cycles=100, batch_size=100)

View results:

In [None]:
for i in range(10):
    rvae.manifold2d(label=i, d=12, origin="upper")

Finally, we can run a regular VAE with class conditioning to confirm that it does *not* work (as one would expect):

In [None]:
input_dim = (28, 28)

# Intitialize model
vae = aoi.models.VAE(input_dim, nb_classes=10) 
# Train
vae.fit(imstack_train, labels_train, training_cycles=100, batch_size=100)

In [None]:
for i in range(10):
    vae.manifold2d(label=i, d=12, origin="upper")