# Basic Training Routine

**NOTE**: No longer up-to-date with source code (i.e., LeNet architecture, loading of FashionMNIST dataset, and training algorithms have changed)!


## Imports

To start off, we import all the necessary modules.

In [1]:
from datetime import datetime
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import v2 as transforms
from tqdm import tqdm

## Constants & Preparations

In [2]:
# Timestamp for logging purposes
now = datetime.today()

# Paths
DATA_DIR = Path("../data")
RUNS_DIR = Path(f"../logs/tensorboard/{now.strftime('%Y-%m-%d')}/{now.strftime('%H-%M-%S')}")

# Params
BATCH_SIZE = 128
LR = 1e-3
MOMENTUM = 0.9
NUM_EPOCHS = 10
NUM_WORKERS = 0

# Training
DEVICE = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

# Logging
EPOCH_INDEX = 0
INTRA_EPOCH_UPDATES = 10

Next, we set up a **TensorBoard writer** to log the training process later on.

In [3]:
writer = SummaryWriter(RUNS_DIR)

## Data Preparation

We create a transform that transforms the inputs (`PIL.Image.Image`) to `Image` instances (precisely, `torchvision.tv_tensors.Image`),
which are largely interchangeable with regular tensors. See [here](https://pytorch.org/vision/main/auto_examples/transforms/plot_transforms_getting_started.html#what-are-tvtensors) for details.

In [4]:
# Effectively the same as the 'ToTensor' transformation in v1, followed by normalization
transform = transforms.Compose([
    transforms.ToImage(),                           # convert to Image
    transforms.ToDtype(torch.float32, scale=True),  # scale data to have values in [0, 1]
    transforms.Normalize((0.5,), (0.5,))            # normalize
    ])

We create separate datasets for training and validation.
- `train=True` creates dataset from `train-images-idx3-ubyte` (60k training images)
- `train=False` creates dataset from `t10k-images-idx3-ubyte` (10k test images)

In [5]:
# Create datasets
train_set = torchvision.datasets.FashionMNIST(DATA_DIR, train=True, transform=transform, download=True)
val_set = torchvision.datasets.FashionMNIST(DATA_DIR, train=False, transform=transform, download=True)

# Create dataloaders from datasets, shuffle only during training
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

Next, we manually define the class labels used by the FashionMNIST dataset.

In [6]:
CLASS_LABELS = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle Boot"
]

Finally, we visualize a few images from the validation set using TensorBoard.

In [7]:
# Grab sample images and labels from validation set
sample_val_imgs, sample_val_labels = next(iter(val_loader))
img_grid = torchvision.utils.make_grid(sample_val_imgs)

# Write to Tensorboard
writer.add_image("FashionMNIST Sample Validation Images", img_grid)
writer.flush()

## Neural Networks

We implement a slight modification of the **LeNet** model proposed by [LeCun et al. (1998)](https://direct.mit.edu/neco/article-abstract/1/4/541/5515/Backpropagation-Applied-to-Handwritten-Zip-Code?redirectedFrom=fulltext).

In [8]:
class LeNet(nn.Module):
    """LeNet-5 architecture proposed by LeCun et al. (1998)."""
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


## Utility Functions

We start with a function that **trains\validates a model for a single epoch**. We also report progress to TensorBoard.

In [9]:
def run_epoch(model, dataloader, loss_fn, optimizer=None):
    # Initialize dictionary to log intra-epoch results
    logs = {
        "global_step": [],
        "loss": [],
        "accuracy": []
    }
    
    # Running totals to report progress to TensorBoard
    running_samples = 0
    running_loss = 0.
    running_correct = 0

    # Set training/evaluation mode
    is_training = optimizer is not None
    model.train(is_training)

    # Determine batch indices at which to log to TensorBoard
    num_batches = len(dataloader)
    log_indices = torch.linspace(
        0, num_batches - 1, INTRA_EPOCH_UPDATES + 1
    ).int().tolist()
    if EPOCH_INDEX != 0:
        log_indices = log_indices[1:]

    # Set tags for TensorBoard logging
    tag_loss = f"Loss/{'Train' if is_training else 'Val'}"
    tag_accuracy = f"Accuracy/{'Train' if is_training else 'Val'}"

    # Prepare progress bar
    desc = (f"Epoch [{EPOCH_INDEX + 1:02}/{NUM_EPOCHS}]  "
            f"{'Train' if is_training else 'Val'}")
    pbar = tqdm(dataloader, desc=desc, leave=False, unit="batch")

    # Disable gradients during evaluation
    with (torch.set_grad_enabled(is_training)):
        for batch_index, (inputs, targets) in enumerate(pbar):
            inputs = inputs.to(DEVICE)
            targets = targets.to(DEVICE)

            # Keep track of the number of samples
            samples = len(targets)
            running_samples += samples

            # Forward pass
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)

            # Accumulate loss
            running_loss += loss.item() * samples

            # Compute accuracy
            _, predictions = torch.max(outputs, 1)
            correct = (predictions == targets).sum().item()
            running_correct += correct

            # Update progress bar
            avg_batch_loss = running_loss / running_samples
            avg_batch_accuracy = (running_correct / running_samples) * 100  # in pct
            pbar.set_postfix(
                loss=avg_batch_loss,
                accuracy=avg_batch_accuracy
            )

            # Backward pass and optimization
            if is_training:
                optimizer.zero_grad()  # zero gradients
                loss.backward()        # compute gradients
                optimizer.step()       # update weights

            # Log batch loss and accuracy
            if batch_index in log_indices:
                # Log to TensorBoard
                global_step = EPOCH_INDEX * num_batches + batch_index + 1
                writer.add_scalar(tag_loss, avg_batch_loss, global_step)
                writer.add_scalar(tag_accuracy, avg_batch_accuracy, global_step)

                # Log to dictionary
                logs["global_step"].append(global_step)
                logs["loss"].append(avg_batch_loss)
                logs["accuracy"].append(avg_batch_accuracy)

                # Reset running totals
                running_samples = 0
                running_loss = 0.
                running_correct = 0

    # Flush writer after epoch for live updates
    writer.flush()
    
    return logs


To **train our model**, we simply iteratively call the `run_epoch` function with and without passing an optimizer.

In [10]:
def train_model(model, train_loader, val_loader, loss_fn, optimizer):
    global EPOCH_INDEX

    # Initialize dictionary to log results
    logs = {
        "train": {
            "global_step": [],
            "loss": [],
            "accuracy": []
        },
        "val": {
            "global_step": [],
            "loss": [],
            "accuracy": []
        }
    }
    
    for _ in range(NUM_EPOCHS):
        # Train and validate model
        train_logs = run_epoch(model, train_loader, loss_fn, optimizer)
        val_logs = run_epoch(model, val_loader, loss_fn)

        # Log results
        logs["train"]["global_step"].extend(train_logs["global_step"])
        logs["train"]["loss"].extend(train_logs["loss"])
        logs["train"]["accuracy"].extend(train_logs["accuracy"])
        logs["val"]["global_step"].extend(val_logs["global_step"])
        logs["val"]["loss"].extend(val_logs["loss"])
        logs["val"]["accuracy"].extend(val_logs["accuracy"])

        # Increment epoch index
        EPOCH_INDEX += 1

    # Close TensorBoard writer and inform user of training completion
    writer.close()
    print("Training complete!")

    return logs


We also need a function to regularly **save checkpoints during training**.

In [11]:
def save_checkpoint():
    pass


## Training

First, we check the **target device for training**.

In [12]:
print(DEVICE)

mps


We create an **instance of the LeNet model architecture**, move the network to the target device, and visualize the network's
architecture using TensorBoard.

In [13]:
network = LeNet().to(DEVICE)

# Visualize architecture using TensorBoard
writer.add_graph(network, sample_val_imgs.to(DEVICE))
writer.flush()

Next, we set up our **loss function** and **optimizer**.

In [14]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(network.parameters(), lr=LR, momentum=MOMENTUM)

We open **TensorBoard** to track the training progress.

In [15]:
%load_ext tensorboard
%tensorboard --logdir ../logs/tensorboard

Finally, we start the training loop.

In [16]:
logs = train_model(network, train_loader, val_loader, loss_fn, optimizer)

                                                                                

Training complete!


