# Training Script
Adapted from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

In [1]:
from pathlib import Path
from typing import Literal, Tuple, Union

import torch
from torch import nn

### Hyperparameters

In [2]:
NUM_EPOCHS = 1
LEARNING_RATE = 1e-3
BATCH_SIZE = 32

### Devices

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


### Model

In [4]:
from torchvision import models

def load_model(model_name: str, n_classes: int) -> Tuple[nn.Module, int]:
    if model_name == "resnet":
        model = models.resnet50(pretrained=True)
        n_features_in = model.fc.in_features
        model.fc = nn.Linear(n_features_in, 10)
    elif model_name == "mobilenet":
        model = models.mobilenet.mobilenet_v3_small(pretrained=True)
        n_features_in = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(n_features_in, n_classes)
    else:
        raise NotImplementedError

    input_size = 224
    return model, input_size

model, input_size = load_model('mobilenet', n_classes=10)
model = model.to(device)

### Image Preprocessing

In [5]:
from torchvision import transforms as T

train_transform = T.Compose(
    [
        T.Resize((input_size, input_size)),
        T.RandomAffine(10, translate=(0.02, 0.02), scale=(0.9, 1.1)),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)
non_train_transform = T.Compose([T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

### MNIST Dataset

In [6]:
from torch.utils.data import Dataset

class MNISTDataset(Dataset):
    def __init__(self, set: Union[Literal["train"], Literal["val"], Literal["test"]]):
        DATA_DIR = Path("data/MNIST")
        self.images = torch.concat(
            [torch.load(DATA_DIR / f"{set}_images.pt")] * 3, dim=1
        )
        self.labels = torch.load(DATA_DIR / f"{set}_labels.pt")

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.images[index], self.labels[index]


train_dataset = MNISTDataset("train")
val_dataset = MNISTDataset("val")
test_dataset = MNISTDataset("test")

print(f"Number of training examples: {len(train_dataset)}")
print(f"Number of validation examples: {len(val_dataset)}")
print(f"Number of test examples: {len(test_dataset)}")

Number of training examples: 50000
Number of validation examples: 10000
Number of test examples: 10000


### DataLoader

In [7]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE * 2, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE * 2, shuffle=False)

### Criterion & Optimizer

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

### Training & Evaluation Loop

In [9]:
import time

def evaluate(model: nn.Module, dataloader: DataLoader) -> None:
    correct = 0
    total = 0

    with torch.no_grad():
        for (X, y) in dataloader:
            X = X.to(device)
            y = y.to(device)

            X = non_train_transform(X)
            y_pred = model(X)

            _, y_pred = torch.max(y_pred.data, 1)
            total += y.size(0)
            correct += (y_pred == y).sum().item()

    print(
        "Accuracy of the network on the images: %d %%"
        % (100 * correct / total)
    )

start = time.perf_counter()
for epoch in range(NUM_EPOCHS):
    running_loss = 0.0

    for train_idx, (X_train, y_train) in enumerate(train_dataloader):
        X_train = X_train.to(device)
        y_train = y_train.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward pass + loss calc + backprop + update model params
        X_train = train_transform(X_train)
        y_pred = model(X_train)
        loss = criterion(y_pred, y_train)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if train_idx % 100 == 0:
            print(
                "[%d, %5d] loss: %.3f" % (epoch + 1, train_idx + 1, running_loss / 2000)
            )
            running_loss = 0.0

evaluate(model, test_dataloader)

end = time.perf_counter()
print(f"Training job completed in : {end - start:.1f}s")

[1,     1] loss: 0.001
[1,   101] loss: 0.018
[1,   201] loss: 0.006
[1,   301] loss: 0.005
[1,   401] loss: 0.004
[1,   501] loss: 0.004
[1,   601] loss: 0.003
[1,   701] loss: 0.003
[1,   801] loss: 0.005
[1,   901] loss: 0.003
[1,  1001] loss: 0.003
[1,  1101] loss: 0.003
[1,  1201] loss: 0.002
[1,  1301] loss: 0.003
[1,  1401] loss: 0.003
[1,  1501] loss: 0.003
Training job completed in : 143.4s
