# First Convolution based neural net

## References

* LeCun et al. 1990, _Handwritten Digit Recognition: Applications of Neural Net Chips and Automatic Learning_, [Neurocomputing](https://link.springer.com/chapter/10.1007/978-3-642-76153-9_35)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import typing as T
from collections import defaultdict

# plot first item in dataset
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm

# load mnist using scikit-learn
from sklearn.datasets import fetch_openml
from torch.optim import SGD
from torch.utils.data import DataLoader, Dataset

sns.set_theme()

## LeCun et al. 1990, "Handwritten Digit Recognition: Applications of Neural Net Chips and Automatic Learning"
> The following tries to reproduce the original paper. Note that the digits dataset actually used in the paper could not be found and [MNIST 784](https://www.openml.org/search?type=data&status=active&id=554) is used instead

specifics in the paper:

* neural net
    * weight initialization: uniformly at random $\in [-2.4 / F_i, 2.4 / F_i]$ with $F_i = $ number of inputs of the unit
    * "tanh activation": $A \cdot \tanh (S \cdot a)$ with $A = 1.716$, $S = 2/3$ and $a = \text{weights} \cdot \text{input}$
    * 256 input (16 x 16 pixel images)
    * layer #1: 
        * convolution with 12 5x5-kernels and stride 2 (output: 8 x 8 x 12 = 786 "units")
        * tanh activation
        * $F_i = 5 \cdot 5 \cdot n_\text{input-channels} = 5 \cdot 5 \cdot 1 = 25$
    * layer #2: 
        * convolution with 12 5x5-kernels and stride 2 (output: 4 x 4 x 12 = 192 "units")
        * tanh activation
        * $F_i = 5 \cdot 5 \cdot n_\text{input-channels} = 5 \cdot 5 \cdot 12 = 300$
    * layer #3:
        * dense with 30 neurons
        * tanh activation
        * $F_i = 4 \cdot 4 \cdot 12 = 192$
    * layer #4:
        * dense output layer with 10 neurons
        * tanh activation
        * $F_i = 30$
* target: vector of 10 values either 1 or -1 (so 9x -1 and 1x 1)
* loss: mean squared error between prediction and target (paper reached 1.8e-2 on test and 2.5e-3 on train)
* error rates: 0.14% on train, 5% on test
* training:
    * stochastic gradient descent (1 sample per backpropagation)
    * samples always in the same order, no shuffling
    * 23 or 30 epochs, paper is ambiguous
    * learning rate was set using some not defined 2nd order derivative method

In [None]:
mnist = fetch_openml("mnist_784", version=1, cache=True, parser="auto")

In [None]:
torch.manual_seed(42)

random.seed(42)

np.random.seed(42)

In [None]:
def get_device() -> str:
    return "cuda" if torch.cuda.is_available() else "cpu"


device = get_device()
device

In [None]:
X = mnist["data"]
y = mnist["target"]
X.shape, y.shape

In [None]:
ix0 = 100
X0, y0 = X[:ix0], y[:ix0]

In [None]:
class DigitsDataset(Dataset):
    def __init__(self, X: pd.DataFrame, y: pd.Series, edge: int = 28):
        self.X = X
        self.y = y
        self.edge = edge

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx: int) -> T.Tuple[torch.Tensor, int]:
        img = (
            torch.from_numpy(self.X.iloc[idx].values / 255.0)  # normalizing
            .reshape(self.edge, self.edge)
            .double()
        )
        label = int(self.y.iloc[idx])
        return (img, label)

In [None]:
ds = DigitsDataset(X0, y0)

In [None]:
item = ds[4]
plt.imshow(item[0], cmap="gray", origin="upper")
plt.title(f"Label: {item[1]}")
plt.tight_layout()

In [None]:
batch_size = 1
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True)

In [None]:
train_features, train_labels = next(iter(dataloader))

In [None]:
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0]  # .reshape((28,28))
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

In [None]:
def calc_conv_output_dim(input_dim, kernel_size, padding, stride):
    return int((input_dim - kernel_size + 2 * padding) / stride + 1)


calc_conv_output_dim(28, 5, 2, 2), calc_conv_output_dim(14, 5, 2, 2)

In [None]:
class MyConv2d(torch.nn.Module):
    def __init__(
        self,
        edge: int,
        n_in_channels: int = 1,
        n_out_channels: int = 1,
        kernel_width: int = 5,
        kernel_height: int = 5,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        lecun_init: bool = True,
    ):
        super().__init__()
        self.weight = nn.Parameter(
            torch.empty(
                n_out_channels,
                n_in_channels,
                kernel_width,
                kernel_height,
                dtype=torch.double,
            )
        )
        self.bias = nn.Parameter(
            torch.empty(n_out_channels, dtype=torch.double)
        )
        if lecun_init:
            s = 2.4 / (n_in_channels * kernel_width * kernel_height)
            self.weight.data.uniform_(-s, s)
            self.bias.data.uniform_(-s, s)

        else:
            self.weight.data.normal_(0, 1.0)
            self.bias.data.normal_(0, 1.0)

        self.unfold = torch.nn.Unfold(
            kernel_size=(kernel_height, kernel_width),
            dilation=dilation,
            padding=padding,
            stride=stride,
        )
        out_h = out_w = calc_conv_output_dim(
            edge, kernel_width, padding, stride
        )
        self.fold = torch.nn.Fold(
            output_size=(out_h, out_w),
            kernel_size=(1, 1),
            dilation=dilation,
            padding=0,
            stride=1,
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # inspiration from: https://discuss.pytorch.org/t/make-custom-conv2d-layer-efficient-wrt-speed-and-memory/70175/2
        batch_size, in_channels, in_h, in_w = input.shape
        out_channels, in_channels_weight, _, _ = self.weight.shape

        if in_h != in_w:
            raise ValueError(
                f"Input height {in_h} is not equal to width {in_w}"
            )
        if in_channels != in_channels_weight:
            raise ValueError(
                f"Input channels {in_channels} is not equal to weight input channels {in_channels_weight}"
            )

        # (N,C,in_h,in_w) -> (N, C*kh*kw, num_patches)
        # N = batch_size, C = in_channels, kh = kernel_height, kw = kernel_width

        input_unfolded = self.unfold(input)

        # (N, C*kh*kw, num_patches) -> (N, out_channels, num_patches)
        input_unfolded = input_unfolded.transpose(
            1, 2
        )  # (N, num_patches, C*kh*kw)
        weight = self.weight.view(
            self.weight.size(0), -1
        ).T  # (C*kh*kw, out_channels)
        output_unfolded = input_unfolded.matmul(weight).transpose(
            1, 2
        )  # (N, out_channels, num_patches)

        output = self.fold(output_unfolded)  # (N, out_channels, out_h, out_w)
        if self.bias is not None:
            output += self.bias.view(1, -1, 1, 1)

        if output.shape[0] != batch_size:
            raise ValueError(
                f"Batch size {batch_size} is not equal to output batch size {output.shape[0]}"
            )
        if output.shape[1] != out_channels:
            raise ValueError(
                f"Output channels {out_channels} is not equal to output channels {output.shape[1]}"
            )

        return output


kh = kw = 5
n_in_channels = 1
n_out_channels = 1
weight = torch.randn(n_out_channels, n_in_channels, kw, kw).double()
bias = torch.randn(n_out_channels).double()
print(f"{weight.shape=}")
train_features, train_labels = next(iter(dataloader))
train_features = train_features.unsqueeze(dim=1)
print(f"{train_features.shape=}")
myconv2d = MyConv2d(
    edge=28,
    n_in_channels=n_in_channels,
    n_out_channels=n_out_channels,
    kernel_width=kw,
    kernel_height=kh,
    stride=2,
    padding=2,
    dilation=1,
)
conv_features = myconv2d(train_features)
print(f"{conv_features.shape=}")

In [None]:
label = train_labels[0]
fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(5, 5))
ax = axs[0]
img = train_features[0][0]
ax.imshow(img, cmap="gray")
ax = axs[1]
img = conv_features.detach().numpy()[0][0]
ax.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

In [None]:
class ParameterHistory:
    def __init__(self, every_n: int = 1):
        self.history = defaultdict(list)
        self.every_n = every_n
        self.iter = []

    def __call__(self, model: nn.Module, _iter: int):
        if _iter % self.every_n != 0:
            return
        state_dict = model.state_dict()

        for name, tensor in state_dict.items():
            self.history[name].append(tensor.clone().numpy().ravel())

        self.iter.append(_iter)


class LossHistory:
    def __init__(self, every_n: int = 1):
        self.history = []
        self.iter = []
        self.every_n = every_n

    def __call__(self, loss: torch.Tensor, _iter: int):
        if _iter % self.every_n != 0:
            return
        self.history.append(loss.item())
        self.iter.append(_iter)

In [None]:
class TanhLeCun1990(nn.Module):
    def __init__(self, A: float = 1.716, S: float = 2 / 3):
        super().__init__()
        self.A = A
        self.S = S

    def forward(self, x: torch.Tensor):
        return self.A * torch.tanh(self.S * x)


class Model(nn.Module):
    def __init__(
        self,
        edge: int = 28,
        n_classes: int = 10,
        lecun_init: bool = True,
        lecun_act: bool = True,
    ):
        super().__init__()

        # self.conv1 = nn.Conv2d(1, 12, kernel_size=5, stride=2, padding=2)
        self.conv1 = MyConv2d(
            edge=edge,
            n_in_channels=1,
            n_out_channels=12,
            kernel_width=5,
            kernel_height=5,
            stride=2,
            padding=2,
            lecun_init=lecun_init,
        )
        edge = edge // 2  # effect of stride
        # self.conv2 = nn.Conv2d(12, 12, kernel_size=5, stride=2, padding=2)
        self.conv2 = MyConv2d(
            edge=edge,
            n_in_channels=12,
            n_out_channels=12,
            kernel_width=5,
            kernel_height=5,
            stride=2,
            padding=2,
            lecun_init=lecun_init,
        )
        edge = edge // 2  # effect of stride
        self.lin1 = nn.Linear(edge * edge * 12, 30)
        self.lin2 = nn.Linear(30, n_classes)

        if lecun_init:
            s = 2.4 / self.lin1.weight.shape[0]
            self.lin1.weight.data.uniform_(-s, s)

            s = 2.4 / self.lin2.weight.shape[0]
            self.lin2.weight.data.uniform_(-s, s)

        if lecun_act:
            self.act = TanhLeCun1990()
        else:
            self.act = F.tanh

    def forward(self, x: torch.Tensor):
        x = x.unsqueeze(dim=1)
        x = self.act(self.conv1(x))
        x = self.act(self.conv2(x))
        x = torch.flatten(x, 1)
        x = self.act(self.lin1(x))
        x = self.lin2(x)
        return self.act(x)
        # return F.softmax(x, dim=-1)

In [None]:
model = Model(lecun_init=True, lecun_act=True)
model.double()

In [None]:
opt = SGD(
    model.parameters(),
    lr=0.1,
)

In [None]:
# loss_func = nn.CrossEntropyLoss()
loss_func = nn.MSELoss()

In [None]:
model.to(device);

In [None]:
parameter_history = ParameterHistory(every_n=10)
loss_history = LossHistory(every_n=10)

In [None]:
def densify_y(y: torch.Tensor) -> torch.Tensor:
    new_y = F.one_hot(y, num_classes=10)
    new_y[new_y == 0] = -1
    return new_y.double()


train_labels[0:3], densify_y(train_labels[0:3])

In [None]:
n_epochs = 20
_iter = 0
model.train()
for epoch in tqdm.tqdm(range(n_epochs), desc="Epochs", total=n_epochs):
    for i, (xb, yb) in tqdm.tqdm(
        enumerate(dataloader), desc="Batches", total=len(dataloader)
    ):
        xb = xb.to(device)
        yb = yb.to(device)
        yb = densify_y(yb)
        loss = loss_func(model(xb), yb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        parameter_history(model, _iter)
        loss_history(loss, _iter)

        _iter += 1

print("Done!")

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 4))
df = pd.DataFrame({"iter": loss_history.iter, "loss": loss_history.history})
df_roll = df.rolling(window=10, on="iter", min_periods=1).mean()

sns.lineplot(data=df, x="iter", y="loss", ax=ax, label="Train")
sns.lineplot(
    data=df_roll, x="iter", y="loss", ax=ax, label="Train (rolling mean)"
)
ax.set(xlabel="Iter", ylabel="Loss", title="Loss History")
plt.tight_layout()

display(df_roll.tail())

In [None]:
def hist2dfy(history: ParameterHistory, name: str) -> pd.DataFrame:
    df = [
        pd.DataFrame({"value": w}).assign(iter=i)
        for i, w in zip(history.iter, history.history[name])
    ]
    return pd.concat(df, ignore_index=True)[["iter", "value"]]


def draw_history(
    history: ParameterHistory,
    name: str,
    figsize: T.Tuple[int, int] = (12, 4),
    weight_bins: int = 20,
    bias_bins: int = 10,
) -> None:
    fig, axs = plt.subplots(figsize=figsize, nrows=2, sharex=True)

    ax = axs[0]
    _name = f"{name}.weight"
    df = hist2dfy(history, _name)
    n_iter = df["iter"].nunique()
    bins = (n_iter, weight_bins)
    sns.histplot(
        data=df,
        x="iter",
        y="value",
        ax=ax,
        thresh=None,
        cmap="plasma",
        bins=bins,
    )
    ax.set_ylabel("weight")
    ax.set_title(name)

    ax = axs[1]
    _name = f"{name}.bias"
    df = hist2dfy(history, _name)
    bins = (n_iter, bias_bins)
    sns.histplot(
        data=df,
        x="iter",
        y="value",
        ax=ax,
        thresh=None,
        cmap="plasma",
        bins=bins,
    )
    ax.set_xlabel("iter")
    ax.set_ylabel("bias")

    plt.tight_layout()
    plt.show()


draw_history(parameter_history, "conv1")
draw_history(parameter_history, "conv2")
draw_history(parameter_history, "lin1", weight_bins=100)
draw_history(parameter_history, "lin2")

In [None]:
train_features, train_labels = next(iter(dataloader))

In [None]:
model.eval();

In [None]:
train_features = train_features.to(device)
pred_probs = model(train_features)
pred_probs

In [None]:
y_pred = pred_probs.to("cpu").detach().numpy().argmax(axis=1)
y_pred

In [None]:
train_labels

In [None]:
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].cpu()  # .reshape((28,28))
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}, pred: {y_pred[0]}")

In [None]:
n_filters = model.conv1.weight.shape[0]

fig, axs = plt.subplots(
    nrows=n_filters // 3, ncols=n_filters // 4, figsize=(12, 12)
)
with torch.no_grad():
    conv_features = model.act(model.conv1(train_features.unsqueeze(1)))
    for i, ax in enumerate(axs.flatten()):
        ax.imshow(conv_features[0][i], cmap="gray")
        ax.axis("off")
        ax.set_title(f"Filter {i+1}")

plt.show()
print(f"Label: {label}")