# ML Lab 07: Train a Neural Network (Solution)

This is the completed solution notebook with all cells filled in and expected outputs documented.
Use this as a reference if you get stuck on the main notebook.

---
## Section 1: Why scikit-learn Isn't Enough

We load CIFAR-10, flatten the images, and show that logistic regression gets ~40% accuracy.

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cpu')
print(f"Using device: {device}")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)

CLASSES = ('airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

print(f"Training images: {len(train_dataset)}")
print(f"Test images:     {len(test_dataset)}")
print(f"Image shape:     {train_dataset[0][0].shape}")
print(f"Classes:         {CLASSES}")

In [None]:
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i, ax in enumerate(axes.flat):
    image, label = train_dataset[i]
    image = image * 0.5 + 0.5
    ax.imshow(image.permute(1, 2, 0).numpy())
    ax.set_title(CLASSES[label], fontsize=9)
    ax.axis('off')
plt.suptitle('CIFAR-10 Sample Images', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

raw_transform = transforms.Compose([transforms.ToTensor()])
raw_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=raw_transform)
raw_test = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=raw_transform)

n_train = 5000
n_test = 1000

X_train_flat = torch.stack([raw_train[i][0] for i in range(n_train)]).reshape(n_train, -1).numpy()
y_train_flat = np.array([raw_train[i][1] for i in range(n_train)])
X_test_flat = torch.stack([raw_test[i][0] for i in range(n_test)]).reshape(n_test, -1).numpy()
y_test_flat = np.array([raw_test[i][1] for i in range(n_test)])

print(f"Flattened shape: {X_train_flat.shape}")
print("Training logistic regression on raw pixels...")

clf = LogisticRegression(max_iter=1000, random_state=42, solver='saga')
clf.fit(X_train_flat, y_train_flat)

sklearn_acc = accuracy_score(y_test_flat, clf.predict(X_test_flat))
print(f"\nscikit-learn accuracy on CIFAR-10: {sklearn_acc:.3f}")
print(f"Conclusion: {sklearn_acc*100:.0f}% is better than random, but terrible.")
print("scikit-learn can't learn spatial patterns from raw pixels.")

---
## Section 2: Your First Neural Network

In [None]:
import torch.nn as nn


class FeedforwardNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.layers = nn.Sequential(
            nn.Linear(3 * 32 * 32, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        return self.layers(x)


model = FeedforwardNet().to(device)
print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,}")

---
## Section 3: The Training Loop

In [None]:
from torch.utils.data import DataLoader
import time

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

model = FeedforwardNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

NUM_EPOCHS = 10

history = {
    'train_loss': [],
    'train_acc': [],
    'test_acc': [],
    'batch_losses': [],
}

start_time = time.time()

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        history['batch_losses'].append(loss.item())
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss = running_loss / len(train_loader)
    train_acc = correct / total
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)

    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()

    test_acc = test_correct / test_total
    history['test_acc'].append(test_acc)

    elapsed = time.time() - start_time
    print(f"Epoch {epoch+1:2d}/{NUM_EPOCHS} | "
          f"Loss: {train_loss:.3f} | "
          f"Train Acc: {train_acc*100:.1f}% | "
          f"Test Acc: {test_acc*100:.1f}% | "
          f"Time: {elapsed:.0f}s")

print(f"\nFinal test accuracy: {history['test_acc'][-1]*100:.1f}%")

---
## Section 4: Watch It Learn

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(history['batch_losses'], alpha=0.3, color='blue', linewidth=0.5)
window = 50
if len(history['batch_losses']) > window:
    smoothed = np.convolve(history['batch_losses'], np.ones(window)/window, mode='valid')
    axes[0].plot(range(window-1, len(history['batch_losses'])), smoothed, color='red', linewidth=2, label='Smoothed')
axes[0].set_xlabel('Batch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss (per batch)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

epochs = range(1, NUM_EPOCHS + 1)
axes[1].plot(epochs, history['train_loss'], 'o-', linewidth=2, markersize=6, label='Train Loss')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title('Training Loss (per epoch)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

axes[2].plot(epochs, [a*100 for a in history['train_acc']], 'o-', linewidth=2, markersize=6, label='Train Acc')
axes[2].plot(epochs, [a*100 for a in history['test_acc']], 's-', linewidth=2, markersize=6, label='Test Acc')
axes[2].axhline(y=sklearn_acc*100, color='gray', linestyle='--', alpha=0.5, label=f'sklearn ({sklearn_acc*100:.0f}%)')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Accuracy (%)')
axes[2].set_title('Train vs Test Accuracy')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

gap = history['train_acc'][-1] - history['test_acc'][-1]
print(f"Train-Test accuracy gap: {gap*100:.1f}%")

In [None]:
model.eval()
fig, axes = plt.subplots(3, 8, figsize=(18, 7))

test_iter = iter(test_loader)
images, labels = next(test_iter)
images, labels = images.to(device), labels.to(device)

with torch.no_grad():
    outputs = model(images)
    _, predicted = outputs.max(1)

for i, ax in enumerate(axes.flat):
    if i >= len(images):
        break
    img = images[i].cpu() * 0.5 + 0.5
    ax.imshow(img.permute(1, 2, 0).numpy())
    true_label = CLASSES[labels[i].item()]
    pred_label = CLASSES[predicted[i].item()]
    correct = labels[i].item() == predicted[i].item()
    color = 'green' if correct else 'red'
    ax.set_title(f"P:{pred_label}\nT:{true_label}", fontsize=8, color=color)
    ax.axis('off')

plt.suptitle('Predictions (green=correct, red=wrong)', fontsize=14)
plt.tight_layout()
plt.show()

---
## Section 5: Improve with a CNN

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)


cnn_model = SimpleCNN().to(device)
print(cnn_model)
cnn_params = sum(p.numel() for p in cnn_model.parameters())
print(f"\nCNN parameters: {cnn_params:,}")

In [None]:
cnn_model = SimpleCNN().to(device)
cnn_criterion = nn.CrossEntropyLoss()
cnn_optimizer = torch.optim.SGD(cnn_model.parameters(), lr=0.01)

cnn_history = {
    'train_loss': [],
    'train_acc': [],
    'test_acc': [],
}

start_time = time.time()

for epoch in range(NUM_EPOCHS):
    cnn_model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        cnn_optimizer.zero_grad()
        outputs = cnn_model(images)
        loss = cnn_criterion(outputs, labels)
        loss.backward()
        cnn_optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss = running_loss / len(train_loader)
    train_acc = correct / total
    cnn_history['train_loss'].append(train_loss)
    cnn_history['train_acc'].append(train_acc)

    cnn_model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = cnn_model(images)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()

    test_acc = test_correct / test_total
    cnn_history['test_acc'].append(test_acc)

    elapsed = time.time() - start_time
    print(f"Epoch {epoch+1:2d}/{NUM_EPOCHS} | "
          f"Loss: {train_loss:.3f} | "
          f"Train Acc: {train_acc*100:.1f}% | "
          f"Test Acc: {test_acc*100:.1f}% | "
          f"Time: {elapsed:.0f}s")

print(f"\nCNN training complete.")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

epochs = range(1, NUM_EPOCHS + 1)

axes[0].plot(epochs, [a*100 for a in history['test_acc']], 'o-', linewidth=2, label='Feedforward NN')
axes[0].plot(epochs, [a*100 for a in cnn_history['test_acc']], 's-', linewidth=2, label='CNN')
axes[0].axhline(y=sklearn_acc*100, color='gray', linestyle='--', alpha=0.7, label=f'sklearn LogReg ({sklearn_acc*100:.0f}%)')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Test Accuracy (%)')
axes[0].set_title('Test Accuracy: sklearn vs Feedforward vs CNN')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(epochs, history['train_loss'], 'o-', linewidth=2, label='Feedforward NN')
axes[1].plot(epochs, cnn_history['train_loss'], 's-', linewidth=2, label='CNN')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Training Loss')
axes[1].set_title('Training Loss: Feedforward vs CNN')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final test accuracy comparison:")
print(f"  scikit-learn LogReg:  {sklearn_acc*100:.1f}%")
print(f"  Feedforward NN:       {history['test_acc'][-1]*100:.1f}%")
print(f"  CNN:                  {cnn_history['test_acc'][-1]*100:.1f}%")

---
## Section 6: Checkpointing

In [None]:
import os

checkpoint_path = 'cnn_checkpoint.pt'

checkpoint = {
    'epoch': NUM_EPOCHS,
    'model_state_dict': cnn_model.state_dict(),
    'optimizer_state_dict': cnn_optimizer.state_dict(),
    'train_loss': cnn_history['train_loss'][-1],
    'test_acc': cnn_history['test_acc'][-1],
    'history': cnn_history,
}

torch.save(checkpoint, checkpoint_path)

size_mb = os.path.getsize(checkpoint_path) / (1024 * 1024)
print(f"Checkpoint saved: {size_mb:.2f} MB")
print(f"Epoch: {checkpoint['epoch']}, Test acc: {checkpoint['test_acc']*100:.1f}%")

In [None]:
loaded_checkpoint = torch.load(checkpoint_path, weights_only=False)

restored_model = SimpleCNN().to(device)
restored_optimizer = torch.optim.SGD(restored_model.parameters(), lr=0.01)

restored_model.load_state_dict(loaded_checkpoint['model_state_dict'])
restored_optimizer.load_state_dict(loaded_checkpoint['optimizer_state_dict'])
resume_epoch = loaded_checkpoint['epoch']

print(f"Checkpoint loaded! Resuming from epoch {resume_epoch}.")

cnn_model.eval()
restored_model.eval()

test_images, test_labels = next(iter(test_loader))
test_images = test_images.to(device)

with torch.no_grad():
    original_preds = cnn_model(test_images).argmax(dim=1)
    restored_preds = restored_model(test_images).argmax(dim=1)

match = (original_preds == restored_preds).all().item()
print(f"Predictions match: {match}")

In [None]:
print(f"Resuming training from epoch {resume_epoch}...")

restored_model.train()
for epoch in range(resume_epoch, resume_epoch + 2):
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        restored_optimizer.zero_grad()
        outputs = restored_model(images)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()
        restored_optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    restored_model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for imgs, lbls in test_loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            outs = restored_model(imgs)
            _, preds = outs.max(1)
            test_total += lbls.size(0)
            test_correct += preds.eq(lbls).sum().item()
    restored_model.train()

    print(f"Epoch {epoch+1} | "
          f"Loss: {running_loss/len(train_loader):.3f} | "
          f"Train Acc: {100.*correct/total:.1f}% | "
          f"Test Acc: {100.*test_correct/test_total:.1f}%")

print("\nTraining resumed seamlessly from the checkpoint.")

---
## Section 7: Experiment Tracking

In [None]:
import json
import pandas as pd


def run_experiment(name, model_class, lr, epochs, train_loader, test_loader):
    model = model_class().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    start = time.time()
    for epoch in range(epochs):
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    model.eval()
    train_correct = train_total = 0
    with torch.no_grad():
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            _, predicted = model(images).max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()

    test_correct = test_total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            _, predicted = model(images).max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()

    return {
        'experiment': name,
        'architecture': model_class.__name__,
        'lr': lr,
        'epochs': epochs,
        'params': sum(p.numel() for p in model.parameters()),
        'train_acc': round(train_correct / train_total, 4),
        'test_acc': round(test_correct / test_total, 4),
        'train_time_s': round(time.time() - start, 1),
    }


experiments = []
configs = [
    ('ff_lr0.01_ep5', FeedforwardNet, 0.01, 5),
    ('ff_lr0.05_ep5', FeedforwardNet, 0.05, 5),
    ('cnn_lr0.01_ep5', SimpleCNN, 0.01, 5),
    ('cnn_lr0.05_ep5', SimpleCNN, 0.05, 5),
]

for name, model_class, lr, epochs in configs:
    print(f"Running: {name}...")
    result = run_experiment(name, model_class, lr, epochs, train_loader, test_loader)
    experiments.append(result)
    print(f"  -> test_acc={result['test_acc']*100:.1f}% in {result['train_time_s']}s")

print(f"\nAll {len(experiments)} experiments complete!")

In [None]:
experiments_path = 'experiment_log.json'
with open(experiments_path, 'w') as f:
    json.dump(experiments, f, indent=2)

df = pd.DataFrame(experiments)
df = df.sort_values('test_acc', ascending=False)
print(f"{'='*80}")
print(f"EXPERIMENT RESULTS (sorted by test accuracy)")
print(f"{'='*80}")
print(df[['experiment', 'architecture', 'lr', 'epochs', 'params',
          'train_acc', 'test_acc', 'train_time_s']].to_string(index=False))
print(f"\nBest: {df.iloc[0]['experiment']} with {df.iloc[0]['test_acc']*100:.1f}% test accuracy")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

colors = ['#2196F3' if 'ff' in exp['experiment'] else '#FF9800' for exp in experiments]
axes[0].barh(range(len(experiments)), [e['test_acc']*100 for e in experiments], color=colors)
axes[0].set_yticks(range(len(experiments)))
axes[0].set_yticklabels([e['experiment'] for e in experiments])
axes[0].set_xlabel('Test Accuracy (%)')
axes[0].set_title('Test Accuracy by Experiment')
axes[0].axvline(x=sklearn_acc*100, color='gray', linestyle='--', alpha=0.5, label=f'sklearn ({sklearn_acc*100:.0f}%)')
axes[0].legend()

x_pos = range(len(experiments))
width = 0.35
axes[1].bar([x - width/2 for x in x_pos], [e['train_acc']*100 for e in experiments],
           width, label='Train Acc', alpha=0.8)
axes[1].bar([x + width/2 for x in x_pos], [e['test_acc']*100 for e in experiments],
           width, label='Test Acc', alpha=0.8)
axes[1].set_xticks(x_pos)
axes[1].set_xticklabels([e['experiment'] for e in experiments], rotation=45, ha='right')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Train vs Test Accuracy (gap = overfitting)')
axes[1].legend()

plt.tight_layout()
plt.show()

In [None]:
for path in ['cnn_checkpoint.pt', 'experiment_log.json']:
    if os.path.exists(path):
        os.remove(path)
        print(f"Cleaned up {path}")
print("Done!")

---
## Summary

| Concept | What You Learned |
|---------|------------------|
| **sklearn limits** | Flat feature vectors lose spatial structure -- bad for images |
| **Feedforward NN** | Stacked linear layers with ReLU activations; learns intermediate features |
| **Training loop** | forward -> loss -> backward -> step; the core of all deep learning |
| **Overfitting** | Train accuracy rises, test accuracy plateaus; watch the gap |
| **CNN** | Convolutional layers learn spatial patterns (edges, textures, shapes) |
| **Checkpointing** | Save model + optimizer state to resume training after failures |
| **Experiment tracking** | Log every run so you know what worked and what didn't |