# 08 — Debugging Shapes & Common Errors
**Goal:** make shape/typing errors boring. We’ll trigger the most common mistakes on purpose and fix them.

**Covers:**
- Reading PyTorch error messages
- Batch dimension vs feature dimension
- Wrong target dtype for `CrossEntropyLoss`
- Channel-first images (C×H×W) vs channel-last (H×W×C)
- Device mismatch (CPU vs CUDA)
- Gradients accumulating & forgetting `model.eval()`


## 0) Setup

In [None]:
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)


## 1) Batch vs feature dims
**Rule of thumb:** most dense layers expect shape `[batch, features]`.

In [None]:
x = torch.randn(32, 16)
bad_x = torch.randn(16, 32)
linear = nn.Linear(16, 4)
print('good:', linear(x).shape)
try:
    linear(bad_x)
except Exception as e:
    print('bad shape error ->', type(e).__name__, e)


## 2) CrossEntropyLoss target dtype
`CrossEntropyLoss` expects **class indices** (LongTensor), **not one-hot** floats.

In [None]:
logits = torch.randn(8, 10)
y_ok = torch.randint(0, 10, (8,))
y_bad = F.one_hot(y_ok, 10).float()
ce = nn.CrossEntropyLoss()
print('ok:', ce(logits, y_ok).item())
try:
    ce(logits, y_bad)
except Exception as e:
    print('wrong target error ->', type(e).__name__, e)


## 3) Images are channel-first (C×H×W) in PyTorch
**Transform fix:** `permute(2,0,1)` turns H×W×C into C×H×W.

In [None]:
img_hwc = torch.randn(64, 64, 3)
try:
    nn.Conv2d(3, 8, 3)(img_hwc)
except Exception as e:
    print('conv expects NCHW ->', type(e).__name__, e)
img_chw = img_hwc.permute(2,0,1)
img_nchw = img_chw.unsqueeze(0)
print('fixed:', nn.Conv2d(3, 8, 3)(img_nchw).shape)


## 4) Device mismatch (CPU vs GPU)
Move **both** model and data to the same device.

In [None]:
m = nn.Linear(16, 4).to(device)
x_cpu = torch.randn(2,16)
try:
    m(x_cpu)
except Exception as e:
    print('device mismatch ->', type(e).__name__, e)
x_dev = x_cpu.to(device)
print('fixed:', m(x_dev).shape)


## 5) Accumulating grads & train/eval modes
Gradients **accumulate** by default; clear them each step. Use `model.eval()` during validation.

In [None]:
m = nn.Linear(8,1).to(device)
opt = torch.optim.SGD(m.parameters(), lr=0.1)
x = torch.randn(4,8, device=device)
y = torch.randn(4,1, device=device)
loss_fn = nn.MSELoss()
for step in range(3):
    pred = m(x)
    loss = loss_fn(pred, y)
    loss.backward()
    # opt.zero_grad()  # (commented to show accumulation)
    opt.step()
    print(f'step {step} ok but grads accumulate!')
print('Fix by calling opt.zero_grad() before backward and step.')
