In [1]:
# !nvidia-smi

# MNIST — MLP warm-up

We’ll load MNIST, normalize it, define a small MLP, and run a quick
sanity check forward pass to verify shapes before training.


In [2]:
# OPTIONAL: only run this if your torch/torchvision install is broken.
# For GPU on Kaggle (CUDA 12.1 wheels):
# !pip install --upgrade --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# For CPU-only:
# !pip install --upgrade --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu


## Imports & device

We’ll autodetect CUDA and fall back to CPU. The code works either way.


In [3]:
import torch
import torchvision

device = "cuda" if torch.cuda.is_available() else "cpu"

print("torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
print("device:", device)


torch: 2.6.0+cu124
torchvision: 0.21.0+cu124
device: cpu


## Dataset & transforms

- `ToTensor()` → scales pixels to [0,1] with shape [1, 28, 28].
- `Normalize((0.1307,), (0.3081,))` → center/scale using MNIST stats.
  (Note the commas: single-element tuples.)


### Why do we normalize MNIST with `(0.1307,), (0.3081,)`?

After `ToTensor()`, MNIST images are scaled to `[0,1]`, but their distribution isn’t centered and doesn’t have unit variance:
- Mean pixel value is about **0.1307** (most pixels are dark background).
- Standard deviation is about **0.3081**.

**Why normalize?**
- Centering (subtracting the mean) makes neuron inputs hover around zero, which helps gradients flow and speeds up learning.
- Scaling (dividing by the std) puts features on a comparable scale, making optimization more stable and less sensitive to learning rates.

**Why those exact numbers?**
- They are the empirical mean and std of the MNIST training set computed over all pixels.
- Using dataset-specific stats is better than generic choices (like 0.5/0.5) because it matches the true data distribution.

**Why the tuples — and why the comma is so important?**
- `Normalize` expects a *sequence* (list or tuple) with one value per channel.
- MNIST has 1 channel → we need one mean and one std → a 1-element tuple.
- In Python:
  - `(0.3081,)` → a tuple containing one float ✅
  - `(0.3081)` → just a float ❌
- If you forget the comma, you pass a float instead of a tuple. That breaks the shape handling inside `Normalize` and can lead to confusing errors (like “std evaluated to zero”).

**What if we skip normalization?**
- The model may still learn (MNIST is simple), but:
  - Training is slower.
  - Optimization is less stable.
  - Accuracy may plateau lower.
- For harder datasets (like CIFAR or ImageNet), skipping normalization can mean the model fails to learn at all.

**TL;DR**
Normalization with `(0.1307,), (0.3081,)` standardizes MNIST inputs to zero-like mean and unit-like variance.  
The trailing comma is crucial because it makes those values tuples, not plain floats, which is exactly what `Normalize` expects.


In [4]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

train_mnist = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
test_mnist = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)

# quick peek
x0, y0 = train_mnist[0]
print("one sample:", x0.shape, y0)  # torch.Size([1, 28, 28]) label_int


100%|██████████| 9.91M/9.91M [00:00<00:00, 17.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 481kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.42MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.21MB/s]

one sample: torch.Size([1, 28, 28]) 5





## Model

A simple fully-connected classifier:
- Flatten 28×28 → 784
- Hidden: 300 → 300 with LeakyReLU
- Output: 10 logits (no Softmax; CrossEntropyLoss expects logits)


In [5]:
import torch

# In PyTorch, p.numel() returns the number of elements (scalars) in the tensor p.

p = torch.randn(3, 4)   # shape [3,4]
print(p.numel())        # 12


12


In [6]:
model = torch.nn.Sequential(
    torch.nn.Linear(28*28, 300),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(300, 300),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(300, 10)  # logits
).to(device)

sum_params = sum(p.numel() for p in model.parameters())
print("model params:", sum_params)


model params: 328810


## Sanity check (single example)

Flatten to 784 features, forward once, confirm output shape [1, 10].


In [7]:
digit, cls = train_mnist[0]
digit = digit.to(device).view(1, 28*28)  # add batch dim = 1
with torch.no_grad():
    out = model(digit)
print("single forward shape:", out.shape)  # torch.Size([1, 10])


single forward shape: torch.Size([1, 10])


## Sanity check with dataset loop (first item only)

Iterate the dataset, move to device, flatten, run model, print shape, break.


In [8]:
for digit, cls in train_mnist:
    digit = digit.to(device)
    digit = digit.view(digit.shape[0], 28*28)
    with torch.no_grad():
        print(model(digit).shape)  # expected: torch.Size([1, 10])
    break


torch.Size([1, 10])


### Why do we use `digit.view(digit.shape[0], 28*28)`?

Each MNIST image comes as a tensor of shape `[B, 1, 28, 28]`:
- `B` = batch size  
- `1` = number of channels (grayscale)  
- `28 × 28` = image height and width  

Our model starts with a `Linear(28*28, 300)` layer, which expects a
**flat vector of 784 features per image**, not a 2D grid.

The call

``` python
digit = digit.view(digit.shape[0], 28*28)
```

does two things:
1. Keeps the batch dimension (`digit.shape[0]`).
2. Flattens each `[1,28,28]` image into a single vector `[784]`.

So:
- Before: `[B, 1, 28, 28]`  
- After:  `[B, 784]`  

This reshaping step bridges the gap between image-shaped data and the
fully connected (dense) layers of our MLP.


## Dataloaders

We’ll iterate in mini-batches for efficient training.


In [9]:
from torch.utils.data import DataLoader 

batch_size = 62 
train_dl = DataLoader(
    train_mnist, 
    batch_size=batch_size, 
    shuffle=True,
    num_workers=2, 
    pin_memory=(device=="cuda")
)

test_dl = DataLoader(
    test_mnist, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2,
    pin_memory=(device=="cuda")
)

len(train_dl), len(test_dl)

(968, 162)

### Understanding `DataLoader` arguments

When we wrap our MNIST datasets in `DataLoader`, we specify a few
important options:

- **`batch_size=64`**  
  - How many samples to group together in one batch.  
  - Instead of returning a single `[1, 28, 28]` image, the loader
    returns `[64, 1, 28, 28]` tensors.  
  - Larger batch sizes improve GPU utilization and give smoother
    gradient estimates, but also use more memory.  
  - On CPU, smaller batches can be more practical to keep things fast
    and memory-efficient.

- **`shuffle=True` (for training)**  
  - Each epoch, the training data is shuffled.  
  - Prevents the model from simply memorizing the order of the data.  
  - Helps generalization because each mini-batch looks different each
    epoch.  
  - For evaluation (`test_dl`), we use `shuffle=False` so results are
    deterministic and ordered.

- **`num_workers=2`**  
  - Number of subprocesses used to load data in parallel.  
  - `0` means load in the main process (slower).  
  - On CPU or GPU, having a few workers (like 2–4) allows data to be
    prefetched while the model is training on the previous batch,
    keeping the pipeline efficient.  
  - On Kaggle, small values (like 2) are often safe.

- **`pin_memory=(device=="cuda")`**  
  - *Pinned (page-locked) memory* speeds up data transfer from CPU RAM
    to GPU memory.  
  - If `device=="cuda"`, we set `pin_memory=True` so each batch can be
    moved to GPU more efficiently with `.to("cuda")`.  
  - If `device=="cpu"`, this option does nothing and can safely remain
    `False`.

---

**Summary:**
- Training loader: `batch_size=64`, `shuffle=True`  
- Test loader: `batch_size=64`, `shuffle=False`  
- Use a few `num_workers` to overlap data loading with computation.  
- Enable `pin_memory` only when training on CUDA for faster CPU→GPU
  transfers.


### Loss & Optimizer

- **Loss:** `CrossEntropyLoss` compares the model’s **logits** to the
  ground-truth class indices. It internally applies `log_softmax`, so
  we **do not** put a `Softmax` layer in the model.
- **Optimizer:** `Adam` with a standard learning rate (1e-3) works well
  for this small MLP. It adapts per-parameter step sizes and usually
  converges faster than plain SGD.
- **Seed:** we set a manual seed for reproducibility (weight init and
  the Adam state).
- This works the same on **CPU or CUDA**; no device-specific changes are
  needed for defining the loss/optimizer.


In [10]:
torch.manual_seed(42)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

print(loss_fn)
print(optimizer)

CrossEntropyLoss()
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)


### Training Loop 
We train for a few epochs using mini-batches:
- Switch the model to `train()` mode (enables layers like dropout/batchnorm if present)
- Move each batch to the selected `device` (CPU or CUDA)
- **Flatten** `[B, 1, 28, 28] → [B, 784]` before the first linear layer.
- Forward → compute **CrossEntropyLoss** on logits → backprop → optimizer step.
- Track running loss and accuracy for feedback.

1. `train_dl`
   * this is a DataLoader we created from train_mnist
   * each iteration returns a mini-batch of samples
   * by default, it yields a tuple: `(batch_of_images, batch_of_labels)`
3. `tqdm(train_dl, desc= ...)`
   * `tqdm` wraps the DataLoader so we get a live progress bar while looping
   * the desc string is shown at the start of the bar 
5. `for x, y in bar:`
   * on each loop:
        - `x` = a batch of images, shape [B,1,28,28]
        - `y` = the corresponding batch of labels, shape [B] (each entry is an integer 0-9 for the digit)

### Why do we flatten `x` but not `y`?

- **`x` (the images)**  
  - Each batch arrives from the `DataLoader` with shape `[B, 1, 28, 28]`  
    (`B` = batch size, `1` = grayscale channel, `28×28` = pixels).  
  - Our model begins with a linear layer `Linear(28*28, 300)`, which
    expects a **2D tensor** of shape `[B, 784]`.  
  - Therefore, we flatten each image with  
    `x = x.view(B, 28*28)` so every image becomes a 784-dimensional
    vector, while preserving the batch dimension.

- **`y` (the labels)**  
  - Labels are already a 1D tensor of integers with shape `[B]`
    (e.g. `[5, 0, 4, 1, 9, ...]`).  
  - `CrossEntropyLoss` expects predictions as `[B, num_classes]`
    (logits) and targets as `[B]` (integer class indices).  
  - Since `y` is already in the right format, we **don’t reshape it**.

**In short:**  
- Flatten `x` because the network needs flat input vectors.  
- Leave `y` as is because it already contains class indices in the
expected shape for the loss function.


### Tracking metrics during training

Inside the training loop we compute and accumulate statistics so that we
can monitor progress:

- **`running_loss += loss.item() * x.size(0)`**  
  - `loss.item()` gives the scalar loss value for the current batch.  
  - We multiply by `x.size(0)` (the batch size) so that when we sum over
    all batches, each sample contributes equally.  
  - At the end of the epoch, dividing by `total` gives the **average
    loss per sample**.

- **`preds = logits.argmax(dim=1)`**  
  - From the model’s output `[B, 10]` (logits), we take the index of the
    maximum value along `dim=1` for each sample.  
  - This gives the predicted digit class for each image.

- **`correct += (preds == y).sum().item()`**  
  - Compares predictions to ground-truth labels `y`.  
  - Counts how many were correct in this batch, adds to the running
    total.

- **`total += y.numel()`**  
  - Increments by the number of samples in this batch.  
  - Used to compute overall averages.

- **`bar.set_postfix(...)`**  
  - Updates the live `tqdm` progress bar with the current batch loss and
    accuracy (`correct / total`).

After the loop:

- **`epoch_loss = running_loss / total`**  
  - Average loss per sample across the entire epoch.

- **`epoch_acc = correct / total`**  
  - Fraction of correctly classified samples across the entire epoch.

- **`print(...)`**  
  - Logs a summary line showing the final loss and accuracy for this
    epoch.


### Why do we use `argmax(dim=1)`?

The model outputs logits of shape `[B, 10]`:
- `B` = batch size  
- `10` = number of classes (digits 0–9 for MNIST)  

Each **row** corresponds to one sample, and each **column** corresponds
to the score for a particular class.

Example for a batch of size 4:

    logits =
    [[-1.2,  0.3,  2.1, ..., -0.7],   # sample 1
     [ 0.5, -0.2,  0.1, ..., -1.3],   # sample 2
     [ 2.9,  1.2, -0.4, ...,  0.0],   # sample 3
     [ 0.1,  0.4,  0.2, ..., -0.6]]   # sample 4

- If we used `argmax(dim=0)`, we would be taking the maximum **down the
  batch** for each class column. That compares different samples to each
  other, which is not meaningful for classification.

- If we use `argmax(dim=1)`, we take the maximum **across the 10 class
  scores** for each row (sample). That selects the predicted class for
  each sample independently.

So:

    preds = logits.argmax(dim=1)

returns a vector of shape `[B]` with one integer (0–9) per sample, e.g.:

    tensor([2, 0, 0, 1])   # predicted classes for 4 samples

**Summary:**  
We use `dim=1` because that dimension represents the class scores.
Choosing the max along `dim=1` gives the most likely class **for each
sample in the batch**.


### Why do we use `.item()`?

In the training loop we update the counter of correct predictions:

    correct += (preds == y).sum().item()

Step by step:

1. **`preds == y`**  
   Compares predicted labels with true labels, giving a Boolean tensor of
   shape `[B]` (e.g. `[True, False, True, ...]`).

2. **`.sum()`**  
   Sums over the batch, counting how many predictions were correct.  
   The result is a **0-dimensional PyTorch tensor**, e.g. `tensor(48)`.

3. **`.item()`**  
   Extracts the scalar value from the tensor and converts it into a plain
   Python number (`int` or `float`).  
   This is needed because `correct` is just a Python integer, and adding
   a tensor directly would cause a type mismatch.

4. **`correct += ...`**  
   Accumulates the number of correct predictions from each batch into a
   running total.

**Summary:**  
We use `.item()` to safely convert a scalar PyTorch tensor into a Python
number so it can be added to the counter.


### Why multiply by `x.size(0)` instead of dividing by batches?

In the training loop we accumulate loss like this:

    running_loss += loss.item() * x.size(0)
    ...
    epoch_loss = running_loss / total

- **`loss.item()`** is already the **average loss per sample** in the
  current batch (that is how `CrossEntropyLoss` behaves by default).
- Multiplying by `x.size(0)` (the batch size) converts it into the
  **total loss for that batch**.
- Summing across all batches and then dividing by the **total number of
  samples** gives the true mean loss per sample across the entire epoch.

Why not just average the batch means?

- If we simply did `(sum(losses) / num_batches)`, every batch would be
  weighted equally.
- This becomes incorrect if the **last batch is smaller** than the rest
  (common when dataset size is not divisible by batch size).
- In that case, the smaller batch would count just as much as a full
  batch, skewing the average.

**Summary:**  
By multiplying with `x.size(0)` and dividing by the total number of
samples, we ensure that **every sample contributes equally** to the
epoch loss, regardless of batch size.


### Why do we use `bar.set_postfix(...)`?

We wrapped the training loop with `tqdm` to get a progress bar:

    bar = tqdm(train_dl, desc=f"Epoch {epoch}/{epochs}")

The method `bar.set_postfix(...)` lets us show **live metrics** next to
the progress bar during training. In this case:

- **`loss=loss.item()`**  
  Displays the most recent batch’s loss (a scalar).  
  This gives a quick view of how the current batch is doing.

- **`acc=(correct / total)`**  
  Shows the running accuracy across all samples seen so far in this
  epoch.  
  It is updated continuously as more batches are processed.

**Why it’s useful:**  
- You don’t have to wait until the end of an epoch to see if training is
  going in the right direction.  
- You get immediate feedback on both loss and accuracy as the loop
  progresses.  
- Especially important for spotting problems early (e.g., loss stuck at
  the same value, accuracy not improving).

**Summary:**  
`bar.set_postfix(...)` enriches the progress bar with real-time loss and
accuracy, making training easier to monitor and debug.


In [11]:
from tqdm import tqdm 

epochs = 3 

for epoch in range(epochs+1):
    model.train() 

    running_loss = 0.
    correct = 0
    total = 0

    bar = tqdm(
        train_dl, 
        desc=f"Epoch {epoch}/{epochs}"
    )

    for x, y in bar:
        # move to device and flatten 
        x = x.to(device).view(x.shape[0], 28*28)
        y = y.to(device)

        # forward + loss 
        logits = model(x)
        loss = loss_fn(logits, y)

        # backprop + step 
        optimizer.zero_grad() 
        loss.backward()
        optimizer.step()

        # metrics 
        running_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1) 
        correct += (preds==y).sum().item() 
        total += y.numel() 

        bar.set_postfix(
            loss=loss.item(), 
            acc=(correct / total)
        )


    epoch_loss = running_loss / total 
    epoch_acc = correct / total 
    print(f"Epoch {epoch}: loss={epoch_loss:.4f}, acc={epoch_acc:.4f}")

Epoch 0/3: 100%|██████████| 968/968 [00:11<00:00, 84.42it/s, acc=0.934, loss=0.183]


Epoch 0: loss=0.2136, acc=0.9344


Epoch 1/3: 100%|██████████| 968/968 [00:11<00:00, 84.61it/s, acc=0.973, loss=0.111]


Epoch 1: loss=0.0887, acc=0.9725


Epoch 2/3: 100%|██████████| 968/968 [00:11<00:00, 83.86it/s, acc=0.98, loss=0.0196]


Epoch 2: loss=0.0635, acc=0.9803


Epoch 3/3: 100%|██████████| 968/968 [00:11<00:00, 84.61it/s, acc=0.984, loss=0.0595]

Epoch 3: loss=0.0505, acc=0.9835



