# Toy Convolutional Neural Network

This notebook demonstrates a small CNN similar to the AlexNet-style architecture described in the ["ImageNet Classification with Deep Convolutional Neural Networks"](https://proceedings.neurips.cc/paper_files/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf) paper.
We implement convolution and pooling layers manually without using `nn.Conv2d` or `nn.MaxPool2d`.


In [None]:
import torch
from torch import nn
import matplotlib.pyplot as plt


## ManualConv2d
This layer performs a convolution using `torch.nn.Unfold` to collect sliding windows.
Weights and biases are learnable parameters. Shapes are annotated in comments.

In [None]:
class ManualConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        weight_shape = (out_channels, in_channels, *kernel_size)
        self.weight = nn.Parameter(torch.randn(weight_shape) * 0.01)
        self.bias = nn.Parameter(torch.zeros(out_channels))
        self.unfold = nn.Unfold(kernel_size=kernel_size, stride=stride, padding=padding)

    def forward(self, x):
        # x: (N, C_in, H, W)
        patches = self.unfold(x)  # (N, C_in*kh*kw, L)
        weight = self.weight.view(self.weight.size(0), -1)
        out = weight @ patches + self.bias.unsqueeze(1)  # (N, out_channels, L)
        h_out = (x.size(2) + 2*self.padding[0] - self.kernel_size[0]) // self.stride[0] + 1
        w_out = (x.size(3) + 2*self.padding[1] - self.kernel_size[1]) // self.stride[1] + 1
        out = out.view(x.size(0), self.weight.size(0), h_out, w_out)
        return out


## ManualMaxPool2d
Max pooling implemented with `torch.nn.Unfold`, taking the maximum value from each window.

In [None]:
class ManualMaxPool2d(nn.Module):
    def __init__(self, kernel_size, stride=None, padding=0):
        super().__init__()
        if stride is None:
            stride = kernel_size
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.unfold = nn.Unfold(kernel_size=kernel_size, stride=stride, padding=padding)

    def forward(self, x):
        # x: (N, C, H, W)
        patches = self.unfold(x)  # (N, C*kh*kw, L)
        patches = patches.view(x.size(0), x.size(1), self.kernel_size[0]*self.kernel_size[1], -1)
        out, _ = patches.max(dim=2)
        h_out = (x.size(2) + 2*self.padding[0] - self.kernel_size[0]) // self.stride[0] + 1
        w_out = (x.size(3) + 2*self.padding[1] - self.kernel_size[1]) // self.stride[1] + 1
        out = out.view(x.size(0), x.size(1), h_out, w_out)
        return out


## ToyAlexNet Architecture
An AlexNet-inspired network using the custom layers defined above. Comments indicate tensor shapes for an input of `(N, 3, 224, 224)`.

In [None]:
class ToyAlexNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = ManualConv2d(3, 64, kernel_size=11, stride=4, padding=2)  # -> (N, 64, 55, 55)
        self.pool1 = ManualMaxPool2d(kernel_size=3, stride=2)                  # -> (N, 64, 27, 27)
        self.conv2 = ManualConv2d(64, 192, kernel_size=5, padding=2)           # -> (N, 192, 27, 27)
        self.pool2 = ManualMaxPool2d(kernel_size=3, stride=2)                  # -> (N, 192, 13, 13)
        self.conv3 = ManualConv2d(192, 384, kernel_size=3, padding=1)          # -> (N, 384, 13, 13)
        self.conv4 = ManualConv2d(384, 256, kernel_size=3, padding=1)          # -> (N, 256, 13, 13)
        self.conv5 = ManualConv2d(256, 256, kernel_size=3, padding=1)          # -> (N, 256, 13, 13)
        self.pool5 = ManualMaxPool2d(kernel_size=3, stride=2)                  # -> (N, 256, 6, 6)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(256*6*6, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()

    def forward(self, x):
        # Input x: (N, 3, 224, 224)
        x = self.relu(self.conv1(x))  # -> (N, 64, 55, 55)
        x = self.pool1(x)             # -> (N, 64, 27, 27)
        x = self.relu(self.conv2(x))  # -> (N, 192, 27, 27)
        x = self.pool2(x)             # -> (N, 192, 13, 13)
        x = self.relu(self.conv3(x))  # -> (N, 384, 13, 13)
        x = self.relu(self.conv4(x))  # -> (N, 256, 13, 13)
        x = self.relu(self.conv5(x))  # -> (N, 256, 13, 13)
        x = self.pool5(x)             # -> (N, 256, 6, 6)
        x = self.flatten(x)           # -> (N, 256*6*6)
        x = self.dropout(self.relu(self.fc1(x))) # -> (N, 4096)
        x = self.dropout(self.relu(self.fc2(x))) # -> (N, 4096)
        x = self.fc3(x)               # -> (N, num_classes)
        return x


## Demo
The following cell creates the network, passes a batch of random images, and plots the input and first feature map.

In [None]:
def demo():
    net = ToyAlexNet(num_classes=10)
    dummy = torch.randn(1, 3, 224, 224)
    out1 = net.relu(net.conv1(dummy))
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    axes[0].imshow(dummy[0].permute(1, 2, 0).numpy())
    axes[0].set_title("Input image")
    axes[1].imshow(out1[0, 0].detach().numpy(), cmap="gray")
    axes[1].set_title("First feature map")
    plt.show()
    out = net(dummy)
    print("Output shape:", out.shape)

demo()
