
# 10‑Minute PyTorch: MNIST with a Twist — Parity & Prime

**Idea:** train a tiny CNN on MNIST but **not** to predict the digit itself.  
Instead, predict two fun labels:
- **Parity**: even vs odd (0/1)
- **Prime**: prime vs non‑prime (2,3,5,7 are prime digits)

**Why this is fun**
- You can **draw your own digit** in a canvas cell and see the the model decide parity/prime instantly.
- It shows multi‑task learning with a shared encoder and two small heads.


In [None]:

# Install (run once)
!pip -q install torch torchvision matplotlib scikit-learn pillow



## 1) Imports & configuration


In [None]:

import time, random, math, io
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision import datasets, transforms

from sklearn.metrics import classification_report, confusion_matrix

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

TRAIN_SIZE = 15000
VAL_SIZE   = 3000
BATCH_SIZE = 128
EPOCHS     = 4
LR         = 1e-3



## 2) Data: MNIST → parity & prime labels


In [None]:

transform = transforms.ToTensor()

train_full = datasets.MNIST(root='./data', train=True,  download=True, transform=transform)
test_full  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

import numpy as np
rng = np.random.default_rng(42)
idx_train = rng.choice(len(train_full), size=TRAIN_SIZE, replace=False)
idx_val   = rng.choice(len(test_full),  size=VAL_SIZE,   replace=False)

train_ds = Subset(train_full, idx_train)
val_ds   = Subset(test_full,  idx_val)

def make_labels(y):
    y = int(y)
    parity = y % 2
    prime = 1 if y in {2,3,5,7} else 0
    return parity, prime

class TwistWrapper(Dataset):
    def __init__(self, base):
        self.base = base
    def __len__(self):
        return len(self.base)
    def __getitem__(self, i):
        x, y_digit = self.base[i]
        parity, prime = make_labels(y_digit)
        return x, torch.tensor(parity, dtype=torch.long), torch.tensor(prime, dtype=torch.long), y_digit

train_wrap = TwistWrapper(train_ds)
val_wrap   = TwistWrapper(val_ds)

train_loader = DataLoader(train_wrap, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_wrap,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

len(train_wrap), len(val_wrap)



## 3) Model: shared CNN encoder + two heads


In [None]:

class TinyTwistNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.head_parity = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32*7*7, 32), nn.ReLU(),
            nn.Linear(32, 2)
        )
        self.head_prime = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32*7*7, 32), nn.ReLU(),
            nn.Linear(32, 2)
        )
    def forward(self, x):
        z = self.encoder(x)
        return self.head_parity(z), self.head_prime(z)

model = TinyTwistNet().to(device)
opt = optim.Adam(model.parameters(), lr=LR)
crit = nn.CrossEntropyLoss()

sum(p.numel() for p in model.parameters())



## 4) Train (multi‑task)


In [None]:

def step_metrics(logits, ytrue):
    pred = logits.argmax(dim=1)
    acc = (pred == ytrue).float().mean().item()
    return acc

def run_epoch(loader, train=True):
    if train: model.train()
    else:     model.eval()
    total_loss, n = 0.0, 0
    acc_parity, acc_prime = 0.0, 0.0
    for xb, ypar, yprm, _ydig in loader:
        xb, ypar, yprm = xb.to(device), ypar.to(device), yprm.to(device)
        with torch.set_grad_enabled(train):
            lp, lq = model(xb)
            loss = crit(lp, ypar) + crit(lq, yprm)
            if train:
                opt.zero_grad(); loss.backward(); opt.step()
        bsz = xb.size(0)
        total_loss += loss.item() * bsz; n += bsz
        acc_parity += step_metrics(lp, ypar) * bsz
        acc_prime  += step_metrics(lq, yprm) * bsz
    return total_loss/n, acc_parity/n, acc_prime/n

history = {'train':[], 'val':[]}
t0 = time.time()
for epoch in range(1, EPOCHS+1):
    tr_loss, tr_acc_p, tr_acc_q = run_epoch(train_loader, train=True)
    va_loss, va_acc_p, va_acc_q = run_epoch(val_loader,   train=False)
    history['train'].append((tr_loss, tr_acc_p, tr_acc_q))
    history['val'].append((va_loss, va_acc_p, va_acc_q))
    print(f"Epoch {epoch:02d} | loss T/V: {tr_loss:.3f}/{va_loss:.3f} | parity acc T/V: {tr_acc_p:.3f}/{va_acc_p:.3f} | prime acc T/V: {tr_acc_q:.3f}/{va_acc_q:.3f}")
print(f"Training time: {time.time()-t0:.1f}s")



### Quick curves (accuracy)


In [None]:

tr_p = [x[1] for x in history['train']]
va_p = [x[1] for x in history['val']]
tr_q = [x[2] for x in history['train']]
va_q = [x[2] for x in history['val']]

plt.figure(figsize=(6,3.5))
plt.plot(tr_p, label='train parity')
plt.plot(va_p, label='val parity')
plt.xlabel('epoch'); plt.ylabel('accuracy'); plt.title('Parity accuracy'); plt.legend()
plt.show()

plt.figure(figsize=(6,3.5))
plt.plot(tr_q, label='train prime')
plt.plot(va_q, label='val prime')
plt.xlabel('epoch'); plt.ylabel('accuracy'); plt.title('Prime accuracy'); plt.legend()
plt.show()



## 5) Evaluation report (on the validation subset)


In [None]:

model.eval()
all_par_logits, all_prm_logits = [], []
all_par_true, all_prm_true = [], []
with torch.no_grad():
    for xb, ypar, yprm, _ydig in val_loader:
        lp, lq = model(xb.to(device))
        all_par_logits.append(lp.cpu()); all_prm_logits.append(lq.cpu())
        all_par_true.append(ypar);       all_prm_true.append(yprm)
par_logits = torch.cat(all_par_logits); prm_logits = torch.cat(all_prm_logits)
par_true   = torch.cat(all_par_true).numpy(); prm_true = torch.cat(all_prm_true).numpy()
par_pred   = par_logits.argmax(1).numpy();    prm_pred = prm_logits.argmax(1).numpy()

print("=== Parity (0=even, 1=odd) ===")
print(classification_report(par_true, par_pred, digits=4))
print("=== Prime (0=non-prime, 1=prime) ===")
print(classification_report(prm_true, prm_pred, digits=4))



## 6) Draw your own digit ✍️ and test


In [None]:

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch

CANVAS_SIZE = 280
BRUSH = 14

def run_canvas():
    img = np.ones((CANVAS_SIZE, CANVAS_SIZE), dtype=np.float32)
    drawing = False
    last = None

    fig, ax = plt.subplots(figsize=(4,4))
    ax.imshow(img, vmin=0, vmax=1, cmap='gray', interpolation='nearest')
    ax.set_title("Draw a digit. Left-drag; press ENTER to finish.")
    ax.axis('off')

    def paint(x, y):
        yy, xx = np.ogrid[:CANVAS_SIZE, :CANVAS_SIZE]
        mask = (xx - x)**2 + (yy - y)**2 <= BRUSH**2
        img[mask] = 0.0

    def on_press(event):
        nonlocal drawing, last
        if event.inaxes != ax: return
        if event.button == 1:
            drawing = True
            x, y = int(event.xdata+0.5), int(event.ydata+0.5)
            paint(x,y); last = (x,y)
            ax.images[0].set_data(img); fig.canvas.draw_idle()

    def on_move(event):
        nonlocal last
        if not drawing or event.inaxes != ax: return
        x, y = int(event.xdata+0.5), int(event.ydata+0.5)
        if last is not None:
            lx, ly = last
            steps = max(abs(x-lx), abs(y-ly)) + 1
            for t in range(steps):
                xi = int(lx + (x-lx)*t/steps)
                yi = int(ly + (y-ly)*t/steps)
                paint(xi, yi)
        else:
            paint(x,y)
        last = (x,y)
        ax.images[0].set_data(img); fig.canvas.draw_idle()

    def on_release(event):
        nonlocal drawing, last
        if event.button == 1:
            drawing = False
            last = None

    def on_key(event):
        if event.key == 'enter':
            plt.close(fig)

    fig.canvas.mpl_connect('button_press_event', on_press)
    fig.canvas.mpl_connect('motion_notify_event', on_move)
    fig.canvas.mpl_connect('button_release_event', on_release)
    fig.canvas.mpl_connect('key_press_event', on_key)

    plt.show()

    small = Image.fromarray((img*255).astype(np.uint8)).resize((28,28), Image.BILINEAR)
    arr = np.asarray(small).astype(np.float32)/255.0
    x28 = 1.0 - arr
    x_tensor = torch.from_numpy(x28[None,None,:,:]).float()
    return x_tensor

try:
    x_draw = run_canvas().to(device)
    with torch.no_grad():
        lp, lq = model(x_draw)
        pred_par = lp.argmax(1).item()
        pred_prm = lq.argmax(1).item()
    print(f"Predictions → Parity: {'odd' if pred_par==1 else 'even'} | Prime: {'prime' if pred_prm==1 else 'non-prime'}")
except Exception as e:
    print("Canvas failed (likely no GUI). Use the next cell to load an image instead.\n", e)



### Alternative: load a 28×28 PNG and test


In [None]:

def predict_image(path):
    img = Image.open(path).convert('L').resize((28,28))
    arr = np.array(img).astype(np.float32)/255.0
    x28 = 1.0 - arr
    x_tensor = torch.from_numpy(x28[None,None,:,:]).float().to(device)
    with torch.no_grad():
        lp, lq = model(x_tensor)
        pred_par = lp.argmax(1).item()
        pred_prm = lq.argmax(1).item()
    print(f"Predictions → Parity: {'odd' if pred_par==1 else 'even'} | Prime: {'prime' if pred_prm==1 else 'non-prime'}")

# predict_image('my_digit_28x28.png')



## 7) Wrap‑up
Multi‑task learning with a shared CNN is compact and effective. The **interactive canvas** makes it memorable:
you draw a digit, the model classifies **even/odd** and **prime/non‑prime**.
