In [1]:
from pathlib import Path

DATA_DIR = Path.cwd().parent / "data"

DATA_DIR

PosixPath('/workspaces/mnist/data')

In [2]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.set_float32_matmul_precision('high')

device

device(type='cuda')

In [3]:
import torch

generator = torch.Generator().manual_seed(42)

# Splits

In [4]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

train_ds = MNIST(
    root=DATA_DIR / "train",
    train=True,
    download=True,
    transform=ToTensor(),
)

train_ds

Dataset MNIST
    Number of datapoints: 60000
    Root location: /workspaces/mnist/data/train
    Split: Train
    StandardTransform
Transform: ToTensor()

In [5]:
from torch.utils.data import random_split

DEV_SIZE = 10_000

train_ds, dev_ds = random_split(
    train_ds,
    [len(train_ds) - DEV_SIZE, DEV_SIZE],
    generator=generator,
)

len(train_ds), len(dev_ds)

(50000, 10000)

# Definition

with relu activations, conventional normalization (e.g., converting to standard deviations) with a 0-skewed input dataset, gradients will vanish. min-max scaling seems like a better approach with relu. since the min is 0, this will end up just dividing each input by 255

learning with min-max scaling is much slower. current hypothesis is that, with input values heavily skewed towards 0, the non-zero values do not have enough influence. so weights need to be much larger to compensate. should look at batch norm next, and find other ways to test this hypothesis

In [6]:
X_train = train_ds.dataset.data.float()

min = X_train.min().item()
max = X_train.max().item()

min, max

(0.0, 255.0)

In [7]:
import torch.nn as nn

class Scale(nn.Module):
    min: float
    max: float

    def __init__(self, min: float, max: float):
        super().__init__()
        self.min = min
        self.max = max

    def forward(self, x):
        return (x - self.min) / (self.max - self.min)

In [None]:
import torch
import torch.nn as nn

@torch.compile
class CNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.scale = Scale(min, max)

        self.conv_stack = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1), # padding maintains spatial dims
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # 28x28 -> 14x14

            nn.Conv2d(32, 64, kernel_size=3, padding=1), # padding maintains spatial dims
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # 14x14 -> 7x7

            nn.Conv2d(64, 128, kernel_size=3, padding=1), # padding maintains spatial dims
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # 7x7 -> 3x3
        )

        self.flatten = nn.Flatten()

        self.linear_stack = nn.Sequential(
            nn.Linear(128 * 3 * 3, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.scale(x)
        x = self.conv_stack(x)
        x = self.flatten(x)
        x = self.linear_stack(x)
        return x

In [51]:
model = CNN()

model

CNN(
  (scale): Scale()
  (conv_stack): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_stack): Sequential(
    (0): Linear(in_features=1152, out_features=512, bias=True)
    (1): R

# Training

In [52]:
import wandb

run = wandb.init(project="mnist", config={
    "architecture": "cnn",
    "epochs": 8,
    "batch_size": 64,
    "learning_rate": 1e-3,
})

In [53]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import TypedDict

class LoopMetrics(TypedDict):
    accuracy: float
    loss: float

def train_loop(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module, optimizer: torch.optim.Optimizer) -> LoopMetrics:
    model = model.to(device).train()
    size = len(dataloader.dataset)
    train_loss, correct = 0, 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        train_loss += loss.item()
        run.log({"train_batch_loss": loss.item()})

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        correct += (
            (pred.argmax(1) == y)
            .type(torch.float)
            .sum()
            .item()
        )

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            
    accuracy = correct / size
    train_loss /= size

    print(f"Train Error: \n Accuracy: {100 * accuracy:>0.1f}%, Avg loss: {train_loss:>8f} \n")

    return {
        "accuracy": accuracy,
        "loss": train_loss,
    }

@torch.no_grad
def test_loop(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module) -> LoopMetrics:
    model = model.to(device).eval()
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)
    
        test_loss += loss.item()
    
        correct += (
            (pred.argmax(1) == y)
            .type(torch.float)
            .sum()
            .item()
        )

    test_loss /= size
    accuracy = correct / size
    print(f"Test Error: \n Accuracy: {100 * accuracy:>0.1f}%, Avg loss: {test_loss:>8f} \n")

    return {
        "accuracy": accuracy,
        "loss": test_loss,
    }

In [54]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_ds, batch_size=run.config.batch_size, shuffle=True, generator=generator)
dev_dataloader = DataLoader(dev_ds, batch_size=run.config.batch_size, shuffle=False)

model = CNN()

run.watch(model)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=run.config.learning_rate)

for t in range(run.config.epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_metrics = train_loop(model, train_dataloader, loss_fn, optimizer)
    test_metrics = test_loop(model, dev_dataloader, loss_fn)
    run.log({
        "train_epoch_accuracy": train_metrics["accuracy"],
        "train_epoch_loss": train_metrics["loss"],
        "test_epoch_accuracy": test_metrics["accuracy"],
        "test_epoch_loss": test_metrics["loss"],
    })

run.finish()

Epoch 1
-------------------------------
loss: 2.332319  [    0/50000]
loss: 0.098487  [ 6400/50000]
loss: 0.050322  [12800/50000]
loss: 0.056963  [19200/50000]
loss: 0.061045  [25600/50000]
loss: 0.055132  [32000/50000]
loss: 0.095527  [38400/50000]
loss: 0.090038  [44800/50000]
Train Error: 
 Accuracy: 95.9%, Avg loss: 0.002066 

Test Error: 
 Accuracy: 21.7%, Avg loss: 0.141404 

Epoch 2
-------------------------------
loss: 0.101278  [    0/50000]
loss: 0.003409  [ 6400/50000]
loss: 0.132791  [12800/50000]
loss: 0.012607  [19200/50000]
loss: 0.006297  [25600/50000]
loss: 0.023155  [32000/50000]
loss: 0.035595  [38400/50000]
loss: 0.022563  [44800/50000]
Train Error: 
 Accuracy: 98.4%, Avg loss: 0.000831 

Test Error: 
 Accuracy: 45.6%, Avg loss: 0.026543 

Epoch 3
-------------------------------
loss: 0.032252  [    0/50000]
loss: 0.037924  [ 6400/50000]
loss: 0.019074  [12800/50000]
loss: 0.023290  [19200/50000]
loss: 0.007244  [25600/50000]
loss: 0.173536  [32000/50000]
loss: 0.00

0,1
test_epoch_accuracy,▂▄▂▁█▄▁▁
test_epoch_loss,▄▁▃▇▁▂▃█
train_batch_loss,▅▃▂▂▄▇▄▃▃▂▃▅▂▄▂▁▁▁▂▅▂▂▁▁▃▃▁▁▄▄▁█▃▁▁▃▁▁▂▁
train_epoch_accuracy,▁▆▇▇████
train_epoch_loss,█▃▂▂▁▁▁▁

0,1
test_epoch_accuracy,0.0969
test_epoch_loss,0.37509
train_batch_loss,0.00322
train_epoch_accuracy,0.99478
train_epoch_loss,0.00024
