# Proof of Concept
## Imports

To start off, we import all the necessary modules.

In [1]:
from pathlib import Path

import torch
import torchvision
from torchvision.transforms import v2 as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter

## Constants & Preparations

In [2]:
DATA_DIR = Path("../data/raw")
RUNS_DIR = Path("../runs/FashionMNIST")

NUM_EPOCHS = 10
BATCH_SIZE = 64
NUM_WORKERS = 16
DEVICE = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
    )

LOSS_FN = nn.CrossEntropyLoss()
LR = 1e-3
MOMENTUM = 0.9

REPORTING_FREQ = 100

Next, we set up a **TensorBoard writer** to log the training process later on.

In [3]:
writer = SummaryWriter(RUNS_DIR)

## Data Preparation
We create a transform that transforms the inputs (`PIL.Image.Image`) to `Image` instances (precisely, `torchvision.tv_tensors.Image`),
which are largely interchangeable with regular tensors. See [here](https://pytorch.org/vision/main/auto_examples/transforms/plot_transforms_getting_started.html#what-are-tvtensors) for details.

In [4]:
# Effectively the same as the 'ToTensor' transformation in v1, followed by normalization
transform = transforms.Compose([
    transforms.ToImage(),                           # convert to Image
    transforms.ToDtype(torch.float32, scale=True),  # scale data to have values in [0, 1]
    transforms.Normalize((0.5,), (0.5,))            # normalize
    ])

We create separate datasets for training and validation.
- `train=True` creates dataset from `train-images-idx3-ubyte` (60k training images)
- `train=False` creates dataset from `t10k-images-idx3-ubyte` (10k test images)

In [5]:
# Create datasets
train_set = torchvision.datasets.FashionMNIST(DATA_DIR, train=True, transform=transform, download=True)
val_set = torchvision.datasets.FashionMNIST(DATA_DIR, train=False, transform=transform, download=True)

# Create dataloaders from datasets, shuffle only during training
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

Next, we manually define the class labels used by the FashionMNIST dataset.

In [6]:
CLASS_LABELS = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle Boot"
    ]

Finally, we visualize a few images from the validation set using TensorBoard.

In [7]:
# Grab sample images and labels from validation set
sample_val_imgs, sample_val_labels = next(iter(val_loader))
img_grid = torchvision.utils.make_grid(sample_val_imgs)

# Write to Tensorboard
writer.add_image("FashionMNIST Sample Validation Images", img_grid)
writer.flush()

## Neural Networks
We implement a slight modification of the **LeNet** model proposed by [LeCun et al. (1998)](https://direct.mit.edu/neco/article-abstract/1/4/541/5515/Backpropagation-Applied-to-Handwritten-Zip-Code?redirectedFrom=fulltext).

In [8]:
class LeNet(nn.Module):
    """LeNet-5 architecture proposed by LeCun et al. (1998)."""
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


## Utility Functions
We start with a function that **trains a network for a single epoch**. We report the training progress to TensorBoard and
additionally print to the console.

In [9]:
def train_one_epoch(
    network, dataloader, loss_fn, optimizer,
    num_epochs, epoch_index):
    """Train a network on a training set for one full epoch."""
    # Running totals to report training progress to TensorBoard
    running_samples = 0
    running_loss = 0.
    running_correct = 0

    # Enable training mode
    network.train()

    for batch_index, (inputs, labels) in enumerate(dataloader):
        # Move data to target device
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)

        # Keep track of number of samples
        samples = len(labels)
        running_samples += samples

        optimizer.zero_grad()            # zero gradients
        outputs = network(inputs)        # perform forward pass
        loss = loss_fn(outputs, labels)  # compute batch loss
        loss.backward()                  # compute gradients
        optimizer.step()                 # adjust network parameters

        # Accumulate loss
        running_loss += loss.item() * samples

        # Count correct predictions and add to running total
        _, predictions = torch.max(outputs.data, dim=1)
        running_correct += (predictions == labels).sum().item()

        # Report training progress to TensorBoard
        if (batch_index + 1) % REPORTING_FREQ == 0:
            # Compute current global step (i.e., across epochs)
            global_step = epoch_index * len(dataloader) + batch_index
            
            # Compute average loss and average accuracy per batch
            avg_batch_loss = running_loss / running_samples
            avg_batch_accuracy = (running_correct / running_samples) * 100  # in pct

            # Write to TensorBoard
            writer.add_scalar("Training Loss", avg_batch_loss, global_step)
            writer.add_scalar("Training Accuracy", avg_batch_accuracy, global_step)

            # Print results to console
            print(f"Epoch [{epoch_index + 1:02}/{num_epochs}]   "
                  f"Batch [{batch_index + 1:03}/{len(dataloader)}]   "
                  f"Loss: {avg_batch_loss:.4f}   "
                  f"Acc: {avg_batch_accuracy:02.2f}")

            # Reset running totals
            running_samples = 0
            running_loss = 0.
            running_correct = 0


Next, we implement a function that lets us **evaluate a network on a validation set**.

In [10]:
def eval_network(
    network, dataloader, loss_fn,
    num_epochs, epoch_index):
    """Evaluate a network on a test set using the provided dataloader and loss function."""
    # Running totals to report training progress to TensorBoard
    running_samples = 0
    running_loss = 0.
    running_correct = 0

    # Total values to compute loss and accuracy over entire test set
    total_samples = 0
    total_loss = 0.
    total_correct = 0

    # Enable evaluation mode
    network.eval()

    with torch.no_grad():
        for batch_index, (inputs, labels) in enumerate(dataloader):
            # Move data to target device
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)

            # Keep track of number of samples
            samples = len(labels)
            running_samples += samples
            total_samples += samples

            outputs = network(inputs)  # perform forward pass
            loss = loss_fn(outputs, labels)  # compute batch loss

             # Accumulate loss
            running_loss += loss.item() * samples
            total_loss += loss.item() * samples

            # Count correct predictions and add to running total(s)
            _, predictions = torch.max(outputs.data, dim=1)
            correct = (predictions == labels).sum().item()
            running_correct += correct
            total_correct += correct

            # Report results to TensorBoard
            if (batch_index + 1) % REPORTING_FREQ == 0:
                # Compute current global step (i.e., across epochs)
                global_step = epoch_index * len(dataloader) + batch_index
                
                # Compute average loss and average accuracy per batch
                avg_batch_loss = running_loss / running_samples
                avg_batch_accuracy = (running_correct / running_samples) * 100  # in pct
    
                # Write to TensorBoard
                writer.add_scalar("Validation Loss", avg_batch_loss, global_step)
                writer.add_scalar("Validation Accuracy", avg_batch_accuracy, global_step)
    
                # Reset running totals
                running_samples = 0
                running_loss = 0.
                running_correct = 0

    # Compute loss and accuracy over entire test set
    avg_loss = total_loss / total_samples
    avg_accuracy = (total_correct / total_samples) * 100  # in pct

    return avg_loss, avg_accuracy


We also need a function to regularly **save checkpoints during training**.

In [11]:
def save_checkpoint():
    pass


Finally, we put everything together into a single function that can be used to **train a network for multiple epochs**.

In [12]:
def train_network(
        network, train_loader, val_loader, loss_fn, optimizer,
        num_epochs, start_epoch=0):
    """Train and validate a network for multiple epochs at once."""
    epoch_index = start_epoch
    
    for epoch in range(num_epochs):
        # Iteratively train and validate the network
        print("Training...")
        train_one_epoch(network, train_loader, loss_fn, optimizer, num_epochs, epoch_index)
        print("Validating...")
        avg_val_loss, avg_val_accuracy = eval_network(network, val_loader, loss_fn, num_epochs, epoch_index)

        # Report results for validation set
        print(f"Epoch [{epoch_index + 1:02}/{num_epochs}]   "
                  "Validation        "
                  f"Loss: {avg_val_loss:.4f}   "
                  f"Acc: {avg_val_accuracy:02.2f}\n")

        epoch_index += 1


## Training
First, we check the **target device for training**.

In [13]:
print(DEVICE)

mps


We create an **instance of the LeNet model architecture**, move the network to the target device, and visualize the network's
architecture using TensorBoard.

In [14]:
network = LeNet().to(DEVICE)

# Visualize architecture using TensorBoard
writer.add_graph(network, sample_val_imgs.to(DEVICE))
writer.flush()

Next, we set up our **optimizer**.

In [15]:
optimizer = optim.SGD(network.parameters(), lr=LR, momentum=MOMENTUM)

Finally, we start the training loop.

In [16]:
train_network(network, train_loader, val_loader, LOSS_FN, optimizer, NUM_EPOCHS)

Training...
Epoch [01/10]   Batch [100/938]   Loss: 2.3031   Acc: 12.58
Epoch [01/10]   Batch [200/938]   Loss: 2.2990   Acc: 18.95
Epoch [01/10]   Batch [300/938]   Loss: 2.2942   Acc: 27.28
Epoch [01/10]   Batch [400/938]   Loss: 2.2855   Acc: 26.02
Epoch [01/10]   Batch [500/938]   Loss: 2.2689   Acc: 25.30
Epoch [01/10]   Batch [600/938]   Loss: 2.2190   Acc: 28.64
Epoch [01/10]   Batch [700/938]   Loss: 2.0040   Acc: 34.23
Epoch [01/10]   Batch [800/938]   Loss: 1.4393   Acc: 53.36
Epoch [01/10]   Batch [900/938]   Loss: 1.0289   Acc: 62.89
Validating...
Epoch [01/10]   Validation        Loss: 0.9140   Acc: 66.59

Training...
Epoch [02/10]   Batch [100/938]   Loss: 0.8472   Acc: 69.48
Epoch [02/10]   Batch [200/938]   Loss: 0.8132   Acc: 70.56
Epoch [02/10]   Batch [300/938]   Loss: 0.7862   Acc: 70.45
Epoch [02/10]   Batch [400/938]   Loss: 0.7630   Acc: 72.00
Epoch [02/10]   Batch [500/938]   Loss: 0.7269   Acc: 72.80
Epoch [02/10]   Batch [600/938]   Loss: 0.7060   Acc: 73.41
E