# CNN Autoencoder on fashion mnist
> Only differences to `cnn_autoencover_fastai2022.ipynb` is the cell loading the data and the number of epochs to overfit on a single image and the number of epochs for all images. Fashion MNIST seems to need a few more iterations but yields similar results to MNIST.

## References

* fastai 2022 / 2023 course part II:
    * [notebook 8](https://github.com/fastai/course22p2/blob/master/nbs/08_autoencoder.ipynb)
    * [lesson 15](https://course.fast.ai/Lessons/lesson15.html)

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import typing as T
from collections import defaultdict
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from einops import rearrange
from einops.layers.torch import Rearrange
from sklearn import metrics
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from torch.optim import SGD
from torch.utils.data import DataLoader, Dataset

import random_neural_net_models.cnn_autoencoder_fastai2022 as cnn_ae
import random_neural_net_models.convolution_lecun1990 as conv_lecun1990

sns.set_theme()

In [None]:
mnist = fetch_openml("Fashion-MNIST", version=1, cache=True, parser="auto")

Setting seeds

In [None]:
torch.manual_seed(42)

random.seed(42)

np.random.seed(42)

Getting device

In [None]:
def get_device() -> str:
    return "cuda" if torch.cuda.is_available() else "cpu"


device = get_device()
device

In [None]:
X = mnist["data"]
y = mnist["target"]
X.shape, y.shape

Selecting a few images to overfit on

In [None]:
n = 1
X0, y0 = X.iloc[:n], y.iloc[:n]
X0.shape

## Defining dataset and dataloader

In [None]:
ds = conv_lecun1990.DigitsDataset(X0, y0)

In [None]:
item = ds[0]
plt.imshow(item[0], cmap="gray", origin="upper")
plt.title(f"Label: {item[1]}")
plt.axis("off")
plt.tight_layout()

defining a dataloader

In [None]:
batch_size = 1
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=False)

In [None]:
item[0].shape

## overfitting

In [None]:
model = cnn_ae.Model()
model.double()
model.to(device);

In [None]:
opt = SGD(
    model.parameters(),
    lr=0.1,
)

In [None]:
loss_func = nn.MSELoss()

In [None]:
def get_hooks(
    model: cnn_ae.Model,
    hook_func: T.Callable = partial(
        conv_lecun1990.append_stats, hist_range=(0, 4)
    ),
) -> T.List[conv_lecun1990.Hook]:
    model_acts = [
        model.enc_act1,
        model.enc_act2,
        model.dec_act1,
    ]
    act_names = ["enc_act1", "enc_act2", "dec_act1"]
    hooks = [
        conv_lecun1990.Hook(layer, hook_func, name=name)
        for name, layer in zip(act_names, model_acts)
    ]
    return hooks

In [None]:
loss_history = conv_lecun1990.LossHistory(every_n=1)
parameter_history = conv_lecun1990.ParameterHistory(every_n=10)
hooks = get_hooks(model)

In [None]:
n_epochs = 20_000
_iter = 0
model.train()
for epoch in tqdm.tqdm(range(n_epochs), desc="Epochs", total=n_epochs):
    for i, (xb, _) in enumerate(dataloader):
        xb = xb.to(device)
        x_pred = model(xb)

        loss = loss_func(x_pred, xb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        parameter_history(model, _iter)
        loss_history(loss, _iter)

        _iter += 1

print("Done!")

plotting the loss

In [None]:
conv_lecun1990.draw_loss(loss_history)

plotting parameters

In [None]:
conv_lecun1990.draw_history(parameter_history, "enc_conv1")
conv_lecun1990.draw_history(parameter_history, "enc_conv2")
conv_lecun1990.draw_history(parameter_history, "dec_deconv1")
conv_lecun1990.draw_history(parameter_history, "dec_deconv2")

plotting activations

In [None]:
conv_lecun1990.draw_activations(hooks)

In [None]:
conv_lecun1990.clear_hooks(hooks)

In [None]:
train_features, _ = next(iter(dataloader))

In [None]:
model.eval();

inspecting predictions

In [None]:
train_features = train_features.to(device)
preds = model(train_features)
preds[0, :5, :5]

In [None]:
x_pred = preds.to("cpu").detach().numpy()
x_pred[0, :3, :5]

In [None]:
img = train_features[0].cpu()
img_pred = x_pred[0]
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
ax = axs[0]
ax.imshow(img, cmap="gray")
ax.set_title("Input image")
ax.axis("off")
ax = axs[1]
ax.imshow(img_pred, cmap="gray")
ax.set_title("Reconstructed image")
ax.axis("off")
plt.show()

So we can also overfit a single fashion mnist image using this setup, although the picture seems slightly blurry. 

## Reproducing all items

In [None]:
X0, X1, y0, y1 = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
ds = conv_lecun1990.DigitsDataset(X0, y0)
ds_test = conv_lecun1990.DigitsDataset(X1, y1)

In [None]:
batch_size = 256
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(ds_test, batch_size=500, shuffle=False)

In [None]:
model = cnn_ae.Model()
model.double()
model.to(device);

In [None]:
opt = SGD(
    model.parameters(),
    lr=0.1,
)

In [None]:
loss_func = nn.MSELoss()

In [None]:
loss_history = conv_lecun1990.LossHistory(every_n=1)
loss_history_test = conv_lecun1990.LossHistory(every_n=1)
parameter_history = conv_lecun1990.ParameterHistory(every_n=10)
hooks = get_hooks(model)

In [None]:
n_epochs = 30
_iter = 0
model.train()
for epoch in tqdm.tqdm(range(n_epochs), desc="Epochs", total=n_epochs):
    for i, (xb, _) in tqdm.tqdm(
        enumerate(dataloader), desc="Batches", total=len(dataloader)
    ):
        xb = xb.to(device)
        x_pred = model(xb)

        loss = loss_func(x_pred, xb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        parameter_history(model, _iter)
        loss_history(loss, _iter)

        _iter += 1

    # compute validation loss
    with torch.no_grad():
        model.eval()
        xs_pred, xs_true = [], []
        for xb, _ in dataloader_test:
            xb = xb.to(device)

            x_pred = model(xb)
            xs_pred.append(x_pred)
            xs_true.append(xb)

        x_pred = torch.cat(xs_pred, dim=0)
        x_true = torch.cat(xs_true, dim=0)
        loss_test = loss_func(x_pred, x_true)
        loss_history_test(loss_test, _iter)
        model.train()

print("Done!")

plotting the loss

In [None]:
conv_lecun1990.draw_loss(loss_history)
conv_lecun1990.draw_loss(loss_history_test, label="Test")

plotting parameters

In [None]:
conv_lecun1990.draw_history(parameter_history, "enc_conv1")
conv_lecun1990.draw_history(parameter_history, "enc_conv2")
conv_lecun1990.draw_history(parameter_history, "dec_deconv1")
conv_lecun1990.draw_history(parameter_history, "dec_deconv2")

plotting activations

In [None]:
conv_lecun1990.draw_activations(hooks)

In [None]:
conv_lecun1990.clear_hooks(hooks)

In [None]:
test_features, _ = next(iter(dataloader_test))

In [None]:
model.eval();

inspecting predictions

In [None]:
test_features = test_features.to(device)
preds = model(test_features)
preds[0, :5, :5]

In [None]:
x_pred = preds.to("cpu").detach().numpy()
x_pred[0, :3, :5]

In [None]:
def draw_pair(img: torch.Tensor, img_pred: torch.Tensor):
    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    ax = axs[0]
    ax.imshow(img, cmap="gray")
    ax.set_title("Input image")
    ax.axis("off")
    ax = axs[1]
    ax.imshow(img_pred, cmap="gray")
    ax.set_title("Reconstructed image")
    ax.axis("off")
    plt.show()


def draw_n_pairs(
    input_features: torch.Tensor, x_pred: torch.Tensor, n: int = 5
):
    _n = min(n, len(input_features))
    print(f"Drawing {_n} pairs")
    for i in range(_n):
        img = input_features[i].cpu()
        img_pred = x_pred[i]
        draw_pair(img, img_pred)


draw_n_pairs(test_features, x_pred, n=16)

The training seems to have phases where the model parameters / loss plateau before improving for fashion mnist as well. Also the reconstructed images are a little bit blurry, but subjectively not as much as in the [fashion mnist example used in the lectures](https://github.com/fastai/course22p2/blob/master/nbs/08_autoencoder.ipynb). So in the lectures nb they probably just did not run the training for long enough.