
# PyTorch in ~10 Minutes: **MNIST with a Twist (Parity & Prime)** — JupyterLab Friendly

**What does this tuto do:**
- Train a tiny **multi‑task Neural Network (CNN for convolutional -> perfect for 2D data like images or DEM)** on MNIST **subset** to predict:
  - **Parity** — even vs odd
  - **Prime** — prime vs non‑prime (2,3,5,7 are prime)
- Then **draw your own digit** in a JupyterLab‑friendly canvas and get instant predictions.


In [None]:

# ▶️ Install (run once). ipycanvas gives a JupyterLab-friendly drawing widget.
!pip -q install torch torchvision matplotlib scikit-learn pillow ipywidgets ipycanvas
# In classic Notebook you may need: jupyter nbextension enable --py widgetsnbextension



## 0) **Hyperparameters to tweak** (only ones you should touch during the demo)

Keep everything else stable. These are safe to play with live:
- `TRAIN_SIZE`, `VAL_SIZE` — subset sizes (speed vs accuracy) TRAIN = directly feed to the NN for calculating the weights, VAL = Validation of the NN performances (independant dataset)
- `BATCH_SIZE` — memory vs gradient noise
- `EPOCHS` — more epochs → better accuracy (to a point) -> Number of training iterations
- `LR` — learning rate; too high explodes, too low crawls


In [None]:

# 🔧 Hyperparameters
SEED       = 42 # Random seed for deterministic random number generations
# How many images are used to TRAIN the model
TRAIN_SIZE = 15000   # out of 60k in the OG dataset - images are small so it is not that big - like 70 Mb
# How many images are used to VALIDATE the model
VAL_SIZE   = 3000    # out of 10k
BATCH_SIZE = 128
EPOCHS     = 20       # still quick
LR         = 1e-3



## 1) Imports & basic setup

This is boilerplate: libraries, device (CPU/GPU), and reproducibility.


In [None]:

import time, random, math, io
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision import datasets, transforms

from sklearn.metrics import classification_report

# Widgets for the JupyterLab drawing pad
from ipywidgets import VBox, HBox, Button, Label, HTML
from ipycanvas import Canvas

# Reproducibility
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)



## 2) Download & prepare dataset **(what's happening here)**

- We download **MNIST** (28×28 grayscale). Each sample is a **digit** (0–9).
- We **relabel** each digit into **two targets**:
  - `parity` = 0 if even, 1 if odd
  - `prime`  = 1 if digit in {2,3,5,7} else 0  
- We make **train/validation** splits:
  - **Train**: used to update weights.
  - **Validation**: never used to update, only to **judge generalization**.
- For speed, we use **small subsets**.


In [None]:

# Minimal transform: tensors in [0,1]
transform = transforms.ToTensor()

train_full = datasets.MNIST(root='./data', train=True,  download=True, transform=transform)
test_full  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Subsample for speed
rng = np.random.default_rng(SEED)
idx_train = rng.choice(len(train_full), size=TRAIN_SIZE, replace=False)
idx_val   = rng.choice(len(test_full),  size=VAL_SIZE,   replace=False)
train_ds = Subset(train_full, idx_train)
val_ds   = Subset(test_full,  idx_val)

def make_labels(y):
    y = int(y)
    parity = y % 2              # 0 even, 1 odd
    prime = 1 if y in {2,3,5,7} else 0
    return parity, prime

class TwistWrapper(Dataset):
    """Wrap MNIST to emit (image, parity, prime, digit)"""
    def __init__(self, base): self.base = base
    def __len__(self): return len(self.base)
    def __getitem__(self, i):
        x, y_digit = self.base[i]
        parity, prime = make_labels(y_digit)
        return x, torch.tensor(parity), torch.tensor(prime), y_digit

train_wrap = TwistWrapper(train_ds)
val_wrap   = TwistWrapper(val_ds)

train_loader = DataLoader(train_wrap, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_wrap,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

len(train_wrap), len(val_wrap)



### Quick peek at a few samples

Always sanity‑check: are shapes and labels what we expect?


In [None]:

xb, ypar, yprm, ydig = next(iter(train_loader))
fig, axes = plt.subplots(1, 6, figsize=(9,2))
for i, ax in enumerate(axes):
    ax.imshow(xb[i,0].numpy(), cmap='gray')
    ax.set_title(f"d={int(ydig[i])}\npar={int(ypar[i])},prm={int(yprm[i])}", fontsize=9)
    ax.axis('off')
plt.show()



## 3) Model architecture **(what & why)**

- A tiny **CNN encoder** extracts features from 28×28 images.
- Two small **heads** (linear layers) make task‑specific predictions:
  - **Parity head** → 2 classes (even/odd)
  - **Prime head**  → 2 classes (non‑prime/prime)
- Loss = **CE(parity)** + **CE(prime)** so both tasks learn together.


In [None]:

class TinyTwistNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),               # 28->14
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),               # 14->7
        )
        self.head_parity = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32*7*7, 32), nn.ReLU(),
            nn.Linear(32, 2)
        )
        self.head_prime = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32*7*7, 32), nn.ReLU(),
            nn.Linear(32, 2)
        )
    def forward(self, x):
        z = self.encoder(x)
        return self.head_parity(z), self.head_prime(z)

model = TinyTwistNet().to(device)
opt = optim.Adam(model.parameters(), lr=LR)
crit = nn.CrossEntropyLoss()

print('Total parameters:', sum(p.numel() for p in model.parameters()))



## 4) Training loop **(how it learns)**

For each mini‑batch:
1. Forward pass → **two logits** (parity, prime)  
2. Compute two **cross‑entropy** losses → **sum**  
3. Backprop & optimizer step  
We track **accuracy** for both tasks on **train** and **validation**.


In [None]:

def step_acc(logits, y):
    return (logits.argmax(1) == y).float().mean().item()

def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    total, n = 0.0, 0
    acc_p, acc_q = 0.0, 0.0
    for xb, ypar, yprm, _ in loader:
        xb, ypar, yprm = xb.to(device), ypar.to(device), yprm.to(device)
        with torch.set_grad_enabled(train):
            lp, lq = model(xb)
            loss = crit(lp, ypar) + crit(lq, yprm)
            if train:
                opt.zero_grad(); loss.backward(); opt.step()
        bsz = xb.size(0); n += bsz; total += loss.item()*bsz
        acc_p += step_acc(lp, ypar)*bsz
        acc_q += step_acc(lq, yprm)*bsz
    return total/n, acc_p/n, acc_q/n

history = {'train':[], 'val':[]}
t0 = time.time()
for epoch in range(1, EPOCHS+1):
    tr_loss, tr_p, tr_q = run_epoch(train_loader, True)
    va_loss, va_p, va_q = run_epoch(val_loader, False)
    history['train'].append((tr_loss, tr_p, tr_q))
    history['val'].append((va_loss, va_p, va_q))
    print(f"Epoch {epoch:02d} | loss T/V: {tr_loss:.3f}/{va_loss:.3f} | acc parity T/V: {tr_p:.3f}/{va_p:.3f} | acc prime T/V: {tr_q:.3f}/{va_q:.3f}")
print(f"Training time: {time.time()-t0:.1f}s")



### 5) Quick curves (did it converge?)


In [None]:

tr_p = [x[1] for x in history['train']]; va_p = [x[1] for x in history['val']]
tr_q = [x[2] for x in history['train']]; va_q = [x[2] for x in history['val']]

plt.figure(figsize=(6,3.5))
plt.plot(tr_p, label='train parity'); plt.plot(va_p, label='val parity')
plt.xlabel('epoch'); plt.ylabel('accuracy'); plt.title('Parity accuracy'); plt.legend(); plt.show()

plt.figure(figsize=(6,3.5))
plt.plot(tr_q, label='train prime'); plt.plot(va_q, label='val prime')
plt.xlabel('epoch'); plt.ylabel('accuracy'); plt.title('Prime accuracy'); plt.legend(); plt.show()



## 6) Validation report (numbers to show)

We compute a compact **classification report** for each task on the validation subset.


In [None]:

model.eval()
par_true, par_pred, prm_true, prm_pred = [], [], [], []
with torch.no_grad():
    for xb, ypar, yprm, _ in val_loader:
        lp, lq = model(xb.to(device))
        par_true.append(ypar); prm_true.append(yprm)
        par_pred.append(lp.cpu().argmax(1)); prm_pred.append(lq.cpu().argmax(1))
par_true = torch.cat(par_true).numpy(); par_pred = torch.cat(par_pred).numpy()
prm_true = torch.cat(prm_true).numpy(); prm_pred = torch.cat(prm_pred).numpy()

print("=== Parity (0=even, 1=odd) ===\n", classification_report(par_true, par_pred, digits=4))
print("=== Prime (0=non‑prime, 1=prime) ===\n", classification_report(prm_true, prm_pred, digits=4))



## 7) **Draw your own digit** (JupyterLab‑friendly)

Use the canvas below. Click‑drag to draw, **Clear** to reset, then **Predict**.  
We downsample to 28×28 like MNIST and run the model.


In [None]:
from ipywidgets import VBox, HBox, Button, HTML, IntSlider, Layout
from ipycanvas import Canvas
import numpy as np, torch
from PIL import Image

# --- Canvas config (internal vs displayed size) ---
nx, ny = 128, 128            # internal resolution used for downsampling to 28×28
display_px = 220             # visual size on screen (shrink/grow here)

canvas = Canvas(
    width=nx, height=ny, sync_image_data=True,
    layout=Layout(width=f'{display_px}px', height=f'{display_px}px', border='1px solid #ccc')
)
canvas.fill_style = 'white'; canvas.fill_rect(0,0,nx,ny)
canvas.stroke_style = 'black'

brush = IntSlider(value=10, min=3, max=28, step=1, description='Brush', continuous_update=True)
canvas.line_width = brush.value

def _set_brush(change):
    canvas.line_width = change['new']
brush.observe(_set_brush, names='value')

# --- Drawing handlers ---
is_drawing, last = False, None
def handle_mousedown(x, y):
    global is_drawing, last
    is_drawing, last = True, (x, y)

def handle_mousemove(x, y):
    global is_drawing, last
    if not is_drawing: return
    lx, ly = last if last else (x, y)
    canvas.begin_path(); canvas.move_to(lx, ly); canvas.line_to(x, y); canvas.stroke()
    last = (x, y)

def handle_mouseup(x, y):
    global is_drawing, last
    is_drawing, last = False, None

canvas.on_mouse_down(handle_mousedown)
canvas.on_mouse_move(handle_mousemove)
canvas.on_mouse_up(handle_mouseup)

# --- Buttons & helpers ---
btn_clear = Button(description='Clear', button_style='warning')
btn_pred  = Button(description='Predict', button_style='success')
out_lbl   = HTML(value='Draw a digit, then click <b>Predict</b>.')

def clear_canvas(_):
    canvas.fill_style = 'white'; canvas.fill_rect(0,0,nx,ny)
btn_clear.on_click(clear_canvas)

def preprocess_canvas_to_tensor():
    data = np.asarray(canvas.get_image_data())     # (H,W,4) uint8
    rgb  = data[..., :3].astype(np.float32)
    gray = (0.299*rgb[...,0] + 0.587*rgb[...,1] + 0.114*rgb[...,2]) / 255.0
    img  = Image.fromarray((gray*255).astype(np.uint8)).resize((28,28), Image.BILINEAR)
    arr  = np.asarray(img).astype(np.float32)/255.0
    x28  = 1.0 - arr                               # white digit on black like MNIST
    return torch.from_numpy(x28[None, None, :, :]).float().to(device)

def predict_canvas(_):
    with torch.no_grad():
        x = preprocess_canvas_to_tensor()
        lp, lq = model(x)
        par = lp.argmax(1).item(); prm = lq.argmax(1).item()
    out_lbl.value = f"<b>Prediction:</b> Parity = {'odd' if par==1 else 'even'} | Prime = {'prime' if prm==1 else 'non-prime'}"

btn_pred.on_click(predict_canvas)

VBox([canvas, HBox([btn_clear, btn_pred, brush]), out_lbl])



### (Optional) Load a 28×28 PNG

If you prefer an image file: must be **28×28 grayscale**; white digit on black is expected (we invert if needed).


In [None]:

def predict_image(path):
    img = Image.open(path).convert('L').resize((28,28))
    arr = np.array(img).astype(np.float32)/255.0
    x28 = 1.0 - arr
    x_tensor = torch.from_numpy(x28[None,None,:,:]).float().to(device)
    with torch.no_grad():
        lp, lq = model(x_tensor)
        pred_par = lp.argmax(1).item()
        pred_prm = lq.argmax(1).item()
    print(f"Parity: {'odd' if pred_par==1 else 'even'} | Prime: {'prime' if pred_prm==1 else 'non‑prime'}")

# predict_image('my_digit_28x28.png')



## 8) Wrap‑up (talk track)

- **This cell**: model architecture (shared encoder + two heads).  
- **This cell**: dataset download + relabel to parity/prime.  
- **This cell**: hyperparameters are isolated so you can **safely tweak** them live.  
- **Training cell**: shows the **whole learning loop** is a few lines.  
- **Widgets cell**: audience **draws new data** → instant parity/prime.  

**Extensions (if time)**: add third head (multiple of 3), log confusion matrices, or export to TorchScript.
