In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

In [None]:
from pathlib import Path

DATA_DIR = Path.cwd().parent.parent / "data"

DATA_DIR

In [4]:
import torch

generator = torch.Generator().manual_seed(42)

# Splits

In [None]:

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

train_ds = MNIST(
    root=DATA_DIR / "train",
    train=True,
    download=True,
    transform=ToTensor(),
)

train_ds

In [None]:
import torch
from torch.utils.data import random_split

DEV_SIZE = 10_000

train_ds, dev_ds = random_split(
    train_ds,
    [len(train_ds) - DEV_SIZE, DEV_SIZE],
    generator=torch.Generator().manual_seed(42),
)

len(train_ds), len(dev_ds)

In [None]:
from torch.utils.data import DataLoader

dataloader = DataLoader(train_ds, batch_size=8, shuffle=False)

x, y = next(iter(dataloader))

x.shape, y.shape

In [None]:
import matplotlib.pyplot as plt

x = x.to("cpu")

plt.title(f"Label: {y[0]}")
plt.imshow(x.squeeze()[0], cmap="gray")

plt.show()

In [None]:
x = x.to("cpu")

plt.title(f"Label: {y[1]}")
plt.imshow(x.squeeze()[1], cmap="gray")

plt.show()

# Definition

TODO
- input norm
- batch norm
- hyperparameter tuning
- quantization

In [10]:
import torch
import torch.nn as nn

@torch.compile
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten(1, -1) # (B, 28, 28) -> (B, 28*28)
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 10),
        )

    def forward(self, x):
        f = self.flatten(x)
        logits = self.linear_relu_stack(f)
        return logits

In [None]:
model = MLP()

model

In [None]:
logits = model(x)

logits.shape

In [None]:
logits.argmax(1)

# Training

In [None]:
import wandb

run = wandb.init(project="mnist", config={
    "architecture": "mlp",
    "epochs": 5,
    "batch_size": 64,
    "learning_rate": 1e-3,
})

In [24]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import TypedDict

class LoopMetrics(TypedDict):
    accuracy: float
    loss: float

def train_loop(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module, optimizer: torch.optim.Optimizer) -> LoopMetrics:
    model = model.to(device).train()
    size = len(dataloader.dataset)
    train_loss, correct = 0, 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        X = X.squeeze()

        pred = model(X)
        loss = loss_fn(pred, y)

        train_loss += loss.item()
        run.log({"train_batch_loss": loss.item()})

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        correct += (
            (pred.argmax(1) == y)
            .type(torch.float)
            .sum()
            .item()
        )

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            
    accuracy = correct / size
    train_loss /= size

    print(f"Train Error: \n Accuracy: {100 * accuracy:>0.1f}%, Avg loss: {train_loss:>8f} \n")

    return {
        "accuracy": accuracy,
        "loss": train_loss,
    }

@torch.no_grad
def test_loop(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module) -> LoopMetrics:
    model = model.to(device).eval()
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        X = X.squeeze()

        pred = model(X)
        loss = loss_fn(pred, y)
    
        test_loss += loss.item()
    
        correct += (
            (pred.argmax(1) == y)
            .type(torch.float)
            .sum()
            .item()
        )

    test_loss /= size
    accuracy = correct / size
    print(f"Test Error: \n Accuracy: {100 * accuracy:>0.1f}%, Avg loss: {test_loss:>8f} \n")

    return {
        "accuracy": accuracy,
        "loss": test_loss,
    }

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_ds, batch_size=run.config.batch_size, shuffle=True, generator=generator)
dev_dataloader = DataLoader(dev_ds, batch_size=run.config.batch_size, shuffle=False)

model = MLP()
# model.compile()

run.watch(model)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=run.config.learning_rate)

for t in range(run.config.epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_metrics = train_loop(model, train_dataloader, loss_fn, optimizer)
    test_metrics = test_loop(model, dev_dataloader, loss_fn)
    run.log({
        "train_epoch_accuracy": train_metrics["accuracy"],
        "train_epoch_loss": train_metrics["loss"],
        "test_epoch_accuracy": test_metrics["accuracy"],
        "test_epoch_loss": test_metrics["loss"],
    })

# TODO: save in onnx format
# run.log_artifact(model)
run.finish()