In [1]:
from pathlib import Path

DATA_DIR = Path.cwd().parent / "data"

DATA_DIR

PosixPath('/workspaces/mnist/data')

In [2]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.set_float32_matmul_precision('high')

device

device(type='cuda')

In [3]:
import torch

generator = torch.Generator().manual_seed(42)

# Splits

In [4]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

train_ds = MNIST(
    root=DATA_DIR / "train",
    train=True,
    download=True,
    transform=ToTensor(),
)

train_ds

Dataset MNIST
    Number of datapoints: 60000
    Root location: /workspaces/mnist/data/train
    Split: Train
    StandardTransform
Transform: ToTensor()

In [5]:
from torch.utils.data import random_split

DEV_SIZE = 10_000

train_ds, dev_ds = random_split(
    train_ds,
    [len(train_ds) - DEV_SIZE, DEV_SIZE],
    generator=generator,
)

len(train_ds), len(dev_ds)

(50000, 10000)

# Definition

In [49]:
import torch
import torch.nn as nn

@torch.compile
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten(1, -1) # (B, 28, 28) -> (B, 28*28)
        self.linear_relu_stack = nn.Sequential(
            nn.BatchNorm1d(28*28),

            nn.Linear(28*28, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.5),
            
            nn.Linear(128, 10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [50]:
model = MLP()

model

MLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): BatchNorm1d(784, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Linear(in_features=784, out_features=2048, bias=True)
    (2): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
    (4): Dropout(p=0.5, inplace=False)
    (5): Linear(in_features=2048, out_features=1024, bias=True)
    (6): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
    (8): Dropout(p=0.5, inplace=False)
    (9): Linear(in_features=1024, out_features=512, bias=True)
    (10): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): Dropout(p=0.5, inplace=False)
    (13): Linear(in_features=512, out_features=128, bias=True)
    (14): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (15): ReLU()
    (16): Dropout(p=0.5

# Training

In [51]:
import wandb

run = wandb.init(project="mnist", config={
    "architecture": "mlp",
    "epochs": 8,
    "batch_size": 64,
    "learning_rate": 1e-3,
})

In [52]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import TypedDict

class LoopMetrics(TypedDict):
    accuracy: float
    loss: float

def train_loop(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module, optimizer: torch.optim.Optimizer) -> LoopMetrics:
    model = model.to(device).train()
    size = len(dataloader.dataset)
    train_loss, correct = 0, 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        X = X.squeeze()

        pred = model(X)
        loss = loss_fn(pred, y)

        train_loss += loss.item()
        run.log({"train_batch_loss": loss.item()})

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        correct += (
            (pred.argmax(1) == y)
            .type(torch.float)
            .sum()
            .item()
        )

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            
    accuracy = correct / size
    train_loss /= size

    print(f"Train Error: \n Accuracy: {100 * accuracy:>0.1f}%, Avg loss: {train_loss:>8f} \n")

    return {
        "accuracy": accuracy,
        "loss": train_loss,
    }

@torch.no_grad
def test_loop(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module) -> LoopMetrics:
    model = model.to(device).eval()
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        X = X.squeeze()

        pred = model(X)
        loss = loss_fn(pred, y)
    
        test_loss += loss.item()
    
        correct += (
            (pred.argmax(1) == y)
            .type(torch.float)
            .sum()
            .item()
        )

    test_loss /= size
    accuracy = correct / size
    print(f"Test Error: \n Accuracy: {100 * accuracy:>0.1f}%, Avg loss: {test_loss:>8f} \n")

    return {
        "accuracy": accuracy,
        "loss": test_loss,
    }

In [53]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_ds, batch_size=run.config.batch_size, shuffle=True, generator=generator)
dev_dataloader = DataLoader(dev_ds, batch_size=run.config.batch_size, shuffle=False)

model = MLP()

run.watch(model)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=run.config.learning_rate)

for t in range(run.config.epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_metrics = train_loop(model, train_dataloader, loss_fn, optimizer)
    test_metrics = test_loop(model, dev_dataloader, loss_fn)
    run.log({
        "train_epoch_accuracy": train_metrics["accuracy"],
        "train_epoch_loss": train_metrics["loss"],
        "test_epoch_accuracy": test_metrics["accuracy"],
        "test_epoch_loss": test_metrics["loss"],
    })

run.finish()

Epoch 1
-------------------------------
loss: 2.493058  [    0/50000]
loss: 0.548954  [ 6400/50000]
loss: 0.791900  [12800/50000]
loss: 0.237668  [19200/50000]
loss: 0.327181  [25600/50000]
loss: 0.260748  [32000/50000]
loss: 0.216443  [38400/50000]
loss: 0.248482  [44800/50000]
Train Error: 
 Accuracy: 89.1%, Avg loss: 0.006106 

Test Error: 
 Accuracy: 94.9%, Avg loss: 0.002956 

Epoch 2
-------------------------------
loss: 0.193191  [    0/50000]
loss: 0.181885  [ 6400/50000]
loss: 0.102930  [12800/50000]
loss: 0.119372  [19200/50000]
loss: 0.211668  [25600/50000]
loss: 0.159566  [32000/50000]
loss: 0.174050  [38400/50000]
loss: 0.061767  [44800/50000]
Train Error: 
 Accuracy: 94.1%, Avg loss: 0.003254 

Test Error: 
 Accuracy: 96.0%, Avg loss: 0.002752 

Epoch 3
-------------------------------
loss: 0.163256  [    0/50000]
loss: 0.131595  [ 6400/50000]
loss: 0.151804  [12800/50000]
loss: 0.074652  [19200/50000]
loss: 0.156065  [25600/50000]
loss: 0.203661  [32000/50000]
loss: 0.13

0,1
test_epoch_accuracy,▁▄▄▆▇▇██
test_epoch_loss,▂▁▃▄▂▃▂█
train_batch_loss,█▄▅▃▄▄▃▂▄▅▂▂▅▃▃▂▃▂▂▁▂▃▄▁▂▁▂▂▂▂▂▃▂▁▁▂▂▂▂▂
train_epoch_accuracy,▁▅▆▇▇▇██
train_epoch_loss,█▄▃▂▂▂▁▁

0,1
test_epoch_accuracy,0.97
test_epoch_loss,0.00386
train_batch_loss,0.60108
train_epoch_accuracy,0.97308
train_epoch_loss,0.00147


# Evaluation

In [54]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

test_ds = MNIST(
    root=DATA_DIR / "test",
    train=False,
    download=True,
    transform=ToTensor(),
)

test_ds

Dataset MNIST
    Number of datapoints: 10000
    Root location: /workspaces/mnist/data/test
    Split: Test
    StandardTransform
Transform: ToTensor()

In [57]:
from torch.utils.data import DataLoader

test_dataloader = DataLoader(test_ds, batch_size=run.config.batch_size, shuffle=False)

_ = test_loop(model, test_dataloader, loss_fn)

Test Error: 
 Accuracy: 97.2%, Avg loss: 0.004080 

