# Saving and Loading Models

In this notebook, you'll practice saving and loading PyTorch models â€” the essential skill that lets you keep trained models, resume interrupted training, and implement early stopping.

**What you'll do:**
- Save and load a trained MNIST model, verifying predictions match
- Add checkpointing to a training loop
- Simulate a training crash and resume from a checkpoint
- Implement the full early stopping pattern

**For each exercise, PREDICT the output before running the cell.**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os

# Reproducible results
torch.manual_seed(42)

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# For nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]

## Setup: Train an MNIST Model

First we need a trained model to work with. This is the same MNISTClassifier architecture from your PyTorch training lesson: Flatten â†’ Linear(784, 256) â†’ ReLU â†’ Linear(256, 128) â†’ ReLU â†’ Linear(128, 10). We'll train it for 3 epochs so it's not random but also trains quickly.

In [None]:
# Load MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=64, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=64, shuffle=False
)

print(f'Training samples: {len(train_dataset)}')
print(f'Test samples: {len(test_dataset)}')

In [None]:
class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

print(f'Parameters: {sum(p.numel() for p in MNISTClassifier().parameters()):,}')

In [None]:
def train_one_epoch(model, train_loader, optimizer, criterion):
    """Train for one epoch. Returns average loss."""
    model.train()
    running_loss = 0.0
    n_batches = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        n_batches += 1

    return running_loss / n_batches


def evaluate(model, test_loader):
    """Evaluate model on test set. Returns (loss, accuracy)."""
    model.eval()
    criterion = nn.CrossEntropyLoss()
    correct = 0
    total = 0
    running_loss = 0.0
    n_batches = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            n_batches += 1
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return running_loss / n_batches, 100.0 * correct / total

In [None]:
# Train the model for 3 epochs
model = MNISTClassifier().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(3):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model, test_loader)
    print(f'Epoch {epoch+1}/3 | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')

print('\nModel trained. Ready for exercises.')

---

## Exercise 1: Save and Load a Trained Model (Supported)

The simplest form of saving: capture a model's learned weights (`state_dict`) to a file, then load them into a new model instance.

**Before running, predict:** After saving and loading the state dict into a brand-new model, will `torch.allclose()` return True or False? Will the file size be larger or smaller than 1 MB?

**Task:**
1. Save `model.state_dict()` to a file
2. Create a brand-new `MNISTClassifier` instance (with random weights)
3. Load the saved state dict into it
4. Run both models on the same test batch
5. Use `torch.allclose()` to verify the outputs match exactly

Fill in the blanks below.

In [None]:
# Create a directory for saved models
os.makedirs('saved_models', exist_ok=True)

# Step 1: Save the trained model's state dict
save_path = 'saved_models/mnist_classifier.pth'
____  # FILL IN: save model.state_dict() to save_path using torch.save()

# Print file size
file_size = os.path.getsize(save_path)
print(f'Saved model to {save_path}')
print(f'File size: {file_size / 1024:.1f} KB')

In [None]:
# Step 2: Create a new model with random weights
loaded_model = MNISTClassifier().to(device)

# Step 3: Load the saved state dict into the new model
____  # FILL IN: load the state dict from save_path and load it into loaded_model
# Hint: torch.load() returns the state dict, then use loaded_model.load_state_dict()

print('State dict loaded successfully.')

In [None]:
# Step 4: Get a test batch and run both models
test_images, test_labels = next(iter(test_loader))
test_images = test_images.to(device)

model.eval()
loaded_model.eval()

with torch.no_grad():
    original_outputs = model(test_images)
    loaded_outputs = loaded_model(test_images)

# Step 5: Verify they match
match = torch.allclose(original_outputs, loaded_outputs)

if match:
    print('Predictions match!')
else:
    print('Predictions DO NOT match. Something went wrong.')

# Show a few predictions side by side
print(f'\nFirst 5 predictions (original): {original_outputs.argmax(dim=1)[:5].tolist()}')
print(f'First 5 predictions (loaded):   {loaded_outputs.argmax(dim=1)[:5].tolist()}')

<details>
<summary>ðŸ’¡ Solution</summary>

**Why this works:** `state_dict()` captures every learned parameter as a plain dictionary. `torch.save()` serializes it, and `torch.load()` + `load_state_dict()` restores the exact same values â€” so outputs match exactly.

```python
# Step 1: Save
torch.save(model.state_dict(), save_path)

# Step 3: Load
state_dict = torch.load(save_path, map_location=device, weights_only=True)
loaded_model.load_state_dict(state_dict)
```

**Key points:**
- `model.state_dict()` returns a dictionary mapping layer names to their parameter tensors.
- `torch.save()` serializes any Python object (dicts, tensors, etc.) to a file.
- `torch.load()` deserializes it back. `map_location=device` ensures tensors go to the right device.
- `weights_only=True` is a security best practice â€” it prevents loading arbitrary Python objects.
- `.pth` is the conventional file extension for PyTorch saved files.

</details>

---

## Exercise 2: Add Checkpointing to a Training Loop (Supported)

Saving just the model weights is fine for deployment, but during training you need **checkpoints** â€” snapshots that include everything needed to resume: model weights, optimizer state, epoch number, and loss.

**Task:**
1. Add checkpoint saving every 5 epochs (`if (epoch+1) % 5 == 0`)
2. Track the best validation loss seen so far
3. Save the best model whenever validation loss improves
4. Each checkpoint should be a dict with keys: `'model_state_dict'`, `'optimizer_state_dict'`, `'epoch'`, `'loss'`

**Hints:**
- Initialize `best_loss = float('inf')` before the loop
- Check `if val_loss < best_loss:` after each epoch
- Use different filenames for periodic vs best checkpoints

In [None]:
# Fresh model and optimizer
model_2 = MNISTClassifier().to(device)
optimizer_2 = optim.Adam(model_2.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

n_epochs = 15
best_loss = float('inf')

for epoch in range(n_epochs):
    train_loss = train_one_epoch(model_2, train_loader, optimizer_2, criterion)
    val_loss, val_acc = evaluate(model_2, test_loader)

    print(f'Epoch {epoch+1}/{n_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%', end='')

    # TODO: Save checkpoint every 5 epochs
    # The checkpoint dict should have keys:
    #   'model_state_dict', 'optimizer_state_dict', 'epoch', 'loss'
    # Save to: f'saved_models/checkpoint_epoch_{epoch+1}.pth'

    # TODO: Save best model when val_loss improves
    # Save to: 'saved_models/best_model.pth'
    # Print ' <- best!' when saving

    print()  # newline

print(f'\nBest validation loss: {best_loss:.4f}')

# Show what files we created
print('\nSaved files:')
for f in sorted(os.listdir('saved_models')):
    size = os.path.getsize(f'saved_models/{f}')
    print(f'  {f} ({size / 1024:.1f} KB)')

<details>
<summary>ðŸ’¡ Solution</summary>

**The key insight:** A checkpoint is just a dictionary with everything needed to resume. The two separate saves (periodic + best) serve different purposes â€” periodic checkpoints let you go back to any point, the best checkpoint is what you deploy.

```python
    # Save checkpoint every 5 epochs
    if (epoch + 1) % 5 == 0:
        checkpoint = {
            'model_state_dict': model_2.state_dict(),
            'optimizer_state_dict': optimizer_2.state_dict(),
            'epoch': epoch,
            'loss': val_loss,
        }
        torch.save(checkpoint, f'saved_models/checkpoint_epoch_{epoch+1}.pth')
        print(f' [checkpoint saved]', end='')

    # Save best model when val_loss improves
    if val_loss < best_loss:
        best_loss = val_loss
        checkpoint = {
            'model_state_dict': model_2.state_dict(),
            'optimizer_state_dict': optimizer_2.state_dict(),
            'epoch': epoch,
            'loss': val_loss,
        }
        torch.save(checkpoint, 'saved_models/best_model.pth')
        print(' <- best!', end='')
```

**Why save the optimizer state?** The optimizer (Adam) maintains running averages of gradients (momentum). If you load just the model and create a fresh optimizer, those running averages reset, causing a spike in the loss curve. Saving the optimizer state preserves smooth training.

**Why save the epoch number?** So you know where you left off. When you resume, you start from `checkpoint['epoch'] + 1`, not from 0.

</details>

---

## Exercise 3: Simulate a Training Crash and Resume (Supported)

The real test of checkpointing: can you actually resume training and get a continuous loss curve? In this exercise you'll simulate a crash by training for 10 epochs, saving a checkpoint, then creating a completely fresh model and optimizer, loading the checkpoint, and continuing for 10 more epochs.

**Before running, predict:** Will the loss curve be smooth across the crash point, or will there be a visible jump? What would happen if you loaded only the model weights but NOT the optimizer state?

**Task:**
1. Train a fresh model for 10 epochs, saving a checkpoint at the end
2. Create a brand-new model and optimizer (simulating a fresh Python session)
3. Load the checkpoint â€” restore model weights, optimizer state, and epoch number
4. Print the epoch and loss from the checkpoint to confirm it loaded
5. Continue training for 10 more epochs
6. Plot the full 20-epoch loss curve and verify there's no discontinuity

**Hint:** The key to a smooth loss curve is loading *both* `model_state_dict` and `optimizer_state_dict`.

In [None]:
# ===== Phase 1: Train for 10 epochs =====
print('=== Phase 1: Training for 10 epochs ===')

model_3 = MNISTClassifier().to(device)
optimizer_3 = optim.Adam(model_3.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

losses_phase1 = []

for epoch in range(10):
    train_loss = train_one_epoch(model_3, train_loader, optimizer_3, criterion)
    val_loss, val_acc = evaluate(model_3, test_loader)
    losses_phase1.append(val_loss)
    print(f'Epoch {epoch+1}/10 | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')

# Save checkpoint at the end of phase 1
checkpoint = {
    'model_state_dict': model_3.state_dict(),
    'optimizer_state_dict': optimizer_3.state_dict(),
    'epoch': 9,  # 0-indexed, so epoch 9 = 10th epoch
    'loss': losses_phase1[-1],
}
torch.save(checkpoint, 'saved_models/crash_checkpoint.pth')
print(f'\nCheckpoint saved after epoch 10.')

In [None]:
# ===== Phase 2: Simulate crash â€” start fresh and resume =====
print('=== Phase 2: Simulating crash â€” creating fresh model and optimizer ===')

# Create completely fresh model and optimizer (as if Python restarted)
model_resumed = MNISTClassifier().to(device)
optimizer_resumed = optim.Adam(model_resumed.parameters(), lr=1e-3)

# TODO: Load the checkpoint from 'saved_models/crash_checkpoint.pth'
# checkpoint = ...

# TODO: Restore model weights and optimizer state
# model_resumed.load_state_dict(...)
# optimizer_resumed.load_state_dict(...)

# TODO: Get the starting epoch and loss from the checkpoint
# start_epoch = ...
# print(f'Resumed from epoch {start_epoch + 1}, loss: {checkpoint["loss"]:.4f}')

# Continue training for 10 more epochs
losses_phase2 = []

for epoch in range(start_epoch + 1, start_epoch + 11):
    train_loss = train_one_epoch(model_resumed, train_loader, optimizer_resumed, criterion)
    val_loss, val_acc = evaluate(model_resumed, test_loader)
    losses_phase2.append(val_loss)
    print(f'Epoch {epoch+1}/20 | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')

print('\nTraining complete.')

In [None]:
# Plot the full 20-epoch loss curve
all_losses = losses_phase1 + losses_phase2
epochs = range(1, 21)

plt.figure(figsize=(10, 5))
plt.plot(list(epochs)[:10], losses_phase1, 'o-', label='Phase 1 (before crash)', linewidth=2)
plt.plot(list(epochs)[10:], losses_phase2, 's-', label='Phase 2 (after resume)', linewidth=2)
plt.axvline(x=10.5, color='red', linestyle='--', alpha=0.7, label='Crash / Resume point')
plt.xlabel('Epoch')
plt.ylabel('Validation Loss')
plt.title('Loss Curve Across Training Crash')
plt.legend()
plt.grid(alpha=0.3)
plt.show()

# Check for discontinuity
gap = abs(losses_phase2[0] - losses_phase1[-1])
print(f'Loss before crash:  {losses_phase1[-1]:.4f}')
print(f'Loss after resume:  {losses_phase2[0]:.4f}')
print(f'Gap: {gap:.4f}')

if gap < 0.05:
    print('Loss curve is continuous â€” checkpoint resume worked!')
else:
    print('Significant gap detected â€” did you load the optimizer state?')

<details>
<summary>ðŸ’¡ Solution</summary>

**The key insight:** A smooth loss curve requires restoring both the model AND the optimizer. Without the optimizer state, Adam's momentum estimates reset and the loss jumps.

```python
# Load the checkpoint
checkpoint = torch.load('saved_models/crash_checkpoint.pth', map_location=device, weights_only=True)

# Restore model weights and optimizer state
model_resumed.load_state_dict(checkpoint['model_state_dict'])
optimizer_resumed.load_state_dict(checkpoint['optimizer_state_dict'])

# Get the starting epoch
start_epoch = checkpoint['epoch']
print(f'Resumed from epoch {start_epoch + 1}, loss: {checkpoint["loss"]:.4f}')
```

**Why the optimizer state matters:** Try commenting out `optimizer_resumed.load_state_dict(...)` and watch the gap at epoch 10. Adam maintains per-parameter momentum and variance estimates. Without them, the optimizer "forgets" how it was navigating the loss landscape and has to re-learn â€” causing a visible discontinuity.

**The resume pattern:**
1. Create fresh model and optimizer (same architecture and hyperparameters)
2. `torch.load()` the checkpoint dict
3. `model.load_state_dict(checkpoint['model_state_dict'])`
4. `optimizer.load_state_dict(checkpoint['optimizer_state_dict'])`
5. `start_epoch = checkpoint['epoch']`
6. Resume the loop from `start_epoch + 1`

</details>

---

## Exercise 4: Implement Early Stopping (Independent)

Training too long causes overfitting. Early stopping watches the validation loss: if it stops improving for a set number of epochs (the **patience**), training stops and the best model is restored.

**Task:** Implement early stopping from scratch:
1. Track the best validation loss and a patience counter (`patience=5`)
2. When validation loss improves, save the best model and reset the counter
3. When validation loss does not improve, increment the counter
4. If the counter reaches `patience`, stop training early
5. After training ends (naturally or early), restore the best model
6. Print when patience runs out

You have all the building blocks from the previous exercises. This is about putting them together into a complete pattern.

In [None]:
# Fresh model
model_4 = MNISTClassifier().to(device)
optimizer_4 = optim.Adam(model_4.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Early stopping config
max_epochs = 50
patience = 5

# TODO: Initialize tracking variables
# best_val_loss = ...
# patience_counter = ...
# train_losses = []
# val_losses = []

for epoch in range(max_epochs):
    train_loss = train_one_epoch(model_4, train_loader, optimizer_4, criterion)
    val_loss, val_acc = evaluate(model_4, test_loader)

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    # TODO: Check if validation loss improved
    #   - If improved: save the model, reset patience counter, update best_val_loss
    #   - If not improved: increment patience counter

    print(f'Epoch {epoch+1}/{max_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Patience: {patience_counter}/{patience}')

    # TODO: Check if patience has run out â€” if so, print a message and break

# TODO: Restore the best model
# (load the saved best model state dict back into model_4)

# Final evaluation
final_loss, final_acc = evaluate(model_4, test_loader)
print(f'\nBest model restored â€” Val Loss: {final_loss:.4f} | Val Acc: {final_acc:.2f}%')

In [None]:
# Plot training curves
fig, ax = plt.subplots(figsize=(10, 5))
epochs_range = range(1, len(train_losses) + 1)

ax.plot(epochs_range, train_losses, 'o-', label='Train Loss', linewidth=2, markersize=4)
ax.plot(epochs_range, val_losses, 's-', label='Val Loss', linewidth=2, markersize=4)
ax.axvline(x=len(train_losses) - patience, color='green', linestyle='--', alpha=0.7, label='Best model epoch')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Early Stopping: Training vs Validation Loss')
ax.legend()
ax.grid(alpha=0.3)
plt.show()

print(f'Training stopped at epoch {len(train_losses)} (max was {max_epochs})')
print(f'Best model was from epoch {len(train_losses) - patience}')

<details>
<summary>ðŸ’¡ Solution</summary>

**The key insight:** Early stopping combines three patterns you already know: tracking best loss, saving checkpoints, and restoring the best model. The patience counter is the only new piece â€” it counts how many epochs have passed without improvement.

```python
# Initialize tracking variables
best_val_loss = float('inf')
patience_counter = 0
train_losses = []
val_losses = []

for epoch in range(max_epochs):
    train_loss = train_one_epoch(model_4, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model_4, test_loader)

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    # Check if validation loss improved
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model_4.state_dict(), 'saved_models/early_stop_best.pth')
    else:
        patience_counter += 1

    print(f'Epoch {epoch+1}/{max_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Patience: {patience_counter}/{patience}')

    # Check if patience has run out
    if patience_counter >= patience:
        print(f'\nEarly stopping triggered! No improvement for {patience} epochs.')
        break

# Restore the best model
best_state = torch.load('saved_models/early_stop_best.pth', map_location=device, weights_only=True)
model_4.load_state_dict(best_state)
```

**The early stopping pattern:**
1. Set `best_val_loss = float('inf')` and `patience_counter = 0`
2. Each epoch: if val loss improved -> save model, reset counter. If not -> increment counter.
3. If counter reaches patience -> break.
4. After the loop -> restore the saved best model.

**Why restore the best model?** When training stops, the current model is the one from the *last* epoch â€” which had a higher val loss than the best. You want the model from the epoch with the lowest validation loss.

**Patience tuning:** Too low (1-2) and you stop at normal fluctuations. Too high (20+) and you overfit before stopping. 3-10 is typical.

</details>

---

## Key Takeaways

1. **`model.state_dict()`** is a dictionary of learned parameters. Save it with `torch.save()`, load it with `torch.load()` + `model.load_state_dict()`.
2. **Checkpoints** bundle model weights, optimizer state, epoch, and loss into one dict â€” everything needed to resume training exactly where you left off.
3. **Optimizer state matters.** Without it, resumed training shows a visible gap in the loss curve because Adam's momentum estimates reset.
4. **Early stopping** prevents overfitting by tracking validation loss and stopping when it stops improving. Always restore the best model at the end.
5. **The `.pth` convention** is standard for PyTorch saved files. Use `weights_only=True` when loading for security.

**Clean up** (optional): delete the `saved_models/` directory when you're done experimenting.

In [None]:
# Optional: clean up saved files
# import shutil
# shutil.rmtree('saved_models', ignore_errors=True)
# print('Cleaned up saved_models/')