## Second Model

The first model suffers from poor predictions in many cases, and particular issues when not drawing in the centers.

This model aims to improve upon the first model in a few ways:
* It uses more input transforms to create a translational invariance in the training data.
* It better designs the convolution / pooling layers to better align for digit recognition, and translation equivariance/invariance


In [60]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.transforms import v2
from importlib import resources

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")
print(f"Using device: {device}")

CURRENT_WEIGHTS_PATH = "nn-invariant-current.pth"

Using device: mps


In [67]:
training_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(dtype=torch.float32),
    v2.RandomResize(28, 40),
    v2.RandomRotation(30),
    v2.RandomResizedCrop(size = 28, scale = (28.0/40, 28.0/40)),
    # Not sure if this is needed...
    v2.Normalize((0.1307,), (0.3081,))
])

test_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(dtype=torch.float32),
    v2.Normalize((0.1307,), (0.3081,))
])

train_set = torchvision.datasets.MNIST(
    "./data",
    download=True,
    transform=training_transform,
    train=True,
)

test_set = torchvision.datasets.MNIST(
    "./data",
    download=True,
    transform=test_transform,
    train=False,
)
print(f"Training set size: {len(train_set)}")
print(f"Test set size: {len(test_set)}")

Training set size: 60000
Test set size: 10000


In [56]:
class AssertShape(nn.Module):
    def __init__(
        self,
        size: tuple | list[int],
        label: str = None,
        ignore_batch_size: bool = True,
    ):
        super().__init__()
        self.size = tuple(size)
        self.ignore_batch_size = ignore_batch_size
        if label is not None:
            self._label_prefix = f"{label}: "
        else:
            self._label_prefix = ""
        self._cached_check_shape = torch.Size(size)
        self._last_batch_size = None

    def forward(self, x):
        if self.ignore_batch_size:
            if self._last_batch_size != x.shape[0]:
                self._cached_check_shape = torch.Size(tuple([x.shape[0],] + [d for d in self.size]))
            if x.shape != self._cached_check_shape:
                raise ValueError(f"{self._label_prefix}Expected tensor shape (*, {", ".join(str(d) for d in self.size)}), got ({", ".join(str(d) for d in x.shape)})")
        else:
            if x.shape != self._cached_check_shape:
                raise ValueError(f"{self._label_prefix}Expected tensor shape ({", ".join(str(d) for d in self.size)}), got ({", ".join(str(d) for d in x.shape)})")
        return x

class Network(nn.Module):
    def __init__(self):
        # Inspired by thinking in article:
        # * https://chriswolfvision.medium.com/what-is-translation-equivariance-and-why-do-we-use-convolutions-to-get-it-6f18139d4c59
        # A reproduction of:
        # * https://www.kaggle.com/code/minggyul/digit-recognizer-using-cnn
        super().__init__()
        self.layers = nn.Sequential(
            AssertShape([1, 28, 28]),
            nn.Conv2d(1, 32, kernel_size=3),  AssertShape([32, 26, 26]),
            nn.ReLU(),
            nn.BatchNorm2d(32),               AssertShape([32, 26, 26]),
            nn.Conv2d(32, 32, kernel_size=3), AssertShape([32, 24, 24]),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),               AssertShape([32, 12, 12]),
            nn.Dropout(0.25),

            nn.Conv2d(32, 64, kernel_size=3), AssertShape([64, 10, 10]),
            nn.ReLU(),
            nn.BatchNorm2d(64),               AssertShape([64, 10, 10]),
            nn.Conv2d(64, 64, kernel_size=3), AssertShape([64, 8, 8]),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),               AssertShape([64, 4, 4]),
            nn.Dropout(0.25),

            nn.Flatten(),                     AssertShape([64 * 4 * 4]),
            nn.Linear(64 * 4 * 4, 100),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(100, 10),
        )

    def forward(self, x):
        return self.layers(x)

network = Network().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(network.parameters())

In [None]:
# Load weights
state_dict = torch.load(resources.open_binary(__package__, CURRENT_WEIGHTS_PATH))
network.load_state_dict(state_dict)

In [None]:
# Save weights as backup
torch.save(network.state_dict(), "../src/model/nn-invariant-backup-2-0.079.pth")

In [64]:
batch_size = 20

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True, # Speed up CUDA
)

for epoch in range(100):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(train_loader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = network(inputs.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 1000 == 999:    # print every 1000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

    torch.save(network.state_dict(), CURRENT_WEIGHTS_PATH)
    print(f'[{epoch + 1}, DONE!] model weights saved')

print('Finished Training')

[1,  1000] loss: 0.057
[1,  2000] loss: 0.054
[1,  3000] loss: 0.055
[1, DONE!] model weights saved
[2,  1000] loss: 0.057
[2,  2000] loss: 0.054
[2,  3000] loss: 0.059
[2, DONE!] model weights saved
[3,  1000] loss: 0.056
[3,  2000] loss: 0.052
[3,  3000] loss: 0.059
[3, DONE!] model weights saved
[4,  1000] loss: 0.055
[4,  2000] loss: 0.054
[4,  3000] loss: 0.053
[4, DONE!] model weights saved
[5,  1000] loss: 0.054
[5,  2000] loss: 0.051
[5,  3000] loss: 0.049
[5, DONE!] model weights saved
[6,  1000] loss: 0.049
[6,  2000] loss: 0.055
[6,  3000] loss: 0.051
[6, DONE!] model weights saved
[7,  1000] loss: 0.054
[7,  2000] loss: 0.050
[7,  3000] loss: 0.050
[7, DONE!] model weights saved
[8,  1000] loss: 0.050
[8,  2000] loss: 0.051
[8,  3000] loss: 0.049
[8, DONE!] model weights saved
[9,  1000] loss: 0.044
[9,  2000] loss: 0.050
[9,  3000] loss: 0.049
[9, DONE!] model weights saved
[10,  1000] loss: 0.050
[10,  2000] loss: 0.043
[10,  3000] loss: 0.053
[10, DONE!] model weights sa

KeyboardInterrupt: 

In [68]:
import numpy as np

test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=1,
    shuffle=False,
    num_workers=2,
)
test_network = Network().to(device).eval()
test_network.load_state_dict(torch.load(CURRENT_WEIGHTS_PATH))

correct = 0
total = 0
failed_ids = []
for i, data in enumerate(test_loader):
    inputs, labels = data
    outputs = network(inputs.to(device))
    # Batch size of 0
    label = labels[0]
    output = outputs[0]

    total += 1
    if np.argmax(F.softmax(output, dim=0).tolist()) == label:
        correct += 1
    else:
        failed_ids.append(i)

print(f"Accuracy: {correct}/{total} ({(float(correct) / total):2.1%})")

Accuracy: 9699/10000 (97.0%)


In [40]:
# Prints a random failed example
import random

failed_id = random.choice(failed_ids)
(example, label) = test_set[failed_id]
inputs = torch.stack([example])
output = F.softmax(network(inputs.to(device)), dim=1)[0]
labels = sorted(enumerate(output), key=lambda x: x[1], reverse=True)
print(f"Expected label: {label}")
print(f"Top labels: {", ".join(f"{label} ({prob:2.1%})" for label, prob in labels[0:3])}")
v2.ToPILImage()(test_set.data[failed_id]).show()


Expected label: 0
Top labels: 6 (44.7%), 0 (26.0%), 8 (20.1%)
