# CNN Variational Autoencoder on MNIST

## References

* fastai 2022 / 2023 course part II:
    * [notebook 29](https://github.com/fastai/course22p2/blob/master/nbs/29_vae.ipynb)
    * [lesson 25](https://course.fast.ai/Lessons/lesson25.html)
* https://github.com/sksq96/pytorch-vae

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import re
import typing as T
from collections import defaultdict
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from einops import rearrange
from einops.layers.torch import Rearrange
from sklearn import metrics
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader, Dataset

import random_neural_net_models.cnn_autoencoder_fastai2022 as cnn_ae
import random_neural_net_models.convolution_lecun1990 as conv_lecun1990

sns.set_theme()

In [None]:
mnist = fetch_openml("mnist_784", version=1, cache=True, parser="auto")

Setting seeds

In [None]:
torch.manual_seed(42)

random.seed(42)

np.random.seed(42)

Getting device

In [None]:
def get_device() -> str:
    return "cuda" if torch.cuda.is_available() else "cpu"


device = get_device()
device

In [None]:
X = mnist["data"]
y = mnist["target"]
X.shape, y.shape

Selecting a few images to overfit on

In [None]:
n0 = 10
n1 = 1_000
X0, y0 = X.iloc[:n0], y.iloc[:n0]
X1, y1 = X.iloc[n0 : n1 + n0], y.iloc[n0 : n0 + n1]
X0.shape, X1.shape

## Defining dataset and dataloader

In [None]:
ds = conv_lecun1990.DigitsDataset(X0, y0)
ds_test = conv_lecun1990.DigitsDataset(X1, y1)

In [None]:
item = ds[0]
plt.imshow(item[0], cmap="gray", origin="upper")
plt.title(f"Label: {item[1]}")
plt.axis("off")
plt.tight_layout()

defining a dataloader

In [None]:
batch_size = 10
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=False)
dataloader_test = DataLoader(ds_test, batch_size=batch_size, shuffle=False)

## Model

In [None]:
class Model(nn.Module):
    # https://github.com/sksq96/pytorch-vae/blob/master/vae.py
    # https://github.com/fastai/course22p2/blob/master/nbs/29_vae.ipynb
    def __init__(self):
        super(Model, self).__init__()
        ks = 3
        stride = 2
        padding = ks // 2
        h, w = 28, 28

        self.add_dim = Rearrange("b h w -> b 1 h w")
        self.add_padding = nn.ZeroPad2d(2)
        self.enc_conv1 = nn.Conv2d(
            1, 2, kernel_size=ks, stride=stride, padding=padding
        )
        self.enc_act1 = nn.ReLU()
        # TODO: figure out which batchnorm to use for conv2d
        self.enc_bn1 = nn.BatchNorm2d(num_features=2)
        self.enc_conv2 = nn.Conv2d(
            2, 4, kernel_size=ks, stride=stride, padding=padding
        )
        self.enc_act2 = nn.ReLU()
        self.enc_bn2 = nn.BatchNorm2d(num_features=4)

        nn.init.kaiming_normal_(self.enc_conv1.weight)
        nn.init.kaiming_normal_(self.enc_conv2.weight)

        self.encoder = nn.Sequential(
            self.add_dim,  # 28x28 -> 1x28x28
            self.add_padding,  # 1x28x28 -> 1x32x32
            self.enc_conv1,  # 1x32x32 -> 1x16x16x2
            self.enc_act1,
            self.enc_bn1,
            self.enc_conv2,  # 1x16x16x2 -> 1x8x8x4
            self.enc_act2,
            self.enc_bn2,
        )

        # variational / latent part
        n_conv2 = 4 * 8 * 8
        n_latent = 200
        self.conv2flat = Rearrange("b c h w -> b (c h w)")

        self.mu = nn.Linear(n_conv2, n_latent)
        self.logvar = nn.Linear(n_conv2, n_latent)
        nn.init.kaiming_normal_(self.logvar.weight)
        nn.init.kaiming_normal_(self.mu.weight)

        self.mu_bn = nn.BatchNorm1d(n_latent)
        self.logvar_bn = nn.BatchNorm1d(n_latent)

        self.dec_dense1 = nn.Linear(n_latent, n_conv2)
        self.dec_act1 = nn.ReLU()
        self.dec_bn1 = nn.BatchNorm1d(n_conv2)

        self.flat2conv = Rearrange("b (c h w) -> b c h w", c=4, h=8, w=8)

        self.dec_deconv1 = cnn_ae.DeConv2d(4, 2, kernel_size=ks, stride=1)
        self.dec_act2 = nn.ReLU()
        self.dec_bn2 = nn.BatchNorm2d(num_features=2)

        self.dec_deconv2 = cnn_ae.DeConv2d(2, 1, kernel_size=ks, stride=1)
        self.dec_act3 = nn.Sigmoid()
        self.dec_bn3 = nn.BatchNorm2d(num_features=1)

        self.rm_padding = nn.ZeroPad2d(-2)
        self.rm_dim = Rearrange("b 1 h w -> b h w")

        nn.init.kaiming_normal_(self.dec_deconv1.weight)
        nn.init.kaiming_normal_(self.dec_deconv1.weight)
        nn.init.kaiming_normal_(self.dec_dense1.weight)

        self.decoder = nn.Sequential(
            self.dec_dense1,  # 1x200 -> 1x256
            self.dec_act1,
            self.dec_bn1,
            self.flat2conv,  # 1x256 -> 1x4x8x8
            self.dec_deconv1,  # 1x8x8x4 -> 1x16x16x2
            self.dec_act2,
            self.dec_bn2,
            self.dec_deconv2,  # 1x16x16x2 -> 1x32x32
            self.rm_padding,  # 1x32x32 -> 1x28x28
            self.dec_act3,
            self.dec_bn3,
            self.rm_dim,  # 1x28x28 -> 28x28
        )

    def forward(self, x):
        # encode
        x = self.encoder(x)

        # projecting from 1x8x8x4 to 1x256
        x = self.conv2flat(x)

        # variational / latent part
        mu = self.mu(x)
        logvar = self.logvar(x)
        mu = self.mu_bn(mu)
        logvar = self.logvar_bn(logvar)
        std = (0.5 * logvar).exp()
        eps = torch.randn_like(std)
        z = mu + eps * std

        # projecting back from 1xn_latent to 1x8x8x4
        # z = self.dec_dense1(z)
        # z = self.flat2conv(z)

        # decode
        x_hat = self.decoder(z)

        return x_hat, mu, logvar


def calc_distribution_divergence_loss(
    input: T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor], x: torch.Tensor
) -> torch.Tensor:
    _, mu, logvar = input
    s = 1 + logvar - mu.pow(2) - logvar.exp()
    return -0.5 * s.mean()


def calc_reconstruction_loss(
    input: T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor], x: torch.Tensor
) -> torch.Tensor:
    x_hat, _, _ = input
    return F.mse_loss(x, x_hat)


def calc_vae_loss(
    input: T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor], x: torch.Tensor
) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    reconstruction_loss = calc_reconstruction_loss(input, x)
    divergence_loss = calc_distribution_divergence_loss(input, x)
    total_loss = reconstruction_loss + divergence_loss
    return total_loss, reconstruction_loss, divergence_loss


def calc_vae_test_loss(
    model_output: T.List[T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
    x: torch.Tensor,
) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    x_hat = torch.cat([_x[0] for _x in model_output], dim=0)
    mu = torch.cat([_x[1] for _x in model_output], dim=0)
    logvar = torch.cat([_x[2] for _x in model_output], dim=0)
    _model_output = (x_hat, mu, logvar)
    reconstruction_loss = calc_reconstruction_loss(_model_output, x)
    divergence_loss = calc_distribution_divergence_loss(_model_output, x)
    total_loss = reconstruction_loss + divergence_loss
    return total_loss, reconstruction_loss, divergence_loss

## overfitting

In [None]:
def check_module_name_is_activation(name: str) -> bool:
    return re.match(r".*act\d$", name) is not None


print(
    check_module_name_is_activation("act1"),
    check_module_name_is_activation("blub_act1"),
    check_module_name_is_activation("blub"),
    check_module_name_is_activation("act1_bla"),
)

In [None]:
def check_module_name_grad_relevant(name: str) -> bool:
    return (
        name
        not in [
            "rm_dim",
            "rm_padding",
            "conv2flat",
            "flat2conv",
            "encoder",
            "decoder",
        ]
    ) and re.match(r".*act\d$", name) is None


print(
    check_module_name_grad_relevant("rm_dim"),
    check_module_name_grad_relevant("encoder"),
    check_module_name_grad_relevant("decoder"),
    check_module_name_grad_relevant("dec_bn3"),
    check_module_name_grad_relevant("dec_act3"),
)

In [None]:
model = Model()
model = conv_lecun1990.ModelTelemetry(
    model,
    func_is_act=check_module_name_is_activation,
    func_is_grad_relevant=check_module_name_grad_relevant,
    loss_names=("total", "reconstruction", "divergence"),
)
model.double()
model.to(device);

In [None]:
# opt = SGD(
#     model.parameters(),
#     lr=0.1,
# )
opt = Adam(model.parameters(), lr=3e-2, eps=1e-5)
opt

In [None]:
loss_func = calc_vae_loss
loss_func_test = calc_vae_test_loss

In [None]:
_iter = 0

In [None]:
n_epochs = 1200

model.train()
for epoch in tqdm.tqdm(range(n_epochs), desc="Epochs", total=n_epochs):
    for i, (xb, _) in enumerate(dataloader):
        xb = xb.to(device)
        x_pred = model(xb)

        loss, reconstruction_loss, divergence_loss = loss_func(x_pred, xb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        model.loss_history_train(
            (loss, reconstruction_loss, divergence_loss), _iter
        )
        model.parameter_history(_iter)

        _iter += 1

    # compute validation loss
    with torch.no_grad():
        model.eval()
        xs_pred, xs_true = [], []
        for xb, _ in dataloader_test:
            xb = xb.to(device)

            x_pred = model(xb)
            xs_pred.append(x_pred)
            xs_true.append(xb)

        x_true = torch.cat(xs_true, dim=0)
        (
            loss_test,
            reconstruction_loss_test,
            divergence_loss_test,
        ) = loss_func_test(xs_pred, x_true)

        model.loss_history_test(
            (loss_test, reconstruction_loss_test, divergence_loss_test), _iter
        )
        model.train()

print("Done!")

plotting gradients

In [None]:
model.draw_gradient_stats(yscale="log", figsize=(12, 20))

plotting activations

In [None]:
model.draw_activation_stats(yscale="log")

drawing histograms of the weights and biases across training iterations

In [None]:
model.draw_parameter_stats(
    "enc_conv1",
    "enc_conv2",
    "mu",
    "logvar",
    "dec_dense1",
    "dec_deconv1",
    "dec_deconv2",
)

plotting losses

In [None]:
model.draw_loss_history_train()

In [None]:
model.draw_loss_history_test()  # yscale="log"

In [None]:
model.clean_hooks()

In [None]:
train_features, _ = next(iter(dataloader))

In [None]:
next(iter(dataloader))

In [None]:
model.eval();

inspecting predictions

In [None]:
train_features = train_features.to(device)
preds, _, _ = model(train_features)
preds[0, :5, :5]

In [None]:
x_pred = preds.detach().sigmoid().cpu().numpy()
x_pred[0, :3, :5]

In [None]:
img = train_features[0].cpu()
img_pred = x_pred[0]
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
ax = axs[0]
ax.imshow(img, cmap="gray")
ax.set_title("Input image")
ax.axis("off")
ax = axs[1]
ax.imshow(img_pred, cmap="gray")
ax.set_title("Reconstructed image")
ax.axis("off")
plt.show()

So we can overfit using this setup. Interestingly there seem to be 3 stages of optimization and it took about 15k iterations to get there and there still seems to be some room. So more iterations than without the variational / latent component. Other notable differences to the plain autoencoder are:
* overfitting not achieved within 10k iterations if the `mu` and `logvar` estimates are not fed into a dense layer before reshaping back into 8x8x4 for deconvolution
* the loss is much noisier with the variational approach

## Reproducing 10 digits

In [None]:
X0, X1, y0, y1 = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
ds = conv_lecun1990.DigitsDataset(X0, y0)
ds_test = conv_lecun1990.DigitsDataset(X1, y1)

In [None]:
batch_size = 256
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)
dataloader_test = DataLoader(
    ds_test, batch_size=500, shuffle=False, drop_last=True
)

In [None]:
model = Model()
model = conv_lecun1990.ModelTelemetry(
    model,
    func_is_act=check_module_name_is_activation,
    func_is_grad_relevant=check_module_name_grad_relevant,
    loss_names=("total", "reconstruction", "divergence"),
    gradients_every_n=100,
    activations_every_n=100,
    parameter_every_n=100,
)
model.double()
model.to(device);

In [None]:
# opt = SGD(
#     model.parameters(),
#     lr=0.1,
# )
opt = Adam(model.parameters(), lr=3e-2, eps=1e-5)
opt

In [None]:
_iter = 0

In [None]:
n_epochs = 50

model.train()
for epoch in tqdm.tqdm(range(n_epochs), desc="Epochs", total=n_epochs):
    for xb, _ in dataloader:
        xb = xb.to(device)
        x_pred = model(xb)

        loss, reconstruction_loss, divergence_loss = loss_func(x_pred, xb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        model.loss_history_train(
            (loss, reconstruction_loss, divergence_loss), _iter
        )
        model.parameter_history(_iter)

        _iter += 1

    # compute validation loss
    with torch.no_grad():
        model.eval()
        xs_pred, xs_true = [], []
        for xb, _ in dataloader_test:
            xb = xb.to(device)

            x_pred = model(xb)
            xs_pred.append(x_pred)
            xs_true.append(xb)

        x_true = torch.cat(xs_true, dim=0)
        (
            loss_test,
            reconstruction_loss_test,
            divergence_loss_test,
        ) = loss_func_test(xs_pred, x_true)

        model.loss_history_test(
            (loss_test, reconstruction_loss_test, divergence_loss_test), _iter
        )
        model.train()

print("Done!")

plotting gradients

In [None]:
model.draw_gradient_stats(yscale="log", figsize=(12, 20))

plotting activations

In [None]:
model.draw_activation_stats(yscale="log")

drawing histograms of the weights and biases across training iterations

In [None]:
model.draw_parameter_stats(
    "enc_conv1",
    "enc_conv2",
    "mu",
    "logvar",
    "dec_dense1",
    "dec_deconv1",
    "dec_deconv2",
)

plotting losses

In [None]:
model.draw_loss_history_train(yscale="log")

In [None]:
model.draw_loss_history_test(yscale="log")

In [None]:
# TODO: why does the validation loss explode initially?

In [None]:
# TODO: enc_act1 and enc_act2 are pretty much 0, why?

In [None]:
test_features, _ = next(iter(dataloader_test))

In [None]:
model.eval();

inspecting predictions

In [None]:
test_features = test_features.to(device)
preds, _, _ = model(test_features)
preds[0, :5, :5]

In [None]:
test_features[0, :3, :5]

In [None]:
x_pred = preds.detach().sigmoid().cpu().numpy()
x_pred[0, :3, :5]

In [None]:
def draw_pair(img: torch.Tensor, img_pred: torch.Tensor):
    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    ax = axs[0]
    ax.imshow(img, cmap="gray")
    ax.set_title("Input image")
    ax.axis("off")
    ax = axs[1]
    ax.imshow(img_pred, cmap="gray")
    ax.set_title("Reconstructed image")
    ax.axis("off")
    plt.show()


def draw_n_pairs(
    input_features: torch.Tensor, x_pred: torch.Tensor, n: int = 5
):
    _n = min(n, len(input_features))
    print(f"Drawing {_n} pairs")
    for i in range(_n):
        img = input_features[i].cpu()
        img_pred = x_pred[i]
        draw_pair(img, img_pred)


draw_n_pairs(test_features, x_pred, n=16)

In [None]:
# TODO: what is broken that the reconstruction is not working - yields white blobs?