In [1]:
from pathlib import Path

DATA_DIR = Path.cwd().parent / "data"

DATA_DIR

PosixPath('/workspaces/mnist/data')

In [2]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.set_float32_matmul_precision('high')

device

device(type='cuda')

In [3]:
import torch

generator = torch.Generator().manual_seed(42)

# Splits

In [4]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

train_ds = MNIST(
    root=DATA_DIR / "train",
    train=True,
    download=True,
    transform=ToTensor(),
)

train_ds

Dataset MNIST
    Number of datapoints: 60000
    Root location: /workspaces/mnist/data/train
    Split: Train
    StandardTransform
Transform: ToTensor()

In [5]:
from torch.utils.data import random_split

DEV_SIZE = 10_000

train_ds, dev_ds = random_split(
    train_ds,
    [len(train_ds) - DEV_SIZE, DEV_SIZE],
    generator=generator,
)

len(train_ds), len(dev_ds)

(50000, 10000)

# Definition

In [44]:
import torch
import torch.nn as nn

@torch.compile
class CNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv_stack = nn.Sequential(
            nn.BatchNorm2d(1),

            nn.Conv2d(1, 16, kernel_size=5, padding="same"), # padding preserves spatial dims
            nn.BatchNorm2d(16),
            nn.ReLU(), # (B, 1, 28, 28) -> (B, 16, 28, 28)

            nn.Conv2d(16, 32, kernel_size=5, padding="same"),
            nn.BatchNorm2d(32),
            nn.ReLU(), # (B, 16, 28, 28) -> (B, 32, 28, 28)

            nn.MaxPool2d(kernel_size=2, stride=2), # (B, 32, 28, 28) -> (B, 32, 14, 14)

            nn.Conv2d(32, 16, kernel_size=3, padding="same"),
            nn.BatchNorm2d(16),
            nn.ReLU(), # (B, 32, 14, 14) -> (B, 16, 14, 14)

            nn.Conv2d(16, 8, kernel_size=3, padding="same"),
            nn.BatchNorm2d(8),
            nn.ReLU(), # (B, 16, 14, 14) -> (B, 8, 14, 14)

            nn.Conv2d(8, 4, kernel_size=3, padding="same"),
            nn.BatchNorm2d(4),
            nn.ReLU(), # (B, 8, 14, 14) -> (B, 4, 14, 14)

            nn.MaxPool2d(kernel_size=2, stride=2), # (B, 8, 14, 14) -> (B, 4, 7, 7)
        )

        self.flatten = nn.Flatten() # (B, 4, 7, 7) -> (B, 4 * 7 * 7)

        self.linear_stack = nn.Sequential(
            nn.Linear(4 * 7 * 7, 64), # (B, 196) -> (B, 64)
            nn.ReLU(),
            nn.Dropout(0.5),
            
            nn.Linear(64, 10), # (B, 64) -> (B, 10)
        )

    def forward(self, x):
        x = self.conv_stack(x)
        x = self.flatten(x)
        x = self.linear_stack(x)
        return x

In [45]:
model = CNN()

model

CNN(
  (conv_stack): Sequential(
    (0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=same)
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
    (4): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=same)
    (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (9): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (12): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU()
    (14): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (15): 

Total parameters

In [46]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params

32736

# Training

In [47]:
import wandb

run = wandb.init(project="mnist", config={
    "architecture": "cnn",
    "epochs": 8,
    "batch_size": 64,
    "learning_rate": 1e-3,
    "total_params": total_params,
})

In [48]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import TypedDict

class LoopMetrics(TypedDict):
    accuracy: float
    loss: float

def train_loop(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module, optimizer: torch.optim.Optimizer) -> LoopMetrics:
    model = model.to(device).train()
    size = len(dataloader.dataset)
    train_loss, correct = 0, 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        train_loss += loss.item()
        run.log({"train_batch_loss": loss.item()})

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        correct += (
            (pred.argmax(1) == y)
            .type(torch.float)
            .sum()
            .item()
        )

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            
    accuracy = correct / size
    train_loss /= size

    print(f"Train Error: \n Accuracy: {100 * accuracy:>0.1f}%, Avg loss: {train_loss:>8f} \n")

    return {
        "accuracy": accuracy,
        "loss": train_loss,
    }

@torch.no_grad
def test_loop(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module) -> LoopMetrics:
    model = model.to(device).eval()
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)
    
        test_loss += loss.item()
    
        correct += (
            (pred.argmax(1) == y)
            .type(torch.float)
            .sum()
            .item()
        )

    test_loss /= size
    accuracy = correct / size
    print(f"Test Error: \n Accuracy: {100 * accuracy:>0.1f}%, Avg loss: {test_loss:>8f} \n")

    return {
        "accuracy": accuracy,
        "loss": test_loss,
    }

In [49]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_ds, batch_size=run.config.batch_size, shuffle=True, generator=generator)
dev_dataloader = DataLoader(dev_ds, batch_size=run.config.batch_size, shuffle=False)

model = CNN()

run.watch(model)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=run.config.learning_rate)

for t in range(run.config.epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_metrics = train_loop(model, train_dataloader, loss_fn, optimizer)
    test_metrics = test_loop(model, dev_dataloader, loss_fn)
    run.log({
        "train_epoch_accuracy": train_metrics["accuracy"],
        "train_epoch_loss": train_metrics["loss"],
        "test_epoch_accuracy": test_metrics["accuracy"],
        "test_epoch_loss": test_metrics["loss"],
    })

run.finish()

Epoch 1
-------------------------------
loss: 2.345578  [    0/50000]
loss: 0.509311  [ 6400/50000]
loss: 0.126886  [12800/50000]
loss: 0.206105  [19200/50000]
loss: 0.150900  [25600/50000]
loss: 0.111874  [32000/50000]
loss: 0.051089  [38400/50000]
loss: 0.170755  [44800/50000]
Train Error: 
 Accuracy: 92.0%, Avg loss: 0.004459 

Test Error: 
 Accuracy: 98.1%, Avg loss: 0.000979 

Epoch 2
-------------------------------
loss: 0.320460  [    0/50000]
loss: 0.052733  [ 6400/50000]
loss: 0.111996  [12800/50000]
loss: 0.073829  [19200/50000]
loss: 0.054652  [25600/50000]
loss: 0.029629  [32000/50000]
loss: 0.065865  [38400/50000]
loss: 0.133971  [44800/50000]
Train Error: 
 Accuracy: 97.4%, Avg loss: 0.001478 

Test Error: 
 Accuracy: 98.5%, Avg loss: 0.000799 

Epoch 3
-------------------------------
loss: 0.102876  [    0/50000]
loss: 0.053742  [ 6400/50000]
loss: 0.081412  [12800/50000]
loss: 0.097992  [19200/50000]
loss: 0.041854  [25600/50000]
loss: 0.049106  [32000/50000]
loss: 0.02

0,1
test_epoch_accuracy,▁▄▆▆▅█▆▇
test_epoch_loss,█▅▃▃▃▁▄▂
train_batch_loss,█▄▂▁▂▂▁▂▁▂▁▁▂▁▁▁▁▁▁▂▂▁▁▁▂▁▁▁▂▁▁▁▁▁▁▁▂▁▁▁
train_epoch_accuracy,▁▇▇▇████
train_epoch_loss,█▂▂▂▁▁▁▁

0,1
test_epoch_accuracy,0.9897
test_epoch_loss,0.0006
train_batch_loss,0.00855
train_epoch_accuracy,0.9865
train_epoch_loss,0.00071


# Evaluation

In [42]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

test_ds = MNIST(
    root=DATA_DIR / "test",
    train=False,
    download=True,
    transform=ToTensor(),
)

test_ds

Dataset MNIST
    Number of datapoints: 10000
    Root location: /workspaces/mnist/data/test
    Split: Test
    StandardTransform
Transform: ToTensor()

In [43]:
from torch.utils.data import DataLoader

test_dataloader = DataLoader(test_ds, batch_size=run.config.batch_size, shuffle=False)

_ = test_loop(model, test_dataloader, loss_fn)

Test Error: 
 Accuracy: 99.1%, Avg loss: 0.000450 

