# 06 — Training Loop Anatomy (from zero, explained)
**Goal:** demystify the training loop by building it piece-by-piece with a tiny neural net on a toy classification problem.

You’ll learn:
- What *tensors*, *parameters*, *gradients*, *loss*, *optimizer*, and *backprop* actually **do** (in plain terms).
- How to structure a **Dataset** and use a **DataLoader** (batches, shuffling, workers).
- The difference between `model.train()` and `model.eval()`.
- Why we call `optimizer.zero_grad()` and then `loss.backward()` and `optimizer.step()`.
- How to **debug shapes** when things break.
- How to add **mixed precision** (AMP), **validation**, **checkpoints**, **LR scheduler**, and **gradient clipping**.
- How to make runs **reproducible** (seeds).


## 0) Imports, device, and reproducibility
**Analogy (Frontend):** Think of this as wiring up your build pipeline and setting your env (dev vs prod).
- **Device** = where tensors live (`cpu` or `cuda`).  
- **Seed** = makes randomness repeatable so results are comparable while you learn.


In [None]:
import math, random, os, time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR

# Device selection
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)

# Reproducibility (good defaults while learning)
SEED = 42
random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(False)  # keep some speed; flip True if you need strict determinism
torch.backends.cudnn.benchmark = True      # autotune convs for speed


## 1) A tiny **Dataset** (toy classification)
We’ll synthesize 2D points from two classes (think: dots on a plane) so there’s nothing to download.

**Key terms:**
- **Dataset**: an object that knows how to return one `(input, target)` pair by index.
- **DataLoader**: batches + shuffles data and optionally uses background workers.


In [None]:
class ToyBlobs(Dataset):
    """Two Gaussian blobs in 2D, labeled 0/1."""
    def __init__(self, n_per_class=800, gap=2.5, std=1.0):
        # class 0 centered at (-gap, 0), class 1 at (+gap, 0)
        c0 = torch.randn(n_per_class, 2) * std + torch.tensor([-gap, 0.0])
        c1 = torch.randn(n_per_class, 2) * std + torch.tensor([+gap, 0.0])
        x = torch.cat([c0, c1], dim=0)
        y = torch.cat([torch.zeros(n_per_class, dtype=torch.long),
                       torch.ones(n_per_class,  dtype=torch.long)], dim=0)
        # shuffle once here
        idx = torch.randperm(x.size(0))
        self.x, self.y = x[idx], y[idx]

    def __len__(self):
        return self.x.size(0)

    def __getitem__(self, i):
        return self.x[i], self.y[i]

train_ds = ToyBlobs(n_per_class=1000)
val_ds   = ToyBlobs(n_per_class=200)

# DataLoaders: batch_size controls VRAM usage; pin_memory helps GPU transfers.
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True,  num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

len(train_ds), len(val_ds)


> **Shape Debug Tip:** Everywhere, keep an eye on `tensor.shape`.
If something breaks, print shapes through your pipeline until you find the mismatch—just like inspecting props through a component tree.


## 2) Define a tiny model (**nn.Module**)
**Analogy (Frontend):** A model is a component tree with learnable weights. The `forward()` method is your render function.

We’ll use a **2-layer MLP**:
- Input: 2 features (x & y)
- Hidden: 32 units (ReLU)
- Output: 2 logits (class scores)

**Terms:**
- **Parameters**: Tensors registered as weights/biases (PyTorch tracks them).
- **Logits**: Raw scores before softmax; `CrossEntropyLoss` expects logits.


In [None]:
class MLP(nn.Module):
    def __init__(self, in_dim=2, hidden=32, out_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_dim),
        )
    def forward(self, x):
        # x: [batch, 2]
        return self.net(x)

model = MLP().to(device)
print(model)


## 3) Loss & Optimizer
- **Loss**: how wrong we are. We’ll use `CrossEntropyLoss` for 2-class classification.
- **Optimizer**: updates parameters to reduce loss using gradients (here: `AdamW`).

**The 3-step training micro-loop** (memorize this):
1. `optimizer.zero_grad()` → clear old gradients
2. `loss.backward()` → backprop: compute new gradients
3. `optimizer.step()` → move weights a tiny bit opposite the gradient


In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-2, weight_decay=1e-2)
scheduler = StepLR(optimizer, step_size=20, gamma=0.7)  # shrink LR every 20 epochs


## 4) Training and Validation loops (with AMP)
- `model.train()` enables dropout/batchnorm training behavior.
- `model.eval()` switches them off and we wrap in `torch.no_grad()` for speed.
- **AMP** (automatic mixed precision) uses `autocast` + `GradScaler` to speed up and lower memory usage on GPUs.

We’ll collect: `loss`, `accuracy`. Accuracy is easy for classification: compare `argmax` of logits with labels.


In [None]:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler(enabled=(device=='cuda'))

def accuracy_from_logits(logits, y):
    pred = logits.argmax(dim=1)
    return (pred == y).float().mean().item()

def train_one_epoch(model, loader):
    model.train()
    total_loss, total_acc, n = 0.0, 0.0, 0
    for xb, yb in loader:
        xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)

        with autocast(enabled=(device=='cuda')):
            logits = model(xb)
            loss = loss_fn(logits, yb)

        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        # optional: gradient clipping to stabilize
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        scaler.step(optimizer)
        scaler.update()

        bs = xb.size(0)
        total_loss += loss.item() * bs
        total_acc  += accuracy_from_logits(logits, yb) * bs
        n += bs
    return total_loss/n, total_acc/n

@torch.no_grad()
def validate(model, loader):
    model.eval()
    total_loss, total_acc, n = 0.0, 0.0, 0
    for xb, yb in loader:
        xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)
        logits = model(xb)
        loss = loss_fn(logits, yb)
        bs = xb.size(0)
        total_loss += loss.item() * bs
        total_acc  += accuracy_from_logits(logits, yb) * bs
        n += bs
    return total_loss/n, total_acc/n


## 5) Run it
Watch loss ↓ and accuracy ↑ over epochs. If things get worse, that’s a bug or too-large learning rate.


In [None]:
EPOCHS = 40
best_val_acc, best_state = 0.0, None

for epoch in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_one_epoch(model, train_loader)
    va_loss, va_acc = validate(model, val_loader)
    scheduler.step()

    if va_acc > best_val_acc:
        best_val_acc = va_acc
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

    print(f"epoch {epoch:02d} | train loss {tr_loss:.4f} acc {tr_acc:.3f} | val loss {va_loss:.4f} acc {va_acc:.3f} | lr {scheduler.get_last_lr()[0]:.5f}")

# Save best checkpoint
if best_state is not None:
    torch.save(best_state, 'toy_mlp_best.pt')
    print('Saved best checkpoint -> toy_mlp_best.pt')


## 6) Load & use the checkpoint (inference)
**Inference** = forward pass only (no gradients). Here we’ll reload and score the validation set again.


In [None]:
@torch.no_grad()
def load_and_eval(path='toy_mlp_best.pt'):
    m = MLP().to(device)
    m.load_state_dict(torch.load(path, map_location=device))
    loss, acc = validate(m, val_loader)
    print(f'Reloaded model -> val loss {loss:.4f} acc {acc:.3f}')
    return m

_ = load_and_eval()


## 7) Common questions (quick answers)
- **Why zero gradients?** Gradients accumulate by default. If you don’t clear them, each `backward()` adds to the previous one.
- **Why `train()` vs `eval()`?** Some layers behave differently during training (dropout, batchnorm). Use the right mode.
- **Why AMP?** Faster training + less VRAM on GPU with near-identical accuracy for most models.
- **Why a scheduler?** Big learning rates help early, smaller ones help refine later. Schedulers automate that.
- **Why clip gradients?** Prevents exploding gradients in some models; a safe default limit helps stability.

**Memorize this mini template:**
```python
for xb, yb in train_loader:
    with autocast():
        logits = model(xb)
        loss = loss_fn(logits, yb)
    optimizer.zero_grad(set_to_none=True)
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer); clip_grad_norm_(model.parameters(), 5.0)
    scaler.step(optimizer); scaler.update()
```
