# The Training Loop: Putting It All Together

In this notebook, you'll go from understanding the individual pieces (loss functions, optimizers, gradients) to wiring them into a complete PyTorch training loop.

**What you'll do:**
- Verify that PyTorch's built-in loss matches your manual calculation
- Verify that the optimizer does the same thing as a manual parameter update
- Write a complete training loop from scratch
- Compare SGD vs Adam convergence
- Diagnose a common training bug (missing `zero_grad`)
- (Stretch) Train a 2-layer network on nonlinear data

**Predict-first methodology:** Before running each cell, predict what you think the output will be. This forces active engagement and builds real intuition.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# For nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)

print('Setup complete.')

---

## Exercise 1: Verify nn.MSELoss Matches Manual MSE

**Type: GUIDED** — Follow along, fill in the marked lines.

Before using `nn.MSELoss` in a training loop, you should know exactly what it computes. The formula is:

$$\text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (\hat{y}_i - y_i)^2$$

Your task: compute MSE manually, then verify it matches `nn.MSELoss`.

In [None]:
# Some fake predictions and targets
y_hat = torch.tensor([2.5, 0.0, 1.0, 3.5])
y = torch.tensor([3.0, -0.5, 0.0, 2.0])

# --- Manual MSE ---
# Step 1: Compute the differences
diffs = y_hat - y
print(f'Differences: {diffs}')

# Step 2: Square them
squared = diffs ** 2
print(f'Squared:     {squared}')

# Step 3: Take the mean
manual_mse = squared.mean()
print(f'Manual MSE:  {manual_mse.item():.6f}')

# --- PyTorch MSE ---
criterion = nn.MSELoss()
pytorch_mse = criterion(y_hat, y)
print(f'PyTorch MSE: {pytorch_mse.item():.6f}')

# --- Verify they match ---
match = torch.allclose(manual_mse, pytorch_mse)
print(f'\nMatch: {match}')

if match:
    print('nn.MSELoss is doing exactly what you think it does. No magic.')

**What you just proved:** `nn.MSELoss` is literally `((y_hat - y)**2).mean()`. There's no hidden complexity. When you see it in a training loop, you know exactly what's happening.

---

## Exercise 2: Verify optimizer.step() Matches Manual Update

**Type: GUIDED** — Follow along, compare the results.

SGD with learning rate `lr` does exactly one thing per parameter:

$$p \leftarrow p - \text{lr} \times \frac{\partial L}{\partial p}$$

Let's prove that `optimizer.step()` does the same thing as doing it by hand.

In [None]:
lr = 0.1

# --- Method A: Manual update ---
torch.manual_seed(0)
model_manual = nn.Linear(1, 1)

# Record initial parameter values
w_before = model_manual.weight.data.clone()
b_before = model_manual.bias.data.clone()
print(f'Initial weight: {w_before.item():.6f}')
print(f'Initial bias:   {b_before.item():.6f}')

# Forward pass to get gradients
x = torch.tensor([[2.0]])
y = torch.tensor([[5.0]])
pred = model_manual(x)
loss = nn.MSELoss()(pred, y)
loss.backward()

print(f'\nWeight grad: {model_manual.weight.grad.item():.6f}')
print(f'Bias grad:   {model_manual.bias.grad.item():.6f}')

# Manual update: p = p - lr * grad
with torch.no_grad():
    model_manual.weight -= lr * model_manual.weight.grad
    model_manual.bias -= lr * model_manual.bias.grad

w_manual = model_manual.weight.data.clone()
b_manual = model_manual.bias.data.clone()
print(f'\nAfter manual update:')
print(f'  weight: {w_manual.item():.6f}')
print(f'  bias:   {b_manual.item():.6f}')

In [None]:
# --- Method B: optimizer.step() ---
torch.manual_seed(0)  # Same init as above
model_optim = nn.Linear(1, 1)
optimizer = optim.SGD(model_optim.parameters(), lr=lr)

# Same forward pass
x = torch.tensor([[2.0]])
y = torch.tensor([[5.0]])
pred = model_optim(x)
loss = nn.MSELoss()(pred, y)
loss.backward()

# Optimizer update
optimizer.step()

w_optim = model_optim.weight.data.clone()
b_optim = model_optim.bias.data.clone()
print(f'After optimizer.step():')
print(f'  weight: {w_optim.item():.6f}')
print(f'  bias:   {b_optim.item():.6f}')

# --- Compare ---
print(f'\nWeights match: {torch.allclose(w_manual, w_optim)}')
print(f'Biases match:  {torch.allclose(b_manual, b_optim)}')
print('\noptimizer.step() is just p -= lr * p.grad for every parameter. No magic.')

**What you just proved:** `optimizer.step()` with SGD does the exact same subtraction you'd do by hand. It just does it for every parameter in the model at once. The optimizer is a convenience, not a mystery.

---

## Exercise 3: Train Linear Regression in PyTorch

**Type: SUPPORTED** — Template provided, you fill in the training loop.

Now wire the pieces together. You'll train a model to learn `y = 3x - 2` from noisy data.

The training loop pattern:
```
for each epoch:
    1. Forward pass:  predictions = model(x)
    2. Compute loss:  loss = criterion(predictions, y)
    3. Zero grads:    optimizer.zero_grad()
    4. Backward:      loss.backward()
    5. Update:        optimizer.step()
```

In [None]:
# Generate data: y = 3x - 2 + noise
torch.manual_seed(42)
n_samples = 100

X = torch.randn(n_samples, 1) * 3  # Random x values
y = 3 * X - 2 + torch.randn(n_samples, 1) * 0.5  # y = 3x - 2 + noise

# Visualize the data
plt.scatter(X.numpy(), y.numpy(), alpha=0.5, s=20)
plt.xlabel('x')
plt.ylabel('y')
plt.title('Training Data: y = 3x - 2 + noise')
plt.grid(alpha=0.3)
plt.show()

In [None]:
# --- Your training loop ---

# Model: a single linear layer (1 input -> 1 output)
model = nn.Linear(1, 1)

# Loss function
criterion = nn.MSELoss()

# Optimizer: SGD with learning rate 0.01
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Record initial parameters
print(f'Before training:')
print(f'  weight = {model.weight.item():.4f} (target: 3.0)')
print(f'  bias   = {model.bias.item():.4f} (target: -2.0)')

# Training
n_epochs = 100
losses = []

for epoch in range(n_epochs):
    # 1. Forward pass
    predictions = model(X)
    
    # 2. Compute loss
    loss = criterion(predictions, y)
    
    # 3. Zero gradients (critical! you'll see why in Exercise 5)
    optimizer.zero_grad()
    
    # 4. Backward pass
    loss.backward()
    
    # 5. Update parameters
    optimizer.step()
    
    # Record loss
    losses.append(loss.item())
    
    # Print every 20 epochs
    if (epoch + 1) % 20 == 0:
        print(f'  Epoch {epoch+1:3d}: loss = {loss.item():.4f}')

print(f'\nAfter training:')
print(f'  weight = {model.weight.item():.4f} (target: 3.0)')
print(f'  bias   = {model.bias.item():.4f} (target: -2.0)')

In [None]:
# Plot the results
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss curve
axes[0].plot(losses, linewidth=2, color='#ff6b6b')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('MSE Loss')
axes[0].set_title('Training Loss')
axes[0].grid(alpha=0.3)

# Final fit
x_line = torch.linspace(X.min(), X.max(), 100).unsqueeze(1)
with torch.no_grad():
    y_pred = model(x_line)

axes[1].scatter(X.numpy(), y.numpy(), alpha=0.4, s=20, label='Data')
axes[1].plot(x_line.numpy(), y_pred.numpy(), color='#ff6b6b', linewidth=2,
             label=f'Learned: y = {model.weight.item():.2f}x + ({model.bias.item():.2f})')
axes[1].plot(x_line.numpy(), 3 * x_line.numpy() - 2, '--', color='#51cf66', linewidth=2,
             label='True: y = 3x - 2', alpha=0.7)
axes[1].set_xlabel('x')
axes[1].set_ylabel('y')
axes[1].set_title('Learned Fit vs True Function')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

---

## Exercise 4: Swap SGD for Adam

**Type: SUPPORTED** — One-line change, observe the difference.

Adam adapts the learning rate per-parameter using momentum and squared gradient history. In practice, this often means faster convergence.

Your task: run the same training as Exercise 3, but swap `optim.SGD` for `optim.Adam`. Compare the loss curves.

In [None]:
# --- SGD training ---
torch.manual_seed(42)
model_sgd = nn.Linear(1, 1)
optimizer_sgd = optim.SGD(model_sgd.parameters(), lr=0.01)
criterion = nn.MSELoss()

losses_sgd = []
for epoch in range(100):
    pred = model_sgd(X)
    loss = criterion(pred, y)
    optimizer_sgd.zero_grad()
    loss.backward()
    optimizer_sgd.step()
    losses_sgd.append(loss.item())

# --- Adam training ---
torch.manual_seed(42)
model_adam = nn.Linear(1, 1)
optimizer_adam = optim.Adam(model_adam.parameters(), lr=0.01)  # <-- The only change

losses_adam = []
for epoch in range(100):
    pred = model_adam(X)
    loss = criterion(pred, y)
    optimizer_adam.zero_grad()
    loss.backward()
    optimizer_adam.step()
    losses_adam.append(loss.item())

print(f'SGD  — Final loss: {losses_sgd[-1]:.4f}, weight: {model_sgd.weight.item():.4f}, bias: {model_sgd.bias.item():.4f}')
print(f'Adam — Final loss: {losses_adam[-1]:.4f}, weight: {model_adam.weight.item():.4f}, bias: {model_adam.bias.item():.4f}')

In [None]:
# Compare loss curves side by side
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(losses_sgd, linewidth=2, color='#ff6b6b', label='SGD')
axes[0].plot(losses_adam, linewidth=2, color='#4ecdc4', label='Adam')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('MSE Loss')
axes[0].set_title('Loss Curves: SGD vs Adam')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Zoomed in on early epochs
axes[1].plot(losses_sgd[:30], linewidth=2, color='#ff6b6b', label='SGD')
axes[1].plot(losses_adam[:30], linewidth=2, color='#4ecdc4', label='Adam')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('MSE Loss')
axes[1].set_title('First 30 Epochs (Zoomed)')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print('\nNotice: Adam typically converges faster in the early epochs.')
print('Both reach a similar final loss, but Adam gets there quicker.')
print('This is why Adam is the default choice for most deep learning.')

---

## Exercise 5: Diagnose the Accumulation Bug

**Type: SUPPORTED** — Predict, observe, fix.

The code below has a training loop with a **common bug**. Before running it:

1. Read the code carefully
2. **Predict** what will go wrong
3. Run it and see if you were right
4. Fix the bug

In [None]:
# --- BUGGY training loop ---
# Read this carefully. What's wrong?

torch.manual_seed(42)
model_buggy = nn.Linear(1, 1)
optimizer_buggy = optim.SGD(model_buggy.parameters(), lr=0.01)
criterion = nn.MSELoss()

losses_buggy = []
grad_norms = []  # Track gradient magnitude

for epoch in range(100):
    pred = model_buggy(X)
    loss = criterion(pred, y)
    
    # BUG: Where's optimizer.zero_grad()?
    
    loss.backward()
    
    # Record gradient magnitude before stepping
    grad_norm = model_buggy.weight.grad.norm().item()
    grad_norms.append(grad_norm)
    
    optimizer_buggy.step()
    losses_buggy.append(loss.item())

print(f'Final loss: {losses_buggy[-1]:.4f}')
print(f'Weight: {model_buggy.weight.item():.4f} (target: 3.0)')
print(f'Bias:   {model_buggy.bias.item():.4f} (target: -2.0)')

In [None]:
# Visualize the bug
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss comparison
axes[0].plot(losses_sgd, linewidth=2, color='#51cf66', label='Correct (with zero_grad)')
axes[0].plot(losses_buggy, linewidth=2, color='#ff6b6b', label='Buggy (no zero_grad)')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('MSE Loss')
axes[0].set_title('Loss: Correct vs Buggy')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Gradient norms
axes[1].plot(grad_norms, linewidth=2, color='#ff6b6b')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Gradient Norm')
axes[1].set_title('Gradient Magnitude Over Time (no zero_grad)')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print('\nWithout zero_grad(), gradients ACCUMULATE across epochs.')
print('Each .backward() ADDS to the existing .grad, it does not replace it.')
print('This makes the effective gradient grow every iteration,')
print('causing the optimizer to take increasingly large steps.')

In [None]:
# --- FIXED training loop ---
# Add the missing line and verify it works

torch.manual_seed(42)
model_fixed = nn.Linear(1, 1)
optimizer_fixed = optim.SGD(model_fixed.parameters(), lr=0.01)
criterion = nn.MSELoss()

losses_fixed = []

for epoch in range(100):
    pred = model_fixed(X)
    loss = criterion(pred, y)
    
    optimizer_fixed.zero_grad()  # <-- THE FIX: zero gradients before backward
    loss.backward()
    optimizer_fixed.step()
    
    losses_fixed.append(loss.item())

print(f'Fixed — Final loss: {losses_fixed[-1]:.4f}')
print(f'Weight: {model_fixed.weight.item():.4f} (target: 3.0)')
print(f'Bias:   {model_fixed.bias.item():.4f} (target: -2.0)')
print(f'\nLesson: always call optimizer.zero_grad() before loss.backward().')

**Why does PyTorch accumulate gradients by default?**

It seems like a footgun, but it's actually useful. When your batch is too large to fit in memory, you can split it into mini-batches, call `.backward()` on each, and the gradients accumulate. Then you call `optimizer.step()` once. This is called **gradient accumulation** and it's a real technique used in LLM training.

For standard training loops though: always `zero_grad()` first.

---

## Exercise 6 (Stretch): Train a 2-Layer Network on Nonlinear Data

**Type: INDEPENDENT** — Minimal guidance. You have all the pieces.

A single `nn.Linear` can only learn straight lines. For `y = x^2`, you need at least one hidden layer with a nonlinear activation.

Your task:
1. Generate data from `y = x^2 + noise`
2. Build a 2-layer network: `Linear(1, 32) -> ReLU -> Linear(32, 1)`
3. Train it with the same loop pattern
4. Plot the learned curve vs the true curve

In [None]:
# --- Generate nonlinear data ---
torch.manual_seed(42)
X_nl = torch.linspace(-3, 3, 200).unsqueeze(1)  # 200 points from -3 to 3
y_nl = X_nl ** 2 + torch.randn(200, 1) * 0.3    # y = x^2 + noise

plt.scatter(X_nl.numpy(), y_nl.numpy(), alpha=0.4, s=15)
plt.xlabel('x')
plt.ylabel('y')
plt.title('Nonlinear Data: y = x² + noise')
plt.grid(alpha=0.3)
plt.show()

In [None]:
# --- Build and train your network ---

# Model: 2-layer network with ReLU
net = nn.Sequential(
    nn.Linear(1, 32),
    nn.ReLU(),
    nn.Linear(32, 1)
)

criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.01)

# Training loop
n_epochs = 500
losses_nl = []

for epoch in range(n_epochs):
    pred = net(X_nl)
    loss = criterion(pred, y_nl)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    losses_nl.append(loss.item())
    
    if (epoch + 1) % 100 == 0:
        print(f'  Epoch {epoch+1}: loss = {loss.item():.4f}')

print(f'\nFinal loss: {losses_nl[-1]:.4f}')

In [None]:
# --- Visualize the result ---
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss curve
axes[0].plot(losses_nl, linewidth=2, color='#ff6b6b')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('MSE Loss')
axes[0].set_title('Training Loss')
axes[0].grid(alpha=0.3)

# Learned curve vs true curve
x_plot = torch.linspace(-3, 3, 300).unsqueeze(1)
with torch.no_grad():
    y_learned = net(x_plot)

axes[1].scatter(X_nl.numpy(), y_nl.numpy(), alpha=0.3, s=10, label='Data')
axes[1].plot(x_plot.numpy(), x_plot.numpy() ** 2, '--', color='#51cf66',
             linewidth=2, label='True: y = x²')
axes[1].plot(x_plot.numpy(), y_learned.numpy(), color='#ff6b6b',
             linewidth=2, label='Learned curve')
axes[1].set_xlabel('x')
axes[1].set_ylabel('y')
axes[1].set_title('Learned Curve vs True Function')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print('A linear model cannot learn y = x². But Linear + ReLU + Linear can.')
print('The ReLU activation is what allows the network to bend.')
print('This is why neural networks use nonlinear activations between layers.')

---

## Key Takeaways

1. **`nn.MSELoss` is just `((y_hat - y)**2).mean()`** — no hidden complexity
2. **`optimizer.step()` is just `p -= lr * p.grad`** for every parameter — the optimizer is a convenience, not magic
3. **The training loop is always the same pattern:** forward -> loss -> zero_grad -> backward -> step
4. **Adam converges faster than SGD** for the same learning rate on most problems
5. **Always call `zero_grad()` before `backward()`** — PyTorch accumulates gradients by default, which is useful for gradient accumulation but a bug if you forget
6. **Nonlinear activations (ReLU) are essential** — without them, stacking linear layers just gives you another linear layer