# GPU Training

**Module 2.3, Lesson 2** | CourseAI

In this notebook you will move your MNIST training loop to GPU, measure the speedup, add device-aware checkpointing, and use mixed precision to squeeze even more speed out of the hardware.

**What you'll do:**
- Move the MNIST model and training loop to GPU, time it, and compare to CPU
- Add device-aware checkpointing â€” save during GPU training, load on CPU with `map_location`
- Add mixed precision (`autocast` + `GradScaler`) to the GPU training loop and compare speed
- Write a complete, portable training script with device detection, GPU, mixed precision, and device-aware checkpoints

**For each exercise, PREDICT the output before running the cell.**

**Estimated time:** 20-30 minutes.

**Requirements:** This notebook needs a GPU. In Colab: Runtime -> Change runtime type -> T4 GPU.

---

## Setup

Run this cell to import everything and configure the environment.

In [None]:
import time
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Reproducible results
torch.manual_seed(42)

# For nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]

# Verify GPU is available
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'GPU memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')
else:
    print('WARNING: No GPU detected. Go to Runtime â†’ Change runtime type â†’ T4 GPU.')

## Shared Setup: Model and Data

The same MNISTClassifier and data loading you have used in previous lessons. We define them once and reuse across all exercises.

In [None]:
# Data loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=64, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=64, shuffle=False
)

print(f'Training samples: {len(train_dataset)}')
print(f'Test samples: {len(test_dataset)}')

In [None]:
class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

print(f'Parameters: {sum(p.numel() for p in MNISTClassifier().parameters()):,}')

In [None]:
def train_one_epoch(model, train_loader, optimizer, criterion, device):
    """Train for one epoch. Returns average loss."""
    model.train()
    running_loss = 0.0
    n_batches = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        n_batches += 1

    return running_loss / n_batches


def evaluate(model, test_loader, device):
    """Evaluate model on test set. Returns (loss, accuracy)."""
    model.eval()
    criterion = nn.CrossEntropyLoss()
    correct = 0
    total = 0
    running_loss = 0.0
    n_batches = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            n_batches += 1
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return running_loss / n_batches, 100.0 * correct / total

print('Helper functions defined.')

---

## Exercise 1: Move to GPU and Time It (Supported)

The simplest GPU training experiment: take the training loop you already know and add three lines of device placement. Then time both CPU and GPU training to measure the speedup.

**Before running, predict:** How much faster will GPU training be for this small MNIST model? 2x? 5x? 10x? Will it always be faster?

**Task:**
1. Train on **CPU** for 3 epochs and record the wall-clock time
2. Train on **GPU** for 3 epochs and record the wall-clock time
3. Compare the two times

The CPU version is provided. Fill in the GPU version â€” you need to add device placement.

In [None]:
# ===== CPU Training (provided) =====
print('=== Training on CPU ===')

cpu_device = torch.device('cpu')
cpu_model = MNISTClassifier().to(cpu_device)
cpu_optimizer = optim.Adam(cpu_model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

cpu_start = time.time()

for epoch in range(3):
    train_loss = train_one_epoch(cpu_model, train_loader, cpu_optimizer, criterion, cpu_device)
    val_loss, val_acc = evaluate(cpu_model, test_loader, cpu_device)
    print(f'Epoch {epoch+1}/3 | Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.2f}%')

cpu_time = time.time() - cpu_start
print(f'\nCPU training time: {cpu_time:.1f}s')

In [None]:
# ===== GPU Training (fill in the blanks) =====
print('=== Training on GPU ===')

# Step 1: Create the device
gpu_device = ____  # FILL IN: torch.device('cuda')

# Step 2: Create model and move to GPU
gpu_model = ____  # FILL IN: MNISTClassifier().to(gpu_device)
gpu_optimizer = optim.Adam(gpu_model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

gpu_start = time.time()

for epoch in range(3):
    # Note: train_one_epoch already moves batches to the device â€”
    # check the helper function above to see how
    train_loss = train_one_epoch(gpu_model, train_loader, gpu_optimizer, criterion, gpu_device)
    val_loss, val_acc = evaluate(gpu_model, test_loader, gpu_device)
    print(f'Epoch {epoch+1}/3 | Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.2f}%')

# Sync GPU before stopping timer (GPU operations are async)
torch.cuda.synchronize()
gpu_time = time.time() - gpu_start
print(f'\nGPU training time: {gpu_time:.1f}s')

In [None]:
# Compare the results
speedup = cpu_time / gpu_time

print('=' * 50)
print(f'{"Device":<10} {"Time":>10} {"Speedup":>10}')
print('-' * 50)
print(f'{"CPU":<10} {cpu_time:>9.1f}s {"1.0x":>10}')
print(f'{"GPU":<10} {gpu_time:>9.1f}s {speedup:>9.1f}x')
print('=' * 50)

if speedup > 1:
    print(f'\nGPU is {speedup:.1f}x faster.')
else:
    print(f'\nGPU is slower! Transfer overhead dominates for this model/data size.')
    print('This is expected for small models â€” GPU shines with larger models.')

<details>
<summary>ðŸ’¡ Solution</summary>

**The key insight:** Moving to GPU requires exactly two things â€” place the model on the device once, and place each batch on the device in the loop. Our helper function already handles the batch placement, so you only need to set up the device and model.

```python
gpu_device = torch.device('cuda')
gpu_model = MNISTClassifier().to(gpu_device)
```

**Key points:**
- The model moves to GPU once with `.to(gpu_device)`.
- Each batch of data moves to GPU inside the training loop â€” our `train_one_epoch` helper already does this with `images.to(device)` and `labels.to(device)`.
- `torch.cuda.synchronize()` ensures all GPU operations finish before we stop the timer. Without it, the timer stops while the GPU is still computing.
- For MNIST (235K parameters), you should see a 3-5x speedup on a T4 GPU. Larger models see larger speedups.

</details>

---

## Exercise 2: Device-Aware Checkpointing (Supported)

You learned checkpointing in Saving, Loading, and Checkpoints. Now your tensors are on GPU. The challenge: save a checkpoint during GPU training, then load it as if you were on a CPU-only machine.

**Before running, predict:** If you save a model trained on GPU, then try to load it on a CPU-only machine without `map_location`, what happens?

**Task:**
1. Train on GPU for 5 epochs, saving a checkpoint at the end
2. Load the checkpoint using `map_location` to force everything to CPU
3. Verify the loaded model produces the same predictions as the GPU model (on CPU)

**Hints:**
- Save a checkpoint dict with keys: `model_state_dict`, `optimizer_state_dict`, `epoch`, `loss`
- Use `map_location=torch.device('cpu')` when loading to simulate a CPU-only machine
- Use `torch.allclose()` to verify outputs match

In [None]:
os.makedirs('saved_models', exist_ok=True)

# Train on GPU for 5 epochs
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_2 = MNISTClassifier().to(device)
optimizer_2 = optim.Adam(model_2.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

print(f'Training on {device} for 5 epochs...')
for epoch in range(5):
    train_loss = train_one_epoch(model_2, train_loader, optimizer_2, criterion, device)
    val_loss, val_acc = evaluate(model_2, test_loader, device)
    print(f'Epoch {epoch+1}/5 | Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.2f}%')

# TODO: Save a checkpoint to 'saved_models/gpu_checkpoint.pth'
# The checkpoint dict should have keys:
#   'model_state_dict', 'optimizer_state_dict', 'epoch', 'loss'



print(f'\nCheckpoint saved. File size: {os.path.getsize("saved_models/gpu_checkpoint.pth") / 1024:.1f} KB')

In [None]:
# Load the checkpoint on CPU (simulating a machine with no GPU)
cpu_device = torch.device('cpu')

# TODO: Load the checkpoint with map_location=cpu_device
# checkpoint = torch.load(...)


# Create a new model on CPU and load the state dict
loaded_model = MNISTClassifier().to(cpu_device)

# TODO: Load the model state dict from the checkpoint
# loaded_model.load_state_dict(...)


print(f'Loaded from epoch {checkpoint["epoch"] + 1}, loss: {checkpoint["loss"]:.4f}')
print(f'Model is on: {next(loaded_model.parameters()).device}')

In [None]:
# Verify: run both models on the same test batch and compare
test_images, test_labels = next(iter(test_loader))

# GPU model prediction
model_2.eval()
with torch.no_grad():
    gpu_outputs = model_2(test_images.to(device)).cpu()

# CPU model prediction
loaded_model.eval()
with torch.no_grad():
    cpu_outputs = loaded_model(test_images.to(cpu_device))

# Compare
match = torch.allclose(gpu_outputs, cpu_outputs, atol=1e-5)
print(f'Predictions match: {match}')

if match:
    print('Device-aware checkpointing works!')
    print('You saved on GPU and loaded on CPU â€” the checkpoint is portable.')
else:
    print('Predictions do NOT match. Check your map_location setting.')

<details>
<summary>ðŸ’¡ Solution</summary>

**The key insight:** `map_location` remaps every tensor to the specified device during loading. Without it, `torch.load()` puts tensors back on `cuda:0` â€” crashing on CPU-only machines.

**Saving the checkpoint:**
```python
checkpoint = {
    'model_state_dict': model_2.state_dict(),
    'optimizer_state_dict': optimizer_2.state_dict(),
    'epoch': 4,
    'loss': val_loss,
}
torch.save(checkpoint, 'saved_models/gpu_checkpoint.pth')
```

**Loading with map_location:**
```python
checkpoint = torch.load(
    'saved_models/gpu_checkpoint.pth',
    map_location=cpu_device,
    weights_only=False
)
loaded_model.load_state_dict(checkpoint['model_state_dict'])
```

**Why `weights_only=False`:** Our checkpoint dict contains non-tensor metadata (epoch number, loss value). `weights_only=True` would reject these. For pure state_dicts with no metadata, use `weights_only=True`.

</details>

---

## Exercise 3: Add Mixed Precision (Supported)

Mixed precision uses float16 for the forward pass (where values are large) and float32 for gradient accumulation (where values can be tiny). PyTorch handles this automatically with two tools:

- **`torch.amp.autocast(device_type='cuda')`** â€” wraps the forward pass, choosing float16 where safe
- **`torch.amp.GradScaler()`** â€” scales the loss up before backward (to prevent float16 underflow), then scales gradients back down before the optimizer step

**Before running, predict:** Will mixed precision noticeably speed up this small MNIST model? Will accuracy change?

**Task:**
1. Train on GPU for 3 epochs **without** mixed precision â€” time it
2. Train on GPU for 3 epochs **with** mixed precision â€” time it
3. Compare speed and final accuracy

The non-mixed-precision version is provided. Add mixed precision to the second version.

**Hints:**
- Create `scaler = torch.amp.GradScaler()` before the loop
- Wrap forward + loss in `with torch.amp.autocast(device_type='cuda'):`
- Replace `loss.backward()` with `scaler.scale(loss).backward()`
- Replace `optimizer.step()` with `scaler.step(optimizer)`
- Add `scaler.update()` after the step

In [None]:
# ===== Baseline: GPU without mixed precision (provided) =====
print('=== GPU Training (float32) ===')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_fp32 = MNISTClassifier().to(device)
optimizer_fp32 = optim.Adam(model_fp32.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

torch.cuda.synchronize()
fp32_start = time.time()

for epoch in range(3):
    train_loss = train_one_epoch(model_fp32, train_loader, optimizer_fp32, criterion, device)
    val_loss, val_acc = evaluate(model_fp32, test_loader, device)
    print(f'Epoch {epoch+1}/3 | Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.2f}%')

torch.cuda.synchronize()
fp32_time = time.time() - fp32_start
fp32_acc = val_acc
print(f'\nfloat32 time: {fp32_time:.1f}s | Final acc: {fp32_acc:.2f}%')

In [None]:
# ===== Mixed Precision: GPU with autocast + GradScaler =====
print('=== GPU Training (mixed precision) ===')

model_amp = MNISTClassifier().to(device)
optimizer_amp = optim.Adam(model_amp.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# TODO: Create a GradScaler
# scaler = ...


torch.cuda.synchronize()
amp_start = time.time()

for epoch in range(3):
    model_amp.train()
    running_loss = 0.0
    n_batches = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # TODO: Wrap the forward pass and loss computation in autocast
        # with torch.amp.autocast(device_type='cuda'):
        outputs = model_amp(images)
        loss = criterion(outputs, labels)

        # TODO: Replace the standard backward/step with scaled versions
        optimizer_amp.zero_grad()
        loss.backward()       # Replace with: scaler.scale(loss).backward()
        optimizer_amp.step()  # Replace with: scaler.step(optimizer_amp)
        # Add: scaler.update()

        running_loss += loss.item()
        n_batches += 1

    avg_loss = running_loss / n_batches
    val_loss, val_acc = evaluate(model_amp, test_loader, device)
    print(f'Epoch {epoch+1}/3 | Train Loss: {avg_loss:.4f} | Val Acc: {val_acc:.2f}%')

torch.cuda.synchronize()
amp_time = time.time() - amp_start
amp_acc = val_acc
print(f'\nMixed precision time: {amp_time:.1f}s | Final acc: {amp_acc:.2f}%')

In [None]:
# Compare results
amp_speedup = fp32_time / amp_time

print('=' * 55)
print(f'{"Method":<20} {"Time":>10} {"Accuracy":>10} {"Speedup":>10}')
print('-' * 55)
print(f'{"GPU (float32)":<20} {fp32_time:>9.1f}s {fp32_acc:>9.2f}% {"1.0x":>10}')
print(f'{"GPU (mixed prec.)":<20} {amp_time:>9.1f}s {amp_acc:>9.2f}% {amp_speedup:>9.1f}x')
print('=' * 55)

acc_diff = abs(amp_acc - fp32_acc)
print(f'\nAccuracy difference: {acc_diff:.2f} percentage points')

if acc_diff < 0.5:
    print('Accuracy is essentially the same â€” mixed precision is a free speedup.')
else:
    print('Accuracy differs â€” check your implementation.')

if amp_speedup > 1:
    print(f'Mixed precision is {amp_speedup:.1f}x faster.')
else:
    print('Mixed precision is not faster for this model size â€” speedup is more noticeable with larger models.')

<details>
<summary>ðŸ’¡ Solution</summary>

**The key insight:** Mixed precision changes only four lines in your training loop. `autocast` picks float16 where safe (forward pass), and `GradScaler` prevents the tiny float16 gradients from underflowing to zero.

```python
# Create GradScaler
scaler = torch.amp.GradScaler()

# Inside the training loop:
for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)

    # Wrap forward + loss in autocast
    with torch.amp.autocast(device_type='cuda'):
        outputs = model_amp(images)
        loss = criterion(outputs, labels)

    # Scaled backward + step
    optimizer_amp.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer_amp)
    scaler.update()
```

**Four lines changed:**
1. `scaler = torch.amp.GradScaler()` â€” creates the scaler
2. `with torch.amp.autocast(device_type='cuda'):` â€” wraps forward pass in auto-casting
3. `scaler.scale(loss).backward()` â€” scales loss up to prevent float16 underflow
4. `scaler.step(optimizer_amp)` + `scaler.update()` â€” unscales gradients, steps optimizer, adjusts scale

**On MNIST:** The speedup may be modest (1.0-1.3x) because the model is small. On larger models (millions of parameters), mixed precision typically gives 1.5-3x speedup and uses ~50% less GPU memory.

</details>

---

## Exercise 4: Complete Portable Training Script (Independent)

Put everything together: write a complete training script from scratch that:

1. **Detects the device** â€” GPU if available, CPU if not
2. **Trains on GPU** with the device-aware training loop
3. **Uses mixed precision** if on GPU (skip if on CPU)
4. **Checkpoints** with device portability (`map_location`)
5. **Tracks the best model** by validation accuracy
6. **Times the training** and reports results

This is the production-ready pattern you carry forward to every future project. Write it from memory â€” the building blocks are all in the exercises above.

**Specifications:**
- Train for 10 epochs
- Save best model checkpoint to `'saved_models/portable_best.pth'`
- Print timing, device, and final accuracy when done
- The script should work correctly on both GPU and CPU machines

In [None]:
# YOUR CODE HERE
# Write the complete portable training script.
#
# Structure:
#   1. Device detection
#   2. Data loading (reuse train_loader and test_loader)
#   3. Model + optimizer + criterion
#   4. Mixed precision setup (only if GPU)
#   5. Training loop with:
#      - Device-aware batch placement
#      - Mixed precision forward/backward (if GPU)
#      - Validation after each epoch
#      - Best model checkpointing
#   6. Report: device, time, best accuracy




In [None]:
# Verify: load the best checkpoint on CPU and evaluate
cpu_device = torch.device('cpu')
verify_model = MNISTClassifier().to(cpu_device)

checkpoint = torch.load(
    'saved_models/portable_best.pth',
    map_location=cpu_device,
    weights_only=False
)
verify_model.load_state_dict(checkpoint['model_state_dict'])

_, verify_acc = evaluate(verify_model, test_loader, cpu_device)
print(f'Loaded checkpoint from epoch {checkpoint["epoch"] + 1}')
print(f'Accuracy on CPU: {verify_acc:.2f}%')
print('\nPortable checkpoint verified â€” saved on GPU, loaded on CPU.')

<details>
<summary>ðŸ’¡ Solution</summary>

**The key insight:** This is the portable pattern â€” detect the device once, conditionally enable mixed precision, and use `map_location` for checkpoint portability. One script runs everywhere.

```python
import time
import os

os.makedirs('saved_models', exist_ok=True)

# 1. Device detection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
use_amp = device.type == 'cuda'
print(f'Training on: {device} | Mixed precision: {use_amp}')

# 2. Model, optimizer, criterion
model = MNISTClassifier().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# 3. Mixed precision setup (only if GPU)
scaler = torch.amp.GradScaler() if use_amp else None

# 4. Training loop
num_epochs = 10
best_acc = 0.0

start_time = time.time()

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    n_batches = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        if use_amp:
            with torch.amp.autocast(device_type='cuda'):
                outputs = model(images)
                loss = criterion(outputs, labels)
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        running_loss += loss.item()
        n_batches += 1

    avg_loss = running_loss / n_batches
    val_loss, val_acc = evaluate(model, test_loader, device)

    # Checkpoint best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
            'accuracy': val_acc,
        }, 'saved_models/portable_best.pth')

    print(f'Epoch {epoch+1:2d}/{num_epochs} | '
          f'Loss: {avg_loss:.4f} | Acc: {val_acc:.2f}%'
          f'{" <- best" if val_acc >= best_acc else ""}')

if device.type == 'cuda':
    torch.cuda.synchronize()
elapsed = time.time() - start_time

print(f'\nTraining complete on {device}')
print(f'Time: {elapsed:.1f}s')
print(f'Best accuracy: {best_acc:.2f}%')
```

**The portable pattern:**
1. Detect device once at the top
2. Set `use_amp = device.type == 'cuda'` â€” mixed precision only on GPU
3. Create GradScaler conditionally
4. Branch the forward/backward in the inner loop based on `use_amp`
5. Always use `map_location=device` when loading checkpoints
6. `torch.cuda.synchronize()` before timing (only if on GPU)

This exact pattern works on any machine â€” GPU or CPU â€” with no code changes.

</details>

---

## Key Takeaways

1. **GPU training = same loop + 3 lines of device placement.** Move the model once, move each batch inside the loop.
2. **`torch.cuda.synchronize()`** before timing â€” GPU operations are asynchronous.
3. **`map_location=device`** when loading checkpoints makes them portable across machines.
4. **Mixed precision** (`autocast` + `GradScaler`) uses float16 for speed and float32 for precision. 4 lines of change.
5. **GPU wins at scale.** Small models may not benefit. Larger models see 3-10x speedups.
6. **The portable pattern** detects the device and conditionally uses mixed precision â€” one script runs everywhere.

**Clean up** (optional): delete the `saved_models/` directory when you are done experimenting.

In [None]:
# Optional: clean up saved files
# import shutil
# shutil.rmtree('saved_models', ignore_errors=True)
# print('Cleaned up saved_models/')