In [1]:

# Cell 0: Imports & device
import math, random, copy, time
from dataclasses import dataclass
from typing import List, Tuple

import numpy as np
from PIL import Image, ImageDraw

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
torch.manual_seed(42); np.random.seed(42); random.seed(42)


Using device: cuda


In [2]:

# Cell 1: Shape drawing utilities
IMG_SZ = 32
CLASSES = ["circle", "square", "triangle", "star"]
CLASS_TO_ID = {c:i for i,c in enumerate(CLASSES)}

def draw_shape(shape:str)->Image.Image:
    img = Image.new("L", (IMG_SZ, IMG_SZ), color=0)  # grayscale
    d = ImageDraw.Draw(img)
    pad = random.randint(4, 8)
    x0, y0 = pad, pad
    x1, y1 = IMG_SZ-pad, IMG_SZ-pad
    if shape == "circle":
        d.ellipse([x0,y0,x1,y1], outline=255, width=random.randint(1,3))
    elif shape == "square":
        d.rectangle([x0,y0,x1,y1], outline=255, width=random.randint(1,3))
    elif shape == "triangle":
        p1 = (IMG_SZ//2, y0)
        p2 = (x0, y1)
        p3 = (x1, y1)
        d.polygon([p1,p2,p3], outline=255)
    elif shape == "star":
        # simple 5-point star
        cx, cy, r = IMG_SZ//2, IMG_SZ//2, (IMG_SZ//2 - pad)
        pts = []
        for k in range(10):
            ang = k * math.pi/5
            rk = r if k%2==0 else r//2
            pts.append((cx + int(rk*math.cos(ang)), cy + int(rk*math.sin(ang))))
        d.polygon(pts, outline=255)
    # light jitter
    return img

def to_tensor(img:Image.Image)->torch.Tensor:
    arr = np.array(img, dtype=np.float32)/255.0
    return torch.from_numpy(arr).unsqueeze(0)  # [1,H,W]


In [3]:

# Cell 2: Torch Dataset classes
class ShapeDataset(Dataset):
    def __init__(self, classes:List[str], n:int, label_noise:float=0.0):
        self.samples = []
        for _ in range(n):
            cls = random.choice(classes)
            img = draw_shape(cls)
            y = CLASS_TO_ID[cls]
            if random.random() < label_noise:  # optional noise
                y = random.randint(0,len(CLASSES)-1)
            self.samples.append((to_tensor(img), y))
    def __len__(self): return len(self.samples)
    def __getitem__(self, i):
        x,y = self.samples[i]
        return x, torch.tensor(y, dtype=torch.long)

def make_orbit_private(classes:List[str], n_train:int, n_val:int)->Tuple[Dataset,Dataset]:
    return ShapeDataset(classes, n_train), ShapeDataset(classes, n_val)

def make_synthetic_mixed(n:int)->Dataset:
    # server's synthetic pool mixes all classes (Phase-1 output stand-in)
    return ShapeDataset(CLASSES, n)


In [4]:

# Cell 3: CNNs with BatchNorm (teachers) & smaller student
class TeacherCNN(nn.Module):
    def __init__(self, width: int = 32, n_classes: int = 4):
        super().__init__()
        self.conv1 = nn.Conv2d(1, width, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(width)
        self.conv2 = nn.Conv2d(width, width, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(width)
        self.conv3 = nn.Conv2d(width, width*2, 3, padding=1)
        self.bn3   = nn.BatchNorm2d(width*2)
        self.head  = nn.Linear((IMG_SZ//4)*(IMG_SZ//4)*width*2, n_classes)
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
        x = F.max_pool2d(F.relu(self.bn3(self.conv3(x))), 2)
        x = x.view(x.size(0), -1)
        return self.head(x)

class StudentCNN(nn.Module):
    def __init__(self, width: int = 16, n_classes: int = 4):
        super().__init__()
        self.conv1 = nn.Conv2d(1, width, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(width)
        self.conv2 = nn.Conv2d(width, width, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(width)
        self.head  = nn.Linear((IMG_SZ//4)*(IMG_SZ//4)*width, n_classes)
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        return self.head(x)


In [5]:

# Cell 4: Make private datasets for 5 orbits (non-IID)
orbits = [
    ["circle", "square"],           # Orbit 1: 2 classes
    ["triangle", "star"],           # Orbit 2: 2 classes
    ["circle", "triangle", "star"], # Orbit 3: 3 classes
    ["square", "triangle"],         # Orbit 4: 2 classes
    ["circle", "star"]              # Orbit 5: 2 classes
]

private_train = []
private_val   = []
for cls_subset in orbits:
    tr, va = make_orbit_private(cls_subset, n_train=600, n_val=150)
    private_train.append(tr)
    private_val.append(va)

len(private_train), len(private_val)


(5, 5)

In [6]:

# Cell 5: Utilities: train one teacher quickly
def train_one_teacher(model, train_ds, val_ds, epochs=5, lr=1e-3, bs=64):
    model = model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    train_loader = DataLoader(train_ds, batch_size=bs, shuffle=True)
    val_loader   = DataLoader(val_ds, batch_size=bs)

    for ep in range(epochs):
        model.train()
        total=0
        correct=0
        loss_sum=0
        for x,y in train_loader:
            x,y = x.to(device), y.to(device)
            logits = model(x)
            loss = F.cross_entropy(logits, y)
            opt.zero_grad(); loss.backward(); opt.step()
            loss_sum += loss.item()*x.size(0)
            pred = logits.argmax(1); correct += (pred==y).sum().item(); total += x.size(0)
        # val
        model.eval(); vtotal=0; vcorrect=0
        with torch.no_grad():
            for x,y in val_loader:
                x,y = x.to(device), y.to(device)
                out = model(x)
                vcorrect += (out.argmax(1)==y).sum().item(); vtotal += x.size(0)
        print(f"Epoch {ep+1}: train loss {loss_sum/total:.3f}, train acc {correct/total:.2f}, val acc {vcorrect/vtotal:.2f}")
    return model

# Train 5 teachers
teachers = []
for i, (tr, va) in enumerate(zip(private_train, private_val), start=1):
    print(f"\nTraining Teacher (Orbit {i}) on classes {orbits[i-1]}")
    t = TeacherCNN(width=32).to(device)
    t = train_one_teacher(t, tr, va, epochs=5, lr=1e-3, bs=64)
    teachers.append(t.eval())  # freeze on server



Training Teacher (Orbit 1) on classes ['circle', 'square']
Epoch 1: train loss 0.295, train acc 0.88, val acc 0.46
Epoch 2: train loss 0.003, train acc 1.00, val acc 0.46
Epoch 3: train loss 0.001, train acc 1.00, val acc 0.94
Epoch 4: train loss 0.000, train acc 1.00, val acc 1.00
Epoch 5: train loss 0.000, train acc 1.00, val acc 1.00

Training Teacher (Orbit 2) on classes ['triangle', 'star']
Epoch 1: train loss 0.174, train acc 0.93, val acc 1.00
Epoch 2: train loss 0.000, train acc 1.00, val acc 1.00
Epoch 3: train loss 0.000, train acc 1.00, val acc 1.00
Epoch 4: train loss 0.000, train acc 1.00, val acc 1.00
Epoch 5: train loss 0.000, train acc 1.00, val acc 1.00

Training Teacher (Orbit 3) on classes ['circle', 'triangle', 'star']
Epoch 1: train loss 0.279, train acc 0.89, val acc 0.86
Epoch 2: train loss 0.001, train acc 1.00, val acc 1.00
Epoch 3: train loss 0.000, train acc 1.00, val acc 1.00
Epoch 4: train loss 0.000, train acc 1.00, val acc 1.00
Epoch 5: train loss 0.000,

In [7]:

# Cell 6: Server synthetic pool (mix of all classes)
synthetic_ds = make_synthetic_mixed(n=1200)
synthetic_loader = DataLoader(synthetic_ds, batch_size=64, shuffle=True)

# A small validation set
synthetic_val = make_synthetic_mixed(n=300)
synthetic_val_loader = DataLoader(synthetic_val, batch_size=64)


In [8]:

# Cell 7: KD utilities
@dataclass
class KDConfig:
    T: float = 4.0          # temperature
    lr: float = 1e-3
    epochs: int = 10
    bs: int = 64
    clip_grad: float = 1.0

def kd_train(student:nn.Module,
             teachers:List[nn.Module],
             synth_loader,
             val_loader,
             cfg:KDConfig):


    student = student.to(device)
    opt = torch.optim.Adam(student.parameters(), lr=cfg.lr)

    for ep in range(cfg.epochs):
        student.train()
        loss_sum = 0
        total = 0

        for x,y in synth_loader:

            x = x.to(device)

            with torch.no_grad():
                # Ensemble teacher logits
                logits_list = [m(x) for m in teachers]
                logits_teacher = torch.stack(logits_list).mean(dim=0)   # [B,C]

            # Student logits
            logits_student = student(x)

            # Temperature-softened distributions
            p = F.softmax(logits_teacher / cfg.T, dim=1)               # teacher probs
            q_log = F.log_softmax(logits_student / cfg.T, dim=1)       # student log-probs

            # KL loss (batchmean) scaled by T^2 (common KD practice)
            loss = F.kl_div(q_log, p, reduction="batchmean") * (cfg.T*cfg.T)
            opt.zero_grad()
            loss.backward()

            if cfg.clip_grad:
                nn.utils.clip_grad_norm_(student.parameters(), cfg.clip_grad)

            opt.step()

            loss_sum += loss.item()*x.size(0); total += x.size(0)

        # quick validation: top-1 accuracy on synthetic val set
        student.eval()
        vtotal=0
        vcorrect=0
        with torch.no_grad():
            for x,y in val_loader:
                x,y = x.to(device), y.to(device)
                out = student(x)
                vcorrect += (out.argmax(1)==y).sum().item()
                vtotal += x.size(0)
        print(f"KD Epoch {ep+1}: train KL {loss_sum/total:.4f}, val acc {vcorrect/vtotal:.3f}")
    return student.eval()

# Initialize student (smaller CNN) and run KD
student = StudentCNN(width=16).to(device)
kd_cfg = KDConfig(T=4.0, lr=1e-3, epochs=10, bs=64)
student = kd_train(student, teachers, synthetic_loader, synthetic_val_loader, kd_cfg)



KD Epoch 1: train KL 0.7917, val acc 0.757
KD Epoch 2: train KL 0.0282, val acc 0.757
KD Epoch 3: train KL 0.0098, val acc 0.757
KD Epoch 4: train KL 0.0081, val acc 0.757
KD Epoch 5: train KL 0.0035, val acc 0.757
KD Epoch 6: train KL 0.0032, val acc 0.757
KD Epoch 7: train KL 0.0030, val acc 0.757
KD Epoch 8: train KL 0.0033, val acc 0.757
KD Epoch 9: train KL 0.0038, val acc 0.757
KD Epoch 10: train KL 0.0024, val acc 0.757


In [10]:

# Cell 8: Final evaluation on fresh mixed test
test_ds = make_synthetic_mixed(n=500)
test_loader = DataLoader(test_ds, batch_size=64)
student.eval(); total=0; correct=0
with torch.no_grad():
    for x,y in test_loader:
        x,y = x.to(device), y.to(device)
        out = student(x)
        correct += (out.argmax(1)==y).sum().item(); total += x.size(0)
print(f"Final Student Accuracy on fresh synthetic test: {correct/total:.3f}")



Final Student Accuracy on fresh synthetic test: 0.730
