# Lesson 29: PyTorch CIFAR-10 classifier activity

In this activity, you will build a deep neural network classifier for the CIFAR-10 dataset using PyTorch. The data loading and preparation code is provided. Your task is to:

1. **Define the model** - Build a DNN using `nn.Sequential`
2. **Train the model** - Write a training loop with validation tracking
3. **Evaluate the model** - Assess performance on the test set

## Notebook set-up

### Imports

In [None]:
# Standard library imports
from pathlib import Path

# Third party imports
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Set random seeds for reproducibility
torch.manual_seed(315)
np.random.seed(315)

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

### Hyperparameters

In [None]:
batch_size = 10000
learning_rate = 1e-2
epochs = 100
print_every = 10

## 1. Load and preprocess CIFAR-10 data

CIFAR-10 contains 32x32 color images across 10 classes. We convert the images to grayscale for this exercise.

### 1.1. Define transformations and class names

In [None]:
# Define class names
class_names = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

# Data preprocessing: convert to grayscale, tensor, and normalize
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

### 1.2. Load datasets

In [None]:
# Make sure data directory exists
data_dir = Path('./data')
data_dir.mkdir(parents=True, exist_ok=True)

# Load training and test datasets
train_dataset = datasets.CIFAR10(
    root=data_dir,
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.CIFAR10(
    root=data_dir,
    train=False,
    download=True,
    transform=transform
)

print(f'Training samples: {len(train_dataset)}')
print(f'Test samples: {len(test_dataset)}')
print(f'Image shape: {train_dataset[0][0].shape}')
print(f'Number of classes: {len(class_names)}')

### 1.3. Pre-load data and create data loaders

In [None]:
# Pre-load entire dataset to device for faster training
X_train_full = torch.stack([img for img, _ in train_dataset]).to(device)
y_train_full = torch.tensor([label for _, label in train_dataset]).to(device)

X_test = torch.stack([img for img, _ in test_dataset]).to(device)
y_test = torch.tensor([label for _, label in test_dataset]).to(device)

# Split training data into train and validation sets (80/20 split)
n_train = int(0.8 * len(X_train_full))
indices = torch.randperm(len(X_train_full))

X_train = X_train_full[indices[:n_train]]
y_train = y_train_full[indices[:n_train]]
X_val = X_train_full[indices[n_train:]]
y_val = y_train_full[indices[n_train:]]

print(f'X_train shape: {X_train.shape}, device: {X_train.device}')
print(f'y_train shape: {y_train.shape}, device: {y_train.device}')
print(f'X_val shape: {X_val.shape}, device: {X_val.device}')
print(f'y_val shape: {y_val.shape}, device: {y_val.device}')
print(f'X_test shape: {X_test.shape}, device: {X_test.device}')
print(f'y_test shape: {y_test.shape}, device: {y_test.device}')

In [None]:
# Create TensorDatasets
train_tensor_dataset = torch.utils.data.TensorDataset(X_train, y_train)
val_tensor_dataset = torch.utils.data.TensorDataset(X_val, y_val)
test_tensor_dataset = torch.utils.data.TensorDataset(X_test, y_test)

# Create DataLoaders
train_loader = DataLoader(
    train_tensor_dataset,
    batch_size=batch_size,
    shuffle=True
)

val_loader = DataLoader(
    val_tensor_dataset,
    batch_size=batch_size,
    shuffle=False
)

test_loader = DataLoader(
    test_tensor_dataset,
    batch_size=batch_size,
    shuffle=False
)

print(f'Training batches: {len(train_loader)}')
print(f'Validation batches: {len(val_loader)}')
print(f'Test batches: {len(test_loader)}')

### 1.4. Visualize sample images

In [None]:
# Get a batch of training images
images, labels = next(iter(train_loader))

# Plot first 10 images
ncols = 5
nrows = 2

fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*1.5, nrows*1.5))
axes = axes.flatten()

for i, ax in enumerate(axes):

    img = images[i].cpu() * 0.5 + 0.5
    img = img.numpy().squeeze()
    ax.set_title(class_names[labels[i]])
    ax.imshow(img, cmap='gray')
    ax.axis('off')

plt.tight_layout()
plt.show()

## 2. Build DNN classifier

### Task 1: Define model architecture

Build a fully connected neural network using `nn.Sequential` to classify CIFAR-10 images.

**Requirements:**
- Flatten the input images (32x32x1 = 1024 features)
- Use at least 2 hidden layers with ReLU activation
- Add dropout for regularization
- Output layer should have 10 units (one per class)

**Hints:**
- Use `nn.Flatten()` as the first layer to convert images to vectors
- Use `nn.Linear(in_features, out_features)` for fully connected layers
- Use `nn.ReLU()` for activation functions
- Use `nn.Dropout(p)` for regularization (e.g., p=0.2)
- Don't forget to move the model to the device with `.to(device)`
- `nn.CrossEntropyLoss` applies softmax internally, so no activation needed on output

In [None]:
# TODO: Define your model architecture
input_size = 32 * 32 * 1  # Grayscale image flattened
num_classes = 10

model = nn.Sequential(
    # Add your layers here
).to(device)

print(model)

### Task 2: Define loss function and optimizer

**Hints:**
- Use `nn.CrossEntropyLoss()` for multi-class classification
- Use `optim.Adam(model.parameters(), lr=learning_rate)` as the optimizer

In [None]:
# TODO: Define loss function and optimizer
criterion = None  # Replace with loss function
optimizer = None  # Replace with optimizer

## 3. Training

### Task 3: Write training function

Write a training loop that trains the model and tracks both training and validation metrics.

**Requirements:**
- Iterate over epochs
- For each epoch, iterate over batches in the training loader
- Track training loss and accuracy
- After training, evaluate on validation set (without gradient computation)
- Track validation loss and accuracy
- Return a history dictionary with all metrics

**Hints:**
- Use `model.train()` before training and `model.eval()` before validation
- Use `optimizer.zero_grad()` to clear gradients before each batch
- Use `loss.backward()` for backpropagation
- Use `optimizer.step()` to update weights
- Use `torch.no_grad()` context manager for validation
- Use `torch.max(outputs, 1)` to get predictions from logits
- Accuracy = 100 * correct / total

In [None]:
def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    epochs: int = 10,
    print_every: int = 1
) -> dict[str, list[float]]:
    '''Training loop for PyTorch classification model.
    
    TODO: Implement the training loop with:
    1. Training phase - iterate over train_loader batches
    2. Validation phase - evaluate on val_loader after each epoch
    3. Track and return history of train_loss, val_loss, train_accuracy, val_accuracy
    '''

    history = {'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': []}

    for epoch in range(epochs):

        # TODO: Training phase
        # - Set model to training mode
        # - Loop over batches in train_loader
        # - For each batch: zero gradients, forward pass, compute loss, backward pass, update weights
        # - Track running loss and accuracy
        pass

        # TODO: Validation phase
        # - Set model to evaluation mode
        # - Use torch.no_grad() context
        # - Loop over batches in val_loader
        # - Compute loss and accuracy (no gradient computation needed)
        pass

        # TODO: Record metrics in history dict

        # TODO: Print progress

    print('\nTraining complete.')

    return history

### Task 4: Train the model

In [None]:
# TODO: Call your training function
history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    epochs=epochs,
    print_every=print_every
)

### Task 5: Plot learning curves

**Hints:**
- Create a 1x2 subplot for loss and accuracy
- Plot both training and validation metrics on each subplot
- Add legends to distinguish the curves

In [None]:
# TODO: Plot learning curves
# - Left plot: training and validation loss over epochs
# - Right plot: training and validation accuracy over epochs

## 4. Evaluate model on test set

### Task 6: Write evaluation function

Write a function to evaluate the model on the test set and return accuracy and predictions.

**Hints:**
- Set model to evaluation mode with `model.eval()`
- Use `torch.no_grad()` context
- Iterate over test_loader and accumulate predictions
- Return accuracy, predictions array, and true labels array

In [None]:
def evaluate_model(
    model: nn.Module,
    test_loader: DataLoader
) -> tuple[float, np.ndarray, np.ndarray]:
    '''Evaluate model on test set.
    
    TODO: Implement evaluation logic
    Returns: (accuracy, predictions, true_labels)
    '''

    # TODO: Implement evaluation
    pass


# TODO: Call your evaluation function and print test accuracy
# test_accuracy, predictions, true_labels = evaluate_model(model, test_loader)
# print(f'Test accuracy: {test_accuracy:.2f}%')