In [1]:
from pathlib import Path

DATA_DIR = Path.cwd().parent / "data"

DATA_DIR

PosixPath('/workspaces/mnist/data')

In [27]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.set_float32_matmul_precision('high')

device

device(type='cuda')

In [3]:
import torch

generator = torch.Generator().manual_seed(42)

# Splits

In [15]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

train_ds = MNIST(
    root=DATA_DIR / "train",
    train=True,
    download=True,
    transform=ToTensor(),
)

train_ds

Dataset MNIST
    Number of datapoints: 60000
    Root location: /workspaces/mnist/data/train
    Split: Train
    StandardTransform
Transform: ToTensor()

In [None]:
from torch.utils.data import random_split

DEV_SIZE = 10_000

train_ds, dev_ds = random_split(
    train_ds,
    [len(train_ds) - DEV_SIZE, DEV_SIZE],
    generator=generator,
)

len(train_ds), len(dev_ds)

(50000, 10000)

# Definition

with relu activations, conventional normalization (e.g., converting to standard deviations) with a 0-skewed input dataset, gradients will vanish. min-max scaling seems like a better approach with relu. since the min is 0, this will end up just dividing each input by 255

learning with min-max scaling is much slower. current hypothesis is that, with input values heavily skewed towards 0, the non-zero values do not have enough influence. so weights need to be much larger to compensate. should look at batch norm next, and find other ways to test this hypothesis

In [30]:
X_train = train_ds.data.float()

min = X_train.min().item()
max = X_train.max().item()

min, max

(0.0, 255.0)

In [16]:
import torch.nn as nn

class Scale(nn.Module):
    min: float
    max: float

    def __init__(self, min: float, max: float):
        super().__init__()
        self.min = min
        self.max = max

    def forward(self, x):
        return (x - self.min) / (self.max - self.min)

In [22]:
import torch
import torch.nn as nn

@torch.compile
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.scale = Scale(min, max)
        self.flatten = nn.Flatten(1, -1) # (B, 28, 28) -> (B, 28*28)
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.scale(x)
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [23]:
model = MLP()

model

MLP(
  (scale): Scale()
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=32, bias=True)
    (5): ReLU()
    (6): Linear(in_features=32, out_features=10, bias=True)
  )
)

# Training

In [28]:
import wandb

run = wandb.init(project="mnist", config={
    "architecture": "mlp",
    "epochs": 8,
    "batch_size": 64,
    "learning_rate": 1e-3,
})

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112994188887467, max=1.0…

In [25]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import TypedDict

class LoopMetrics(TypedDict):
    accuracy: float
    loss: float

def train_loop(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module, optimizer: torch.optim.Optimizer) -> LoopMetrics:
    model = model.to(device).train()
    size = len(dataloader.dataset)
    train_loss, correct = 0, 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        X = X.squeeze()

        pred = model(X)
        loss = loss_fn(pred, y)

        train_loss += loss.item()
        run.log({"train_batch_loss": loss.item()})

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        correct += (
            (pred.argmax(1) == y)
            .type(torch.float)
            .sum()
            .item()
        )

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            
    accuracy = correct / size
    train_loss /= size

    print(f"Train Error: \n Accuracy: {100 * accuracy:>0.1f}%, Avg loss: {train_loss:>8f} \n")

    return {
        "accuracy": accuracy,
        "loss": train_loss,
    }

@torch.no_grad
def test_loop(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module) -> LoopMetrics:
    model = model.to(device).eval()
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        X = X.squeeze()

        pred = model(X)
        loss = loss_fn(pred, y)
    
        test_loss += loss.item()
    
        correct += (
            (pred.argmax(1) == y)
            .type(torch.float)
            .sum()
            .item()
        )

    test_loss /= size
    accuracy = correct / size
    print(f"Test Error: \n Accuracy: {100 * accuracy:>0.1f}%, Avg loss: {test_loss:>8f} \n")

    return {
        "accuracy": accuracy,
        "loss": test_loss,
    }

In [29]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_ds, batch_size=run.config.batch_size, shuffle=True, generator=generator)
dev_dataloader = DataLoader(dev_ds, batch_size=run.config.batch_size, shuffle=False)

model = MLP()

run.watch(model)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=run.config.learning_rate)

for t in range(run.config.epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_metrics = train_loop(model, train_dataloader, loss_fn, optimizer)
    test_metrics = test_loop(model, dev_dataloader, loss_fn)
    run.log({
        "train_epoch_accuracy": train_metrics["accuracy"],
        "train_epoch_loss": train_metrics["loss"],
        "test_epoch_accuracy": test_metrics["accuracy"],
        "test_epoch_loss": test_metrics["loss"],
    })

run.finish()

Epoch 1
-------------------------------
loss: 2.306723  [    0/60000]
loss: 2.289753  [ 6400/60000]
loss: 1.852922  [12800/60000]
loss: 1.698037  [19200/60000]
loss: 1.565181  [25600/60000]
loss: 1.550838  [32000/60000]
loss: 1.110758  [38400/60000]
loss: 1.119117  [44800/60000]
loss: 1.031192  [51200/60000]
loss: 0.889358  [57600/60000]
Train Error: 
 Accuracy: 41.4%, Avg loss: 0.023709 

Test Error: 
 Accuracy: 59.5%, Avg loss: 0.016755 

Epoch 2
-------------------------------
loss: 1.018343  [    0/60000]
loss: 0.926341  [ 6400/60000]
loss: 1.018430  [12800/60000]
loss: 1.056577  [19200/60000]
loss: 0.766785  [25600/60000]
loss: 1.005056  [32000/60000]
loss: 0.843413  [38400/60000]
loss: 0.847720  [44800/60000]
loss: 0.633220  [51200/60000]
loss: 0.742010  [57600/60000]
Train Error: 
 Accuracy: 71.0%, Avg loss: 0.013195 

Test Error: 
 Accuracy: 79.9%, Avg loss: 0.010100 

Epoch 3
-------------------------------
loss: 0.775520  [    0/60000]
loss: 0.392476  [ 6400/60000]
loss: 0.54

0,1
test_epoch_accuracy,▁▅▆▇▇███
test_epoch_loss,█▄▃▃▂▂▁▁
train_batch_loss,█▇█▅▄▃▄▃▂▂▃▂▄▂▂▂▂▂▃▂▃▃▂▂▃▂▂▂▁▃▂▂▁▁▂▂▁▁▁▁
train_epoch_accuracy,▁▅▇▇▇███
train_epoch_loss,█▄▃▂▂▁▁▁

0,1
test_epoch_accuracy,0.9262
test_epoch_loss,0.0039
train_batch_loss,0.19177
train_epoch_accuracy,0.92397
train_epoch_loss,0.00399
