In [1]:
from pathlib import Path

DATA_DIR = Path.cwd().parent / "data"

DATA_DIR

PosixPath('/workspaces/mnist/data')

In [2]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.set_float32_matmul_precision('high')

device

device(type='cuda')

In [3]:
import torch

generator = torch.Generator().manual_seed(42)

# Splits

In [4]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

train_ds = MNIST(
    root=DATA_DIR / "train",
    train=True,
    download=True,
    transform=ToTensor(),
)

train_ds

Dataset MNIST
    Number of datapoints: 60000
    Root location: /workspaces/mnist/data/train
    Split: Train
    StandardTransform
Transform: ToTensor()

In [5]:
from torch.utils.data import random_split

DEV_SIZE = 10_000

train_ds, dev_ds = random_split(
    train_ds,
    [len(train_ds) - DEV_SIZE, DEV_SIZE],
    generator=generator,
)

len(train_ds), len(dev_ds)

(50000, 10000)

# Definition

In [6]:
import torch
import torch.nn as nn

@torch.compile
class CNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv_stack = nn.Sequential(
            nn.BatchNorm2d(1),

            nn.Conv2d(1, 32, kernel_size=3, padding=1), # padding preserves spatial dims
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # 28x28 -> 14x14

            nn.Conv2d(32, 64, kernel_size=3, padding=1), # padding preserves spatial dims
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # 14x14 -> 7x7
        )

        self.flatten = nn.Flatten()

        self.linear_stack = nn.Sequential(
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            
            nn.Linear(128, 10),
        )

    def forward(self, x):
        x = self.conv_stack(x)
        x = self.flatten(x)
        x = self.linear_stack(x)
        return x

In [7]:
model = CNN()

model

CNN(
  (conv_stack): Sequential(
    (0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_stack): Sequential(
    (0): Linear(in_features=3136, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=512, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=128, out_

# Training

In [8]:
import wandb

run = wandb.init(project="mnist", config={
    "architecture": "cnn",
    "epochs": 8,
    "batch_size": 64,
    "learning_rate": 1e-3,
})

[34m[1mwandb[0m: Currently logged in as: [33mmrbrobot[0m ([33mcloudbend[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [9]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import TypedDict

class LoopMetrics(TypedDict):
    accuracy: float
    loss: float

def train_loop(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module, optimizer: torch.optim.Optimizer) -> LoopMetrics:
    model = model.to(device).train()
    size = len(dataloader.dataset)
    train_loss, correct = 0, 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        train_loss += loss.item()
        run.log({"train_batch_loss": loss.item()})

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        correct += (
            (pred.argmax(1) == y)
            .type(torch.float)
            .sum()
            .item()
        )

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            
    accuracy = correct / size
    train_loss /= size

    print(f"Train Error: \n Accuracy: {100 * accuracy:>0.1f}%, Avg loss: {train_loss:>8f} \n")

    return {
        "accuracy": accuracy,
        "loss": train_loss,
    }

@torch.no_grad
def test_loop(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module) -> LoopMetrics:
    model = model.to(device).eval()
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)
    
        test_loss += loss.item()
    
        correct += (
            (pred.argmax(1) == y)
            .type(torch.float)
            .sum()
            .item()
        )

    test_loss /= size
    accuracy = correct / size
    print(f"Test Error: \n Accuracy: {100 * accuracy:>0.1f}%, Avg loss: {test_loss:>8f} \n")

    return {
        "accuracy": accuracy,
        "loss": test_loss,
    }

In [10]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_ds, batch_size=run.config.batch_size, shuffle=True, generator=generator)
dev_dataloader = DataLoader(dev_ds, batch_size=run.config.batch_size, shuffle=False)

model = CNN()

run.watch(model)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=run.config.learning_rate)

for t in range(run.config.epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_metrics = train_loop(model, train_dataloader, loss_fn, optimizer)
    test_metrics = test_loop(model, dev_dataloader, loss_fn)
    run.log({
        "train_epoch_accuracy": train_metrics["accuracy"],
        "train_epoch_loss": train_metrics["loss"],
        "test_epoch_accuracy": test_metrics["accuracy"],
        "test_epoch_loss": test_metrics["loss"],
    })

run.finish()

Epoch 1
-------------------------------
loss: 2.321584  [    0/50000]
loss: 0.176015  [ 6400/50000]
loss: 0.598732  [12800/50000]
loss: 0.250075  [19200/50000]
loss: 0.077316  [25600/50000]
loss: 0.129450  [32000/50000]
loss: 0.106766  [38400/50000]
loss: 0.111339  [44800/50000]
Train Error: 
 Accuracy: 92.3%, Avg loss: 0.003940 

Test Error: 
 Accuracy: 97.8%, Avg loss: 0.001109 

Epoch 2
-------------------------------
loss: 0.060881  [    0/50000]
loss: 0.044780  [ 6400/50000]
loss: 0.040227  [12800/50000]
loss: 0.204864  [19200/50000]
loss: 0.149636  [25600/50000]
loss: 0.166850  [32000/50000]
loss: 0.013978  [38400/50000]
loss: 0.213783  [44800/50000]
Train Error: 
 Accuracy: 97.4%, Avg loss: 0.001499 

Test Error: 
 Accuracy: 98.3%, Avg loss: 0.001031 

Epoch 3
-------------------------------
loss: 0.122946  [    0/50000]
loss: 0.047953  [ 6400/50000]
loss: 0.038037  [12800/50000]
loss: 0.050793  [19200/50000]
loss: 0.094651  [25600/50000]
loss: 0.066618  [32000/50000]
loss: 0.11

0,1
test_epoch_accuracy,▁▄▆▇▇▇██
test_epoch_loss,█▇▄▄▄▃▁▁
train_batch_loss,▇▅▅▆▄▂▆▃▂▃▂▃▂▃▃▂▄▂▁▂▂▁▃▂▁█▂▁▂▂▁▁▁▂▅▂▂▁▂▃
train_epoch_accuracy,▁▆▇▇████
train_epoch_loss,█▃▂▂▂▁▁▁

0,1
test_epoch_accuracy,0.99
test_epoch_loss,0.00057
train_batch_loss,0.06634
train_epoch_accuracy,0.9901
train_epoch_loss,0.00056


# Evaluation

In [13]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

test_ds = MNIST(
    root=DATA_DIR / "test",
    train=False,
    download=True,
    transform=ToTensor(),
)

test_ds

Dataset MNIST
    Number of datapoints: 10000
    Root location: /workspaces/mnist/data/test
    Split: Test
    StandardTransform
Transform: ToTensor()

In [14]:
from torch.utils.data import DataLoader

test_dataloader = DataLoader(test_ds, batch_size=run.config.batch_size, shuffle=False)

_ = test_loop(model, test_dataloader, loss_fn)

Test Error: 
 Accuracy: 99.4%, Avg loss: 0.000387 

