# Proof of Concept
## Imports

In [1]:
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import v2 as transforms

## Define Constants

In [2]:
BATCH_SIZE = 16
DATA_DIR = Path("../data/raw")
NUM_EPOCHS = 5
LOSS_FN = nn.CrossEntropyLoss()
LR = 0.001
MOMENTUM = 0.9

## Data Preparation

We create a transform that transforms the inputs (`PIL.Image.Image`) to `Image` instances (precisely, `torchvision.tv_tensors.Image`),
which are largely interchangeable with regular tensors. See [here](https://pytorch.org/vision/main/auto_examples/transforms/plot_transforms_getting_started.html#what-are-tvtensors) for details.

In [3]:
# Effectively the same as the 'ToTensor' transformation in v1, followed by normalization
transform = transforms.Compose([
    transforms.ToImage(),                           # convert to Image
    transforms.ToDtype(torch.float32, scale=True),  # scale data to have values in [0, 1]
    transforms.Normalize((0.5,), (0.5,))            # normalize
    ])

We create separate datasets for training and validation.
- `train=True` creates dataset from `train-images-idx3-ubyte` (60k training images)
- `train=False` creates dataset from `t10k-images-idx3-ubyte` (10k test images)

In [4]:
# Create datasets
train_set = torchvision.datasets.FashionMNIST(DATA_DIR, train=True, transform=transform, download=True)
val_set = torchvision.datasets.FashionMNIST(DATA_DIR, train=False, transform=transform, download=True)

# Create dataloaders from datasets, shuffle only during training
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False)

Finally, we manually define the class labels used by the FashionMNIST dataset.

In [5]:
CLASS_LABELS = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle Boot"
    ]

## Neural Networks
We implement a slight modification of the **LeNet** model proposed by [LeCun et al. (1998)](https://direct.mit.edu/neco/article-abstract/1/4/541/5515/Backpropagation-Applied-to-Handwritten-Zip-Code?redirectedFrom=fulltext).

In [6]:
class LeNet(nn.Module):
    """LeNet-5 architecture proposed by LeCun et al. (1998)."""
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


## Utility Functions

In [7]:
def train_one_epoch(network, dataloader, loss_fn, optimizer):
    """Train a network on a training set for one full epoch.

    Returns:
        float: The average loss over the last set of 250 batches.
    """
    running_loss = 0.
    avg_batch_loss = 0.

    network.train()

    for batch_index, data in enumerate(dataloader):
        inputs, labels = data
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()            # zero gradients
        outputs = network(inputs)        # perform forward pass
        loss = loss_fn(outputs, labels)  # compute batch loss
        loss.backward()                  # compute gradients
        optimizer.step()                 # adjust network parameters
        
        running_loss += loss.item()  # accumulate loss over multiple batches

        if (batch_index + 1) % 250 == 0:
            avg_batch_loss = running_loss / 250  # avg loss per batch
            print(f"   Batch {batch_index + 1:04}   Loss: {avg_batch_loss:.4f}")
            running_loss = 0.  # reset running loss
    
    return avg_batch_loss


def test_network(network, dataloader, loss_fn):
    """Test a network on a test set using the provided dataloader and loss function.
    
    Returns:
        float: The average loss over the last set of 250 batches.
    """
    running_loss = 0.
    avg_batch_loss = 0.

    network.eval()

    with torch.no_grad():
        # NOTE: Using 'enumerate(dataloader)' let's us track the batch we're
        #       currently in for intra-epoch reporting.
        for batch_index, data in enumerate(dataloader):
            inputs, labels = data
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)

            outputs = network(inputs)  # perform forward pass
            loss = loss_fn(outputs, labels)  # compute batch loss
            
            running_loss += loss.item()  # accumulate loss over multiple batches

            if (batch_index + 1) % 250 == 0:
                avg_batch_loss = running_loss / 250  # avg loss per batch
                running_loss = 0.  # reset running loss

    return avg_batch_loss


def save_checkpoint():
    pass


def train_network(network, train_loader, val_loader, loss_fn, optimizer, num_epochs, start_epoch=0):
    """Train and validate a network for multiple epochs at once."""
    epoch_index = start_epoch
    
    for epoch in range(num_epochs):
        print(f"EPOCH {epoch_index + 1}")

        # Iteratively train and validate the network
        _ = train_one_epoch(network, train_loader, loss_fn, optimizer)
        avg_val_loss = test_network(network, val_loader, loss_fn)

        # Report results for validation set
        print(f"   Validation   Loss: {avg_val_loss:.4f}")

        epoch_index += 1


## Train the Network
First, we set the **target device for training**.

In [8]:
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(DEVICE)

mps


We create an **instance of the LeNet model architecture**, and move the network to the target device.

In [9]:
network = LeNet().to(DEVICE)

Next, we set up our **optimizer**.  
**IMPORTANT**: The optimizer has to be initialized with the network's parameters **after** the model has been moved to the target device for training.

In [10]:
optimizer = torch.optim.SGD(network.parameters(), lr=LR, momentum=MOMENTUM)

Finally, we start the training loop.

In [11]:
train_network(network, train_loader, val_loader, LOSS_FN, optimizer, NUM_EPOCHS)

EPOCH 1
   Batch 0250   Loss: 2.2934
   Batch 0500   Loss: 2.2413
   Batch 0750   Loss: 1.6177
   Batch 1000   Loss: 0.9227
   Batch 1250   Loss: 0.8137
   Batch 1500   Loss: 0.7336
   Batch 1750   Loss: 0.7092
   Batch 2000   Loss: 0.6724
   Batch 2250   Loss: 0.6819
   Batch 2500   Loss: 0.6406
   Batch 2750   Loss: 0.6075
   Batch 3000   Loss: 0.5852
   Batch 3250   Loss: 0.5760
   Batch 3500   Loss: 0.5552
   Batch 3750   Loss: 0.5789
   Validation   Loss: 0.5684
EPOCH 2
   Batch 0250   Loss: 0.5487
   Batch 0500   Loss: 0.5582
   Batch 0750   Loss: 0.5320
   Batch 1000   Loss: 0.5063
   Batch 1250   Loss: 0.4930
   Batch 1500   Loss: 0.5061
   Batch 1750   Loss: 0.4823
   Batch 2000   Loss: 0.5065
   Batch 2250   Loss: 0.4883
   Batch 2500   Loss: 0.4762
   Batch 2750   Loss: 0.4581
   Batch 3000   Loss: 0.4775
   Batch 3250   Loss: 0.4628
   Batch 3500   Loss: 0.4401
   Batch 3750   Loss: 0.4477
   Validation   Loss: 0.4847
EPOCH 3
   Batch 0250   Loss: 0.4313
   Batch 0500   Los