# PyTorch + MNIST (Single Digit, Solid) — Robust Preproc + Augmentation + JupyterLab Canvas

**Goal:** Train a small but **reliable** digit recognizer (0–9) that behaves well on tricky **7/9** and hand-drawn inputs.

**What makes this version sturdier**
- **Data augmentation** (rotations, shifts, shear) during training → more invariant features.  
- **BatchNorm + Dropout** in the CNN → stable + less overfit.  
- **Preprocessing for drawings** that mimics MNIST's centering: resize longest side to **20**, then **pad to 28×28**, and **center by mass**.  
- Consistent **normalization** (MNIST mean/std) for both training and your canvas input.

In [None]:
# ▶️ Install (run once). ipycanvas = JupyterLab-friendly drawing widget.
!pip -q install torch torchvision matplotlib pillow ipywidgets ipycanvas
# In classic Notebook you may need: jupyter nbextension enable --py widgetsnbextension

## 0) Hyperparameters (safe knobs for the demo)

In [None]:
SEED       = 42
TRAIN_SIZE = 40000
VAL_SIZE   = 5000
BATCH_SIZE = 128
EPOCHS     = 8
LR         = 1e-3

# Canvas UI
CANVAS_INTERNAL = 224
DISPLAY_PX      = 240
BRUSH_INIT      = 14

## 1) Imports & basic setup — what happens here

In [None]:
import time, random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

from ipywidgets import VBox, HBox, Button, HTML, IntSlider, Layout
from ipycanvas import Canvas

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

## 2) Dataset download + **augmentation** + normalization

In [None]:
MNIST_MEAN, MNIST_STD = 0.1307, 0.3081

train_tfms = transforms.Compose([
    transforms.RandomAffine(
        degrees=20, translate=(0.1, 0.1), shear=10, fill=0
    ),
    transforms.ToTensor(),
    transforms.Normalize((MNIST_MEAN,), (MNIST_STD,)),
])

val_tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((MNIST_MEAN,), (MNIST_STD,)),
])

train_full = datasets.MNIST(root='./data', train=True, download=True, transform=train_tfms)
test_full  = datasets.MNIST(root='./data', train=False, download=True, transform=val_tfms)

rng = np.random.default_rng(SEED)
idx_train = rng.choice(len(train_full), size=min(TRAIN_SIZE, len(train_full)), replace=False)
idx_val   = rng.choice(len(test_full),  size=min(VAL_SIZE, len(test_full)), replace=False)

train_ds = Subset(train_full, idx_train)
val_ds   = Subset(test_full,  idx_val)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

len(train_ds), len(val_ds)

In [None]:
xb, yb = next(iter(DataLoader(Subset(datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor()), idx_train), batch_size=6)))
fig, axes = plt.subplots(1, 6, figsize=(9,2))
for i, ax in enumerate(axes):
    ax.imshow(xb[i,0].numpy(), cmap='gray'); ax.set_title(int(yb[i])); ax.axis('off')
plt.show()

## 3) Model architecture — small CNN, but sturdier

In [None]:
class SolidDigitCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.2),
            nn.Linear(64*7*7, 128), nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 10)
        )
    def forward(self, x):
        return self.classifier(self.features(x))

model = SolidDigitCNN().to(device)
opt = optim.Adam(model.parameters(), lr=LR)
crit = nn.CrossEntropyLoss()
print('Parameters:', sum(p.numel() for p in model.parameters()))

## 4) Training loop — forward → loss → backward → step

In [None]:
def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    total, correct, n = 0.0, 0, 0
    with torch.set_grad_enabled(train):
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            logits = model(x)
            loss = crit(logits, y)
            if train:
                opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item()*x.size(0)
            correct += (logits.argmax(1) == y).sum().item()
            n += x.size(0)
    return total/n, correct/n

history = {'train_acc':[], 'val_acc':[]}
t0 = time.time()
for epoch in range(1, EPOCHS+1):
    tr_loss, tr_acc = run_epoch(train_loader, True)
    va_loss, va_acc = run_epoch(val_loader,   False)
    history['train_acc'].append(tr_acc); history['val_acc'].append(va_acc)
    print(f"Epoch {epoch:02d} | train acc={tr_acc*100:.2f}% | val acc={va_acc*100:.2f}%")
print(f"Training time: {time.time()-t0:.1f}s")

In [None]:
plt.figure(figsize=(5.2,3.2))
plt.plot(history['val_acc'])
plt.xlabel('epoch'); plt.ylabel('val acc'); plt.title('Validation accuracy'); plt.ylim(0,1); plt.show()

## 6) Robust preprocessing for hand-drawn digits (MNIST-like)

In [None]:
from scipy import ndimage as ndi

def preprocess_to_mnist(x_gray01):
    inv = 1.0 - x_gray01
    h,w = inv.shape
    scale = 20.0 / max(h,w)
    pil = Image.fromarray((inv*255).astype(np.uint8)).resize(
        (max(1, int(w*scale)), max(1, int(h*scale))), Image.BILINEAR
    )
    small = np.array(pil).astype(np.float32)/255.0
    pad = np.pad(small, ((14,14),(14,14)), mode='constant')
    cy, cx = ndi.center_of_mass(pad) if pad.sum()>0 else (pad.shape[0]/2, pad.shape[1]/2)
    cy, cx = int(cy), int(cx)
    y0, x0 = cy-14, cx-14
    canvas = pad[y0:y0+28, x0:x0+28]
    if canvas.shape != (28,28):
        canvas = pad[:28,:28]
    tens = torch.from_numpy(canvas[None,None,:,:]).float()
    tens = (tens - MNIST_MEAN) / MNIST_STD
    return tens

## 7) Draw **one digit** (JupyterLab-friendly) and predict

In [None]:
nx = ny = CANVAS_INTERNAL
canvas = Canvas(
    width=nx, height=ny, sync_image_data=True,
    layout=Layout(width=f'{DISPLAY_PX}px', height=f'{DISPLAY_PX}px', border='1px solid #ccc')
)
canvas.fill_style = 'white'; canvas.fill_rect(0,0,nx,ny)
canvas.stroke_style = 'black'

brush = IntSlider(value=BRUSH_INIT, min=6, max=24, step=1, description='Brush')
canvas.line_width = brush.value
def _set_brush(ch): canvas.line_width = ch['new']
brush.observe(_set_brush, names='value')

is_drawing, last = False, None
def on_down(x, y):
    global is_drawing, last
    is_drawing, last = True, (x,y)
def on_move(x, y):
    global is_drawing, last
    if not is_drawing: return
    lx, ly = last if last else (x,y)
    canvas.begin_path(); canvas.move_to(lx, ly); canvas.line_to(x, y); canvas.stroke()
    last = (x,y)
def on_up(x, y):
    global is_drawing, last
    is_drawing, last = False, None
canvas.on_mouse_down(on_down); canvas.on_mouse_move(on_move); canvas.on_mouse_up(on_up)

btn_clear = Button(description='Clear', button_style='warning')
btn_pred  = Button(description='Predict', button_style='success')
out_lbl   = HTML(value='Draw one digit, then click <b>Predict</b>.')

def clear_canvas(_):
    canvas.fill_style = 'white'; canvas.fill_rect(0,0,nx,ny)
btn_clear.on_click(clear_canvas)

def predict_canvas(_):
    data = np.asarray(canvas.get_image_data())
    rgb  = data[..., :3].astype(np.float32)
    gray = (0.299*rgb[...,0] + 0.587*rgb[...,1] + 0.114*rgb[...,2]) / 255.0
    x = preprocess_to_mnist(gray)
    with torch.no_grad():
        logits = model(x.to(device))
        pred = logits.argmax(1).item()
        conf = torch.softmax(logits, dim=1).max(1).values.item()
    out_lbl.value = f"<b>Prediction:</b> {pred}  (confidence {conf:.2f})"

btn_pred.on_click(predict_canvas)

VBox([canvas, HBox([btn_clear, btn_pred, brush]), out_lbl])

## 8) Wrap-up (talk track)

- **Augmentations** toughen the model (rot/shift/shear).  
- **BatchNorm + Dropout** stabilize and regularize.  
- **MNIST-like centering** of your drawing (resize→pad→center of mass) reduces confusions like **7 vs 9**.  
- Same **normalization** for training and inference keeps behavior consistent.