In [1]:
!nvidia-smi

Thu Sep  4 18:32:11 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   35C    P8              9W /   70W |       1MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla T4                      

# MNIST — MLP warm-up

We’ll load MNIST, normalize it, define a small MLP, and run a quick
sanity check forward pass to verify shapes before training.


In [2]:
# OPTIONAL: only run this if your torch/torchvision install is broken.
# For GPU on Kaggle (CUDA 12.1 wheels):
# !pip install --upgrade --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# For CPU-only:
# !pip install --upgrade --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu


## Imports & device

We’ll autodetect CUDA and fall back to CPU. The code works either way.


In [3]:
import torch
import torchvision

device = "cuda" if torch.cuda.is_available() else "cpu"

print("torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
print("device:", device)


torch: 2.6.0+cu124
torchvision: 0.21.0+cu124
device: cuda


## Dataset & transforms

- `ToTensor()` → scales pixels to [0,1] with shape [1, 28, 28].
- `Normalize((0.1307,), (0.3081,))` → center/scale using MNIST stats.
  (Note the commas: single-element tuples.)


### Why do we normalize MNIST with `(0.1307,), (0.3081,)`?

After `ToTensor()`, MNIST images are scaled to `[0,1]`, but their distribution isn’t centered and doesn’t have unit variance:
- Mean pixel value is about **0.1307** (most pixels are dark background).
- Standard deviation is about **0.3081**.

**Why normalize?**
- Centering (subtracting the mean) makes neuron inputs hover around zero, which helps gradients flow and speeds up learning.
- Scaling (dividing by the std) puts features on a comparable scale, making optimization more stable and less sensitive to learning rates.

**Why those exact numbers?**
- They are the empirical mean and std of the MNIST training set computed over all pixels.
- Using dataset-specific stats is better than generic choices (like 0.5/0.5) because it matches the true data distribution.

**Why the tuples — and why the comma is so important?**
- `Normalize` expects a *sequence* (list or tuple) with one value per channel.
- MNIST has 1 channel → we need one mean and one std → a 1-element tuple.
- In Python:
  - `(0.3081,)` → a tuple containing one float ✅
  - `(0.3081)` → just a float ❌
- If you forget the comma, you pass a float instead of a tuple. That breaks the shape handling inside `Normalize` and can lead to confusing errors (like “std evaluated to zero”).

**What if we skip normalization?**
- The model may still learn (MNIST is simple), but:
  - Training is slower.
  - Optimization is less stable.
  - Accuracy may plateau lower.
- For harder datasets (like CIFAR or ImageNet), skipping normalization can mean the model fails to learn at all.

**TL;DR**
Normalization with `(0.1307,), (0.3081,)` standardizes MNIST inputs to zero-like mean and unit-like variance.  
The trailing comma is crucial because it makes those values tuples, not plain floats, which is exactly what `Normalize` expects.


In [4]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

train_mnist = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
test_mnist = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)

# quick peek
x0, y0 = train_mnist[0]
print("one sample:", x0.shape, y0)  # torch.Size([1, 28, 28]) label_int


100%|██████████| 9.91M/9.91M [00:00<00:00, 38.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.05MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.51MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.02MB/s]

one sample: torch.Size([1, 28, 28]) 5





## Model

A simple fully-connected classifier:
- Flatten 28×28 → 784
- Hidden: 300 → 300 with LeakyReLU
- Output: 10 logits (no Softmax; CrossEntropyLoss expects logits)


In [5]:
model = torch.nn.Sequential(
    torch.nn.Linear(28*28, 300),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(300, 300),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(300, 10)  # logits
).to(device)

sum_params = sum(p.numel() for p in model.parameters())
print("model params:", sum_params)


model params: 328810


## Sanity check (single example)

Flatten to 784 features, forward once, confirm output shape [1, 10].


In [6]:
digit, cls = train_mnist[0]
digit = digit.to(device).view(1, 28*28)  # add batch dim = 1
with torch.no_grad():
    out = model(digit)
print("single forward shape:", out.shape)  # torch.Size([1, 10])


single forward shape: torch.Size([1, 10])


## Sanity check with dataset loop (first item only)

Iterate the dataset, move to device, flatten, run model, print shape, break.


In [7]:
for digit, cls in train_mnist:
    digit = digit.to(device)
    digit = digit.view(digit.shape[0], 28*28)
    with torch.no_grad():
        print(model(digit).shape)  # expected: torch.Size([1, 10])
    break


torch.Size([1, 10])
