# CNN Autoencoder on mnist

## References

* fastai 2022 / 2023 course part II:
    * [notebook 29](https://github.com/fastai/course22p2/blob/master/nbs/29_vae.ipynb)
    * [lesson 25](https://course.fast.ai/Lessons/lesson25.html)
* https://github.com/sksq96/pytorch-vae

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import typing as T
from collections import defaultdict
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from einops import rearrange
from einops.layers.torch import Rearrange
from sklearn import metrics
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from torch.optim import SGD
from torch.utils.data import DataLoader, Dataset

import random_neural_net_models.cnn_autoencoder_fastai2022 as cnn_ae
import random_neural_net_models.convolution_lecun1990 as conv_lecun1990

sns.set_theme()

In [None]:
mnist = fetch_openml("mnist_784", version=1, cache=True, parser="auto")

Setting seeds

In [None]:
torch.manual_seed(42)

random.seed(42)

np.random.seed(42)

Getting device

In [None]:
def get_device() -> str:
    return "cuda" if torch.cuda.is_available() else "cpu"


device = get_device()
device

In [None]:
X = mnist["data"]
y = mnist["target"]
X.shape, y.shape

Selecting a few images to overfit on

In [None]:
n = 1
X0, y0 = X.iloc[:n], y.iloc[:n]
X0.shape

## Defining dataset and dataloader

In [None]:
ds = conv_lecun1990.DigitsDataset(X0, y0)

In [None]:
item = ds[0]
plt.imshow(item[0], cmap="gray", origin="upper")
plt.title(f"Label: {item[1]}")
plt.tight_layout()

defining a dataloader

In [None]:
batch_size = 1
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=False)

In [None]:
item[0].shape

## Model

In [None]:
class Model(nn.Module):
    # https://github.com/sksq96/pytorch-vae/blob/master/vae.py
    # https://github.com/fastai/course22p2/blob/master/nbs/29_vae.ipynb
    def __init__(self):
        super(Model, self).__init__()
        ks = 3
        stride = 2
        padding = ks // 2

        self.add_dim = Rearrange("b h w -> b 1 h w")
        self.add_padding = nn.ZeroPad2d(2)
        self.enc_conv1 = nn.Conv2d(
            1, 2, kernel_size=ks, stride=stride, padding=padding
        )
        self.enc_act1 = nn.ReLU()
        self.enc_conv2 = nn.Conv2d(
            2, 4, kernel_size=ks, stride=stride, padding=padding
        )
        self.enc_act2 = nn.ReLU()

        self.encoder = nn.Sequential(
            self.add_dim,  # 28x28 -> 1x28x28
            self.add_padding,  # 1x28x28 -> 1x32x32
            self.enc_conv1,  # 1x32x32 -> 1x16x16x2
            self.enc_act1,
            self.enc_conv2,  # 1x16x16x2 -> 1x8x8x4
            self.enc_act2,
        )

        # variational / latent part
        n_conv2 = 4 * 8 * 8
        n_latent = n_conv2  # 200
        self.conv2flat = Rearrange("b c h w -> b (c h w)")
        self.mu = nn.Linear(n_conv2, n_latent)
        self.logvar = nn.Linear(n_conv2, n_latent)
        self.latent2conv = nn.Linear(n_latent, n_conv2)
        self.flat2conv = Rearrange("b (c h w) -> b c h w", c=4, h=8, w=8)

        self.dec_deconv1 = cnn_ae.DeConv2d(4, 2, kernel_size=ks, stride=1)
        self.dec_act1 = nn.ReLU()
        self.dec_deconv2 = cnn_ae.DeConv2d(2, 1, kernel_size=ks, stride=1)
        self.dec_act2 = nn.Sigmoid()
        self.rm_padding = nn.ZeroPad2d(-2)
        self.rm_dim = Rearrange("b 1 h w -> b h w")

        self.decoder = nn.Sequential(
            self.dec_deconv1,  # 1x8x8x4 -> 1x16x16x2
            self.dec_act1,
            self.dec_deconv2,  # 1x16x16x2 -> 1x32x32
            self.rm_padding,  # 1x32x32 -> 1x28x28
            self.dec_act2,
            self.rm_dim,  # 1x28x28 -> 28x28
        )

    def forward(self, x):
        # encode
        x = self.encoder(x)

        # projecting from 1x8x8x4 to 1x256
        x = self.conv2flat(x)

        # variational / latent part
        mu = self.mu(x)
        logvar = self.logvar(x)
        std = (0.5 * logvar).exp()
        eps = torch.randn_like(std)
        x = mu + eps * std

        # projecting back from 1xn_latent to 1x8x8x4
        x = self.latent2conv(x)
        x = self.flat2conv(x)

        # decode
        x = self.decoder(x)

        return x, mu, logvar


def calc_distribution_divergence_loss(
    input: T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor], x: torch.Tensor
) -> torch.Tensor:
    _, mu, logvar = input
    s = 1 + logvar - mu.pow(2) - logvar.exp()
    return -0.5 * s.mean()


def calc_reconstruction_loss(
    input: T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor], x: torch.Tensor
) -> torch.Tensor:
    x_hat, _, _ = input
    return F.mse_loss(x, x_hat)


def calc_vae_loss(
    input: T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor], x: torch.Tensor
) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    reconstruction_loss = calc_reconstruction_loss(input, x)
    divergence_loss = calc_distribution_divergence_loss(input, x)
    total_loss = reconstruction_loss + divergence_loss
    return total_loss, reconstruction_loss, divergence_loss

## overfitting

In [None]:
model = Model()
model.double()
model.to(device);

In [None]:
opt = SGD(
    model.parameters(),
    lr=0.1,
)

In [None]:
loss_func = calc_vae_loss

In [None]:
def get_hooks(
    model: cnn_ae.Model,
    hook_func: T.Callable = partial(
        conv_lecun1990.append_stats, hist_range=(0, 4)
    ),
) -> T.List[conv_lecun1990.Hook]:
    model_acts = [
        model.enc_act1,
        model.enc_act2,
        model.dec_act1,
    ]
    act_names = ["enc_act1", "enc_act2", "dec_act1"]
    hooks = [
        conv_lecun1990.Hook(layer, hook_func, name=name)
        for name, layer in zip(act_names, model_acts)
    ]
    return hooks

In [None]:
class ParameterHistory:
    def __init__(
        self,
        every_n: int = 1,
        hist_bins: int = 80,
        hist_range: T.Tuple[float, float] = (0.0, 2.0),
    ):
        self.history = defaultdict(list)
        self.every_n = every_n
        self.iter = []
        self.hist_bins = hist_bins
        self.hist_range = hist_range

    def __call__(self, model: nn.Module, _iter: int):
        if _iter % self.every_n != 0:
            return
        state_dict = model.state_dict()

        for name, tensor in state_dict.items():
            counts = (
                tensor.clone()
                .cpu()
                .abs()
                .flatten()
                .histc(self.hist_bins, self.hist_range[0], self.hist_range[1])
                .numpy()
            )
            self.history[name].append(counts)

        self.iter.append(_iter)

    def get_df(self, name: str) -> pd.DataFrame:
        df = [
            pd.DataFrame({"value": w}).assign(iter=i)
            for i, w in zip(self.iter, self.history[name])
        ]
        return pd.concat(df, ignore_index=True)[["iter", "value"]]

In [None]:
loss_history = conv_lecun1990.LossHistory(every_n=1)
loss_history_reconstruction = conv_lecun1990.LossHistory(every_n=1)
divergence_loss_history = conv_lecun1990.LossHistory(every_n=1)
parameter_history = ParameterHistory(every_n=1, hist_range=(0, 2))
hooks = get_hooks(model)

In [None]:
n_epochs = 15_000
_iter = 0
model.train()
for epoch in tqdm.tqdm(range(n_epochs), desc="Epochs", total=n_epochs):
    for i, (xb, _) in enumerate(dataloader):
        xb = xb.to(device)
        x_pred = model(xb)

        loss, reconstruction_loss, divergence_loss = loss_func(x_pred, xb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        parameter_history(model, _iter)
        loss_history(loss, _iter)
        loss_history_reconstruction(reconstruction_loss, _iter)
        divergence_loss_history(divergence_loss, _iter)

        _iter += 1

print("Done!")

plotting the loss

In [None]:
conv_lecun1990.draw_loss(loss_history, label="Train", window=100)
conv_lecun1990.draw_loss(
    loss_history_reconstruction, label="Train (reconstruction)", window=100
)
conv_lecun1990.draw_loss(
    divergence_loss_history, label="Train (divergence)", window=100
)

plotting parameters

In [None]:
def stack_weight_history(
    history: ParameterHistory, name: str, suffix: str, log1p: bool = True
) -> np.ndarray:
    hist = np.column_stack(history.history[f"{name}.{suffix}"])
    if log1p:
        hist = np.log1p(hist)
    return hist


def draw_history(
    history: ParameterHistory,
    name: str,
    figsize: T.Tuple[int, int] = (12, 4),
    hist_aspect_w: float = 25.0,
    hist_aspect_b: float = 25.0,
    log1p: bool = False,
) -> None:
    fig, axs = plt.subplots(figsize=figsize, nrows=2, sharex=True)

    ax = axs[0]

    hist = stack_weight_history(
        history, name=name, suffix="weight", log1p=log1p
    )
    ax.imshow(hist, aspect=hist_aspect_w, origin="lower")
    ax.set_axis_off()
    ax.set_title(f"{name} - weight")

    ax = axs[1]

    hist = stack_weight_history(history, name=name, suffix="bias", log1p=log1p)
    ax.imshow(hist, aspect=hist_aspect_b, origin="lower")
    ax.set_axis_off()
    ax.set_title(f"{name} - bias")

    plt.tight_layout()
    plt.show()

In [None]:
draw_history(parameter_history, "enc_conv1")
draw_history(parameter_history, "enc_conv2")
draw_history(parameter_history, "mu")
draw_history(parameter_history, "logvar")
draw_history(parameter_history, "latent2conv")
draw_history(parameter_history, "dec_deconv1")
draw_history(parameter_history, "dec_deconv2")

plotting activations

In [None]:
conv_lecun1990.draw_activations(hooks)

In [None]:
conv_lecun1990.clear_hooks(hooks)

In [None]:
train_features, _ = next(iter(dataloader))

In [None]:
model.eval();

inspecting predictions

In [None]:
train_features = train_features.to(device)
preds, _, _ = model(train_features)
preds[0, :5, :5]

In [None]:
x_pred = preds.to("cpu").detach().numpy()
x_pred[0, :3, :5]

In [None]:
img = train_features[0].cpu()
img_pred = x_pred[0]
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
ax = axs[0]
ax.imshow(img, cmap="gray")
ax.set_title("Input image")
ax.axis("off")
ax = axs[1]
ax.imshow(img_pred, cmap="gray")
ax.set_title("Reconstructed image")
ax.axis("off")
plt.show()

So we can overfit using this setup. Interestingly there seem to be 3 stages of optimization and it took about 15k iterations to get there and there still seems to be some room. So more iterations than without the variational / latent component. Other notable differences to the plain autoencoder are:
* overfitting not achieved within 10k iterations if the `mu` and `logvar` estimates are not fed into a dense layer before reshaping back into 8x8x4 for deconvolution
* the loss is much noisier with the variational approach

## Reproducing 10 digits

In [None]:
X0, X1, y0, y1 = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
ds = conv_lecun1990.DigitsDataset(X0, y0)
ds_test = conv_lecun1990.DigitsDataset(X1, y1)

In [None]:
batch_size = 256
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(ds_test, batch_size=500, shuffle=False)

In [None]:
model = Model()
model.double()
model.to(device);

In [None]:
opt = SGD(
    model.parameters(),
    lr=0.1,
)

In [None]:
loss_history = conv_lecun1990.LossHistory(every_n=1)
loss_history_reconstruction = conv_lecun1990.LossHistory(every_n=1)
divergence_loss_history = conv_lecun1990.LossHistory(every_n=1)

loss_history_test = conv_lecun1990.LossHistory(every_n=1)
loss_history_reconstruction_test = conv_lecun1990.LossHistory(every_n=1)
loss_history_divergence_test = conv_lecun1990.LossHistory(every_n=1)

parameter_history = ParameterHistory(every_n=1)
hooks = get_hooks(model)

In [None]:
def calc_vae_test_loss(
    model_output: T.List[T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
    x: torch.Tensor,
) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    x_hat = torch.cat([_x[0] for _x in model_output], dim=0)
    mu = torch.cat([_x[1] for _x in model_output], dim=0)
    logvar = torch.cat([_x[2] for _x in model_output], dim=0)
    _model_output = (x_hat, mu, logvar)
    reconstruction_loss = calc_reconstruction_loss(_model_output, x)
    divergence_loss = calc_distribution_divergence_loss(_model_output, x)
    total_loss = reconstruction_loss + divergence_loss
    return total_loss, reconstruction_loss, divergence_loss

In [None]:
_iter = 0

In [None]:
n_epochs = 100

model.train()
for epoch in tqdm.tqdm(range(n_epochs), desc="Epochs", total=n_epochs):
    for xb, _ in dataloader:
        xb = xb.to(device)
        x_pred = model(xb)

        loss, reconstruction_loss, divergence_loss = loss_func(x_pred, xb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        parameter_history(model, _iter)
        loss_history(loss, _iter)
        loss_history_reconstruction(reconstruction_loss, _iter)
        divergence_loss_history(divergence_loss, _iter)

        _iter += 1

    # compute validation loss
    with torch.no_grad():
        model.eval()
        xs_pred, xs_true = [], []
        for xb, _ in dataloader_test:
            xb = xb.to(device)

            x_pred = model(xb)
            xs_pred.append(x_pred)
            xs_true.append(xb)

        x_true = torch.cat(xs_true, dim=0)
        (
            loss_test,
            reconstruction_loss_test,
            divergence_loss_test,
        ) = calc_vae_test_loss(xs_pred, x_true)

        loss_history_test(loss_test, _iter)
        loss_history_reconstruction_test(reconstruction_loss_test, _iter)
        loss_history_divergence_test(divergence_loss_test, _iter)
        model.train()

print("Done!")

plotting the loss

In [None]:
conv_lecun1990.draw_loss(loss_history)
conv_lecun1990.draw_loss(loss_history_reconstruction, label="Reconstruction")
conv_lecun1990.draw_loss(divergence_loss_history, label="Divergence")

In [None]:
conv_lecun1990.draw_loss(loss_history_test, label="Test")
conv_lecun1990.draw_loss(
    loss_history_reconstruction_test, label="Test (Reconstruction)"
)
conv_lecun1990.draw_loss(
    loss_history_divergence_test, label="Test (Divergence)"
)

plotting parameters

In [None]:
draw_history(parameter_history, "enc_conv1")
draw_history(parameter_history, "enc_conv2")
draw_history(parameter_history, "mu")
draw_history(parameter_history, "logvar")
draw_history(parameter_history, "latent2conv")
draw_history(parameter_history, "dec_deconv1")
draw_history(parameter_history, "dec_deconv2")

plotting activations

In [None]:
conv_lecun1990.draw_activations(hooks)

In [None]:
# TODO: enc_act1 and enc_act2 are pretty much 0, why?

In [None]:
conv_lecun1990.clear_hooks(hooks)

In [None]:
test_features, _ = next(iter(dataloader_test))

In [None]:
model.eval();

inspecting predictions

In [None]:
test_features = test_features.to(device)
preds, _, _ = model(test_features)
preds[0, :5, :5]

In [None]:
test_features[0, :3, :5]

In [None]:
x_pred = preds.to("cpu").detach().numpy()
x_pred[0, :3, :5]

In [None]:
def draw_pair(img: torch.Tensor, img_pred: torch.Tensor):
    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    ax = axs[0]
    ax.imshow(img, cmap="gray")
    ax.set_title("Input image")
    ax.axis("off")
    ax = axs[1]
    ax.imshow(img_pred, cmap="gray")
    ax.set_title("Reconstructed image")
    ax.axis("off")
    plt.show()


def draw_n_pairs(
    input_features: torch.Tensor, x_pred: torch.Tensor, n: int = 5
):
    _n = min(n, len(input_features))
    print(f"Drawing {_n} pairs")
    for i in range(_n):
        img = input_features[i].cpu()
        img_pred = x_pred[i]
        draw_pair(img, img_pred)


draw_n_pairs(test_features, x_pred, n=16)

In [None]:
# TODO: what is broken that the reconstruction is not working - yields white blobs?