In [1]:
from pathlib import Path

DATA_DIR = Path.cwd().parent / "data"

DATA_DIR

PosixPath('/workspaces/mnist/data')

In [2]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.set_float32_matmul_precision('high')

device

device(type='cuda')

In [3]:
import torch

generator = torch.Generator().manual_seed(42)

# Splits

In [4]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

train_ds = MNIST(
    root=DATA_DIR / "train",
    train=True,
    download=True,
    transform=ToTensor(),
)

train_ds

Dataset MNIST
    Number of datapoints: 60000
    Root location: /workspaces/mnist/data/train
    Split: Train
    StandardTransform
Transform: ToTensor()

In [5]:
from torch.utils.data import random_split

DEV_SIZE = 10_000

train_ds, dev_ds = random_split(
    train_ds,
    [len(train_ds) - DEV_SIZE, DEV_SIZE],
    generator=generator,
)

len(train_ds), len(dev_ds)

(50000, 10000)

# Definition

In [6]:
import torch
import torch.nn as nn

@torch.compile
class CNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv_stack = nn.Sequential(
            nn.BatchNorm2d(1),

            nn.Conv2d(1, 32, kernel_size=5, padding="same"), # padding preserves spatial dims
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # (B, 1, 32, 32) -> (B, 32, 16, 16)

            nn.Conv2d(32, 16, kernel_size=3, padding="same"), # padding preserves spatial dims
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # (B, 32, 16, 16) -> (B, 16, 7, 7)
        )

        self.flatten = nn.Flatten() # (B, 16, 8, 8) -> (B, 16 * 7 * 7)

        self.linear_stack = nn.Sequential(
            nn.Linear(16 * 7 * 7, 128), # (B, 784) -> (B, 128)
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(128, 64), # (B, 128) -> (B, 64)
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(64, 10), # (B, 64) -> (B, 10)
        )

    def forward(self, x):
        x = self.conv_stack(x)
        x = self.flatten(x)
        x = self.linear_stack(x)
        return x

In [7]:
model = CNN()

model

CNN(
  (conv_stack): Sequential(
    (0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=same)
    (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (6): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_stack): Sequential(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=128, out_features=64, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.2, inplace=False)
    (6): Linear(in_features=64, out_feature

Total parameters

In [None]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params

114940

# Training

In [None]:
import wandb

run = wandb.init(project="mnist", config={
    "architecture": "cnn",
    "epochs": 8,
    "batch_size": 64,
    "learning_rate": 1e-3,
    "total_params": total_params,
})

[34m[1mwandb[0m: Currently logged in as: [33mmrbrobot[0m ([33mcloudbend[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [10]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import TypedDict

class LoopMetrics(TypedDict):
    accuracy: float
    loss: float

def train_loop(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module, optimizer: torch.optim.Optimizer) -> LoopMetrics:
    model = model.to(device).train()
    size = len(dataloader.dataset)
    train_loss, correct = 0, 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        train_loss += loss.item()
        run.log({"train_batch_loss": loss.item()})

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        correct += (
            (pred.argmax(1) == y)
            .type(torch.float)
            .sum()
            .item()
        )

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            
    accuracy = correct / size
    train_loss /= size

    print(f"Train Error: \n Accuracy: {100 * accuracy:>0.1f}%, Avg loss: {train_loss:>8f} \n")

    return {
        "accuracy": accuracy,
        "loss": train_loss,
    }

@torch.no_grad
def test_loop(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module) -> LoopMetrics:
    model = model.to(device).eval()
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)
    
        test_loss += loss.item()
    
        correct += (
            (pred.argmax(1) == y)
            .type(torch.float)
            .sum()
            .item()
        )

    test_loss /= size
    accuracy = correct / size
    print(f"Test Error: \n Accuracy: {100 * accuracy:>0.1f}%, Avg loss: {test_loss:>8f} \n")

    return {
        "accuracy": accuracy,
        "loss": test_loss,
    }

In [11]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_ds, batch_size=run.config.batch_size, shuffle=True, generator=generator)
dev_dataloader = DataLoader(dev_ds, batch_size=run.config.batch_size, shuffle=False)

model = CNN()

run.watch(model)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=run.config.learning_rate)

for t in range(run.config.epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_metrics = train_loop(model, train_dataloader, loss_fn, optimizer)
    test_metrics = test_loop(model, dev_dataloader, loss_fn)
    run.log({
        "train_epoch_accuracy": train_metrics["accuracy"],
        "train_epoch_loss": train_metrics["loss"],
        "test_epoch_accuracy": test_metrics["accuracy"],
        "test_epoch_loss": test_metrics["loss"],
    })

run.finish()

Epoch 1
-------------------------------
loss: 2.300159  [    0/50000]
loss: 0.117482  [ 6400/50000]
loss: 0.366562  [12800/50000]
loss: 0.201011  [19200/50000]
loss: 0.042790  [25600/50000]
loss: 0.128922  [32000/50000]
loss: 0.082905  [38400/50000]
loss: 0.095699  [44800/50000]
Train Error: 
 Accuracy: 93.6%, Avg loss: 0.003324 

Test Error: 
 Accuracy: 98.0%, Avg loss: 0.001044 

Epoch 2
-------------------------------
loss: 0.013649  [    0/50000]
loss: 0.011442  [ 6400/50000]
loss: 0.036188  [12800/50000]
loss: 0.159598  [19200/50000]
loss: 0.064307  [25600/50000]
loss: 0.098858  [32000/50000]
loss: 0.014725  [38400/50000]
loss: 0.134163  [44800/50000]
Train Error: 
 Accuracy: 97.8%, Avg loss: 0.001115 

Test Error: 
 Accuracy: 98.2%, Avg loss: 0.000923 

Epoch 3
-------------------------------
loss: 0.017903  [    0/50000]
loss: 0.016991  [ 6400/50000]
loss: 0.045127  [12800/50000]
loss: 0.033309  [19200/50000]
loss: 0.063500  [25600/50000]
loss: 0.051973  [32000/50000]
loss: 0.00

0,1
test_epoch_accuracy,▁▃██▇▆██
test_epoch_loss,█▆▁▁▃▃▂▁
train_batch_loss,▄▅▇▅▂▃▂▅▁▃▂▃▂▂▁▂█▂▁▂▁▁▁▄▂▁▁▂▁▂▄▁▂▁▁▄▃▁▂▃
train_epoch_accuracy,▁▆▇▇████
train_epoch_loss,█▃▂▂▁▁▁▁

0,1
test_epoch_accuracy,0.9899
test_epoch_loss,0.0006
train_batch_loss,0.00873
train_epoch_accuracy,0.9916
train_epoch_loss,0.00043


# Evaluation

In [12]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

test_ds = MNIST(
    root=DATA_DIR / "test",
    train=False,
    download=True,
    transform=ToTensor(),
)

test_ds

Dataset MNIST
    Number of datapoints: 10000
    Root location: /workspaces/mnist/data/test
    Split: Test
    StandardTransform
Transform: ToTensor()

In [13]:
from torch.utils.data import DataLoader

test_dataloader = DataLoader(test_ds, batch_size=run.config.batch_size, shuffle=False)

_ = test_loop(model, test_dataloader, loss_fn)

Test Error: 
 Accuracy: 99.1%, Avg loss: 0.000478 

