In [0]:
!pip install -q imageio

In [0]:
%matplotlib inline

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import itertools
import torchvision
import tqdm
import sys
import uuid
import typing
import math
import matplotlib.pyplot as plt
import imageio
import IPython

from torchvision import transforms as T
from torch import nn
from IPython import display
from IPython.display import Image

torch.manual_seed(42)
np.random.seed(231)

In [0]:
torch.manual_seed(42)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

mnist = torchvision.datasets.MNIST("./data",
                                   train=True,
                                   transform=torchvision.transforms.ToTensor(),
                                   download=True)

train = torch.utils.data.DataLoader(mnist, batch_size=32,
                                    shuffle=True, num_workers=2,
                                    pin_memory=True)

In [0]:
class Reshape(nn.Module):
    def __init__(self, *shape):
        super(Reshape, self).__init__()
        self._shape = tuple(shape)
    
    def forward(self, x: torch.Tensor):
        return x.view(-1, *self._shape)

class BVAE(nn.Module):
    def __init__(self, latent_dim: int):
        super(BVAE, self).__init__()
        self._latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64,
                      stride=2, kernel_size=3, padding=1),
            # 64x13x13                     
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=256,
                      stride=2, kernel_size=3, padding=1),
            # 256x7x7
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(in_features=(7 * 7 * 256), out_features=(latent_dim * 2))
        )
        self.decoder = nn.Sequential(
            nn.Linear(in_features=latent_dim, out_features=(7 * 7 * 128)),
            Reshape(128, 7, 7),
            nn.ConvTranspose2d(in_channels=128, out_channels=256,
                               kernel_size=3, stride=2,
                               padding=1, output_padding=0),
            # 256x13x13
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=256, out_channels=64,
                               kernel_size=3, stride=2,
                               padding=0, output_padding=1),
            nn.ReLU(),
            # 64x28x28
            nn.ConvTranspose2d(in_channels=64, out_channels=1,
                               kernel_size=3, stride=1,
                               padding=1, output_padding=0)
            # 1x28x28
        )

    def forward(self, x: torch.Tensor):
        mean, logvar = self.encode(x)
        # why logvar? numerical stability!
        assert mean.shape == logvar.shape
        z = self.reparam(mean, logvar)
        return self.decode(z), mean, logvar

    def reparam(self, mean, logvar):
        # logvar = log(sigma^2) = 2log(sigma)
        # 0.5 * logvar = log(sigma)
        # exp(0.5 * logvar) = sigma = std. dev.
        std = torch.exp(0.5 * logvar)
        eps = torch.from_numpy(np.random.normal(0, 1, mean.shape)).to(device, dtype=torch.float32)
        return eps * std + mean

    def encode(self, x):
        """returns mean and logvar, both of shape N x latent_dim"""
        # self.encoder(x) is of size Nx(latent_dim * 2) where N = batch size
        return torch.split(self.encoder(x), self._latent_dim, dim=1)

    def decode(self, z):
        return self.decoder(z)

The objective function we want to *maximise* is:

$$
\begin{align}
\mathbb{E}_{z \sim q}[log P(x|\mathbf z)] - D_{KL}(q_\phi(\mathbf z | x)\ ||\ p_\theta(\mathbf z))
\end{align}
$$

in other words, we want to *minimise*:
$$
\begin{align}
D_{KL}(q_\phi(\mathbf z | x)\ ||\ p_\theta(\mathbf z)) - \mathbb{E}_{z \sim q}[log P(x|\mathbf z)] 
\end{align}
$$

Turns out, the KL divergence term has a closed form when both $p$ and $q$ are
normal distributions. From [Kingma](https://arxiv.org/pdf/1312.6114.pdf):

$$
\begin{align}
-D_{KL}(q_\phi(\mathbf z)\ ||\ p_\theta(\mathbf z)) &= \int q_\theta(\mathbf z)\ (log\ p_\theta(\mathbf z) - log\ q_\theta(\mathbf z))\  d\mathbf z \\
&= \frac{1}{2}\sum_{j=1}^{J}(1 + log\ \sigma_j^2 - \mu_j^2 - \sigma_j^2)
\end{align}
$$

where both $q$ and $p$ are normal distributions and $J$ is the dimension of the latents (`latent_dim`).

In [0]:
def criterion(xh, x, mean, logvar, beta):
    # mean and logvar have shape [batch_size x latent_dim]

    # pxz = -log p(x|z): recall, binary cross entropy = x * -log sig(x') + (1 - x) * -log sig(1 - x')
    # reduction=sum => assumed p(x|z) is conditionally independent given z, hence log of a product = sum of logs
    pxz = F.binary_cross_entropy_with_logits(xh, x, reduction='sum')
    
    # compute KL using closed form
    kl = -0.5 * (1 + logvar - mean ** 2 - torch.exp(logvar)).sum()
    return (pxz + beta * kl) / x.shape[0]

In [0]:
def train_model(model: nn.Module, optimiser: torch.optim.Optimizer,
                beta: int, epochs: int):
    model.train()
    tepochs = tqdm.trange(epochs, file=sys.stdout)
    losses = []
    for e in tepochs:
        elosses = []
        for x, _ in train:
            x = x.to(device)
            logits, mean, logvar = model(x)
            loss = criterion(logits, x, mean, logvar, beta)
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
            elosses.append(loss.item())
            tepochs.set_postfix(loss=elosses[-1])
        losses.append(np.mean(elosses))
    return losses

# Train a $\beta$-vae model ($\beta$=10)

The reconstructed images seem blurrier than a regular VAE ($\beta$=1),
as the paper reported. The fidelity of these reconstructed images are expected
to be higher as $\beta$ decreases (see below where $\beta$=1).

> I tried $\beta = 20$ and the reconstructed images are beyond recognititon; not sure if it requires more training epochs or if it's too large for such a simple dataset.

In [0]:
# these constants will stay the same for all models
LATENT_DIM = 2**4
EPOCHS = 10

bvae = BVAE(LATENT_DIM).to(device)
optimiser = torch.optim.Adam(bvae.parameters())
losses = train_model(bvae, optimiser, beta=10, epochs=EPOCHS)

In [0]:
def imshow(fig, img: torch.Tensor, title: str=None):
    img = np.transpose(img.detach().cpu(), axes=(1, 2, 0)).squeeze()
    fig.imshow(img, cmap='gray')
    fig.set_xticks(())
    fig.set_yticks(())
    if title:
        fig.set_title(title)
    
batch1 = next(iter(train))[0].to(device)
with torch.no_grad():
    x_recon, *_ = bvae(batch1)
torch.sigmoid_(x_recon)
fig = plt.figure(figsize=(20, 20))
ax1 = fig.add_subplot(121)
imshow(ax1, torchvision.utils.make_grid(x_recon), title="Reconstructed")
ax2 = fig.add_subplot(122)
imshow(ax2, torchvision.utils.make_grid(batch1), title="Original")
fig.show()

# Train a regular VAE model

As expected, the fidelity of reconstructed images are higher.

In [0]:
vae = BVAE(LATENT_DIM).to(device)
optimiser = torch.optim.Adam(vae.parameters())
losses = train_model(vae, optimiser, beta=1, epochs=EPOCHS)

In [0]:
with torch.no_grad():
    x_recon, *_ = vae(batch1)
torch.sigmoid_(x_recon)
fig = plt.figure(figsize=(20, 20))
ax1 = fig.add_subplot(121)
imshow(ax1, torchvision.utils.make_grid(x_recon), title="Reconstructed")
ax2 = fig.add_subplot(122)
imshow(ax2, torchvision.utils.make_grid(batch1), title="Original")
fig.show()

# Visualising the latent space

`LatentVisualiser` produces GIFs that show how the reconstructed images change over time as the the latent is kept fixed for all but one dimension.
Furthermore, the initial latent variable is initilised as a normal vector
with mean 0 and variance 1. The following section explores what happens when it's initialised as a 0 vector.

In [0]:
class LatentVisualiser:
    def __init__(self, model: nn.Module, dim: int=0,
                 init_latent: torch.Tensor=None, width: int=8,
                 height: int=8):
        assert dim < LATENT_DIM 
        if init_latent is not None:
            self._rnd = init_latent.clone() 
        else:
            self._rnd = torch.from_numpy(np.random.normal(0, 1, rnd_shape)).to(device, dtype=torch.float32)
        self._dim = dim
        self._filenames = []
        self._model = model
        self._width = width
        self._height = height

    def __call__(self, iters: int, step: float, save: str=None):
        crnd = self._rnd.clone()
        for i in range(1, iters + 1):
            crnd[:, self._dim] += step
            with torch.no_grad():
                recon = torch.sigmoid(self._model.decode(crnd))
            fig = plt.figure(figsize=(self._width, self._height))
            ax = fig.add_subplot(111)
            # hacky!: to prevent similar filenames
            filename = f'{uuid.uuid4().hex}.png'
            imshow(ax, torchvision.utils.make_grid(recon),
                   title=f"Dimension {self._dim}: {i * step:0.4f}")
            fig.savefig(filename, bbox_inches='tight')
            self._filenames.append(filename)
            plt.close()
        output_filename = save if save else "_temp.gif"
        assert output_filename.endswith(".gif")
        self._generate_fig(output_filename)
        # with open(output_filename,'rb') as f:
        #     display.display(Image(data=f.read(), format='png'))
        return output_filename

    def _generate_fig(self, output_filename):
        assert len(self._filenames)
        # taken from https://www.tensorflow.org/tutorials/generative/cvae#generate_a_gif_of_all_the_saved_images
        with imageio.get_writer(output_filename, mode='I') as writer:
            last = -1
            for i, filename in enumerate(self._filenames):
                frame = 2*(i**0.5)
                if round(frame) > round(last): last = frame
                else: continue
                image = imageio.imread(filename)
                writer.append_data(image)
            image = imageio.imread(filename)
            writer.append_data(image)
    
    @staticmethod
    def combine_gifs(filenames: typing.List[str], output: str, cols: int) -> str:
        # rows = math.ceil(len(filenames) / cols)
        assert len(filenames) % cols == 0
        rows = int(len(filenames) / cols)
        gifs = [imageio.get_reader(f) for f in filenames]
        frames = gifs[0].get_length()
        assert all(g.get_length() == frames for g in gifs)
        gifs = itertools.cycle(gifs)

        with imageio.get_writer(output, mode='I') as writer:
            for _ in range(frames):
                # buf = [[] for _ in range(rows)]
                buf = []
                for row in range(rows):
                    row_buffer = [next(gifs).get_next_data() for _ in range(cols)]
                    buf.append(np.hstack(row_buffer))
                new_image = np.vstack(buf)
                writer.append_data(new_image)
        return output

In [0]:
# number of iterations used to increase the latent's value (in the non-fixed dimension)
ITERS = 100
# step size in each iteration
STEP = 0.2

## Visualising the latent space of $\beta=10$ (normal vector)

In [0]:
rnd = torch.from_numpy(np.random.normal(0, 1, size=(32, LATENT_DIM))).to(device, dtype=torch.float32)
gifs = []
for dim in range(0, LATENT_DIM):
    # model with beta=10
    gifs.append(LatentVisualiser(bvae, dim=dim, init_latent=rnd)(iters=ITERS, step=STEP, save=f"beta10_dim{dim}.gif"))

LatentVisualiser.combine_gifs(gifs, "beta10.gif", 2)
with open("beta10.gif",'rb') as f:
    display.display(Image(data=f.read(), format='png'))

## Visualising the latent space of $\beta=1$ (normal vector)


In [0]:
gifs = []
for dim in range(0, LATENT_DIM):
    # model with beta=1
    gifs.append(LatentVisualiser(vae, dim=dim, init_latent=rnd)(iters=ITERS, step=STEP, save=f"beta1_dim{dim}.gif"))

LatentVisualiser.combine_gifs(gifs, "beta1.gif", 2)
with open("beta1.gif",'rb') as f:
    display.display(Image(data=f.read(), format='png'))

## Visualising the latent space of $\beta=10$ (zero vector)

In [0]:
zrs = torch.zeros_like(rnd)
gifs = []
for dim in range(0, LATENT_DIM):
    # model with beta=10
    gifs.append(LatentVisualiser(bvae, dim=dim, init_latent=zrs)(iters=ITERS, step=STEP, save=f"beta10_dim{dim}_zrs.gif"))

LatentVisualiser.combine_gifs(gifs, "beta10_zrs.gif", 2)
with open("beta10_zrs.gif",'rb') as f:
    display.display(Image(data=f.read(), format='png'))

## Visualising the latent space of $\beta=1$ (zero vector)

In [0]:
gifs = []
for dim in range(0, LATENT_DIM):
    # model with beta=1
    gifs.append(LatentVisualiser(vae, dim=dim, init_latent=zrs)(iters=ITERS, step=STEP, save=f"beta1_dim{dim}_zrs.gif"))

LatentVisualiser.combine_gifs(gifs, "beta1_zrs.gif", 2)
with open("beta1_zrs.gif",'rb') as f:
    display.display(Image(data=f.read(), format='png'))

## Visualising the latent space of $\beta=10$ (zero vector) with decreasing values

Instead of increasing the amount by `STEP`, we decrease it.

In [0]:
gifs = []
for dim in range(0, LATENT_DIM):
    # model with beta=10
    gifs.append(LatentVisualiser(bvae, dim=dim, init_latent=zrs)(iters=ITERS, step=-STEP, save=f"beta10_dim{dim}_zrs_n.gif"))

LatentVisualiser.combine_gifs(gifs, "beta10_zrs_n.gif", 2)
with open("beta10_zrs_n.gif",'rb') as f:
    display.display(Image(data=f.read(), format='png'))

## Visualising the latent space of $\beta=1$ (zero vector) with decreasing values

In [0]:
gifs = []
for dim in range(0, LATENT_DIM):
    # model with beta=1
    gifs.append(LatentVisualiser(vae, dim=dim, init_latent=zrs)(iters=ITERS, step=-STEP, save=f"beta1_dim{dim}_zrs_n.gif"))

LatentVisualiser.combine_gifs(gifs, "beta1_zrs_n.gif", 2)
with open("beta1_zrs_n.gif",'rb') as f:
    display.display(Image(data=f.read(), format='png'))

# Ideas

1. Plot a t-SNE or PCA on the latents to see if similar digits cluster together.

2. What's the difference between $\beta=2$, $\beta=3$, and so on. Is (say) $\beta=4$ any different from $\beta=10$? (inspect visually and probably need to investigate quantitatively using the proposed metric)