In [1]:
!pip install torch torchvision



In [2]:

# Cell 0: Imports, device, seeds
import math, random, copy, time
from dataclasses import dataclass
from typing import List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
torch.manual_seed(42); np.random.seed(42); random.seed(42)


Using device: cuda


In [4]:

# Cell 1: MNIST loading (train as private; test as public proxy for KD)
#   In LEOShot the server generates synthetic data (no public set).
#   For this runnable demo, we use MNIST test images as the "server distillation" proxy.
#   (A very common practice in FedMD-style KD demos.)

transform = transforms.Compose([
    transforms.ToTensor(),                      # [0,1]
    transforms.Normalize((0.1307,), (0.3081,))  # standard MNIST normalization
])

mnist_train = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
mnist_test  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

print("Train size:", len(mnist_train), "Test size:", len(mnist_test))


100%|██████████| 9.91M/9.91M [00:01<00:00, 5.41MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 160kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.52MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 27.9MB/s]

Train size: 60000 Test size: 10000





In [5]:

# Cell 2: Build non-IID orbit splits for teachers
# We'll create 5 orbits with different label sets.
orbits_labels = [
    [0,1,2],         # orbit 1
    [3,4],           # orbit 2
    [5,6],           # orbit 3
    [7,8],           # orbit 4
    [9,0,1]          # orbit 5 (overlaps to reflect heterogeneity)
]

def indices_for_labels(dataset, labels):
    idxs = [i for i,(x,y) in enumerate(dataset) if y in labels]
    return idxs

orbit_train_subsets = []
orbit_val_subsets   = []

# Split per-orbit train into train/val (80/20)
for lbls in orbits_labels:
    idxs = indices_for_labels(mnist_train, lbls)
    random.shuffle(idxs)
    split = int(0.8 * len(idxs))
    train_idxs, val_idxs = idxs[:split], idxs[split:]
    orbit_train_subsets.append(Subset(mnist_train, train_idxs))
    orbit_val_subsets.append(Subset(mnist_train, val_idxs))

for i,lbls in enumerate(orbits_labels, start=1):
    print(f"Orbit {i} labels {lbls}: train {len(orbit_train_subsets[i-1])}, val {len(orbit_val_subsets[i-1])}")


Orbit 1 labels [0, 1, 2]: train 14898, val 3725
Orbit 2 labels [3, 4]: train 9578, val 2395
Orbit 3 labels [5, 6]: train 9071, val 2268
Orbit 4 labels [7, 8]: train 9692, val 2424
Orbit 5 labels [9, 0, 1]: train 14891, val 3723


In [6]:

# Cell 3: Define heterogeneous teacher CNNs and a smaller student
# Teachers vary widths/depths to simulate heterogeneous architectures.
class TeacherA(nn.Module):
    def __init__(self, width=32, n_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, width, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(width)
        self.conv2 = nn.Conv2d(width, width, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(width)
        self.conv3 = nn.Conv2d(width, width*2, 3, padding=1)
        self.bn3   = nn.BatchNorm2d(width*2)
        self.fc    = nn.Linear(7*7*width*2, n_classes)
    def forward(self,x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))),2)
        x = F.max_pool2d(F.relu(self.bn3(self.conv3(x))),2)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class TeacherB(nn.Module):
    def __init__(self, width=24, n_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, width, 5, padding=2)
        self.bn1   = nn.BatchNorm2d(width)
        self.conv2 = nn.Conv2d(width, width*2, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(width*2)
        self.fc    = nn.Linear(7*7*width*2, n_classes)
    def forward(self,x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))),2)
        x = F.max_pool2d(x,2)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class StudentSmall(nn.Module):
    def __init__(self, width=16, n_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, width, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(width)
        self.conv2 = nn.Conv2d(width, width, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(width)
        self.fc    = nn.Linear(7*7*width, n_classes)
    def forward(self,x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))),2)
        x = F.max_pool2d(x,2)
        x = x.view(x.size(0), -1)
        return self.fc(x)


In [7]:

# Cell 4: Train one teacher
def train_teacher(model, train_subset, val_subset, epochs=3, lr=1e-3, bs=128):
    model = model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    train_loader = DataLoader(train_subset, batch_size=bs, shuffle=True)
    val_loader   = DataLoader(val_subset, batch_size=bs)
    for ep in range(epochs):
        model.train(); total=0; correct=0; loss_sum=0
        for x,y in train_loader:
            x,y = x.to(device), y.to(device)
            logits = model(x)
            loss   = F.cross_entropy(logits, y)
            opt.zero_grad(); loss.backward(); opt.step()
            loss_sum += loss.item()*x.size(0)
            pred = logits.argmax(1); correct += (pred==y).sum().item(); total += x.size(0)
        # quick val
        model.eval(); vtotal=0; vcorrect=0
        with torch.no_grad():
            for x,y in val_loader:
                x,y = x.to(device), y.to(device)
                out = model(x)
                vcorrect += (out.argmax(1)==y).sum().item(); vtotal += x.size(0)
        print(f"Epoch {ep+1}: train loss {loss_sum/total:.3f}, train acc {correct/total:.3f}, val acc {vcorrect/vtotal:.3f}")
    return model.eval()

# Train 5 teachers with heterogeneous architectures
teachers = []
for i,(tr,va) in enumerate(zip(orbit_train_subsets, orbit_val_subsets), start=1):
    print(f"\nTraining Teacher for Orbit {i}, labels {orbits_labels[i-1]}")
    model = TeacherA(width=32) if i%2==1 else TeacherB(width=24)
    t = train_teacher(model, tr, va, epochs=3, lr=1e-3, bs=128)
    teachers.append(t)



Training Teacher for Orbit 1, labels [0, 1, 2]
Epoch 1: train loss 0.073, train acc 0.978, val acc 0.993
Epoch 2: train loss 0.012, train acc 0.996, val acc 0.995
Epoch 3: train loss 0.007, train acc 0.998, val acc 0.995

Training Teacher for Orbit 2, labels [3, 4]
Epoch 1: train loss 0.073, train acc 0.974, val acc 1.000
Epoch 2: train loss 0.009, train acc 0.997, val acc 0.999
Epoch 3: train loss 0.005, train acc 0.999, val acc 1.000

Training Teacher for Orbit 3, labels [5, 6]
Epoch 1: train loss 0.078, train acc 0.976, val acc 0.994
Epoch 2: train loss 0.012, train acc 0.996, val acc 0.996
Epoch 3: train loss 0.006, train acc 0.999, val acc 0.997

Training Teacher for Orbit 4, labels [7, 8]
Epoch 1: train loss 0.110, train acc 0.955, val acc 0.993
Epoch 2: train loss 0.012, train acc 0.997, val acc 0.995
Epoch 3: train loss 0.009, train acc 0.997, val acc 0.997

Training Teacher for Orbit 5, labels [9, 0, 1]
Epoch 1: train loss 0.063, train acc 0.981, val acc 0.995
Epoch 2: train 

In [9]:

# Cell 5: Build the server "distillation dataset"
# We'll use MNIST TEST split as the proxy that the server can query teachers on.
kd_loader = DataLoader(mnist_test, batch_size=256, shuffle=True)

# A held-out subset of TEST for eval (we'll keep it simple: reuse test for eval)
eval_loader = DataLoader(mnist_test, batch_size=256)



In [12]:

# Cell 6: Phase 2 KD — student learns teacher ensemble consensus (Eq. 16–17)
@dataclass
class KDConfig:
    T: float = 4.0          # temperature
    lr: float = 1e-3
    epochs: int = 8
    clip_grad: float = 1.0

def kd_train(student:nn.Module, teachers:List[nn.Module], kd_loader, eval_loader, cfg:KDConfig):
    student = student.to(device)
    opt = torch.optim.Adam(student.parameters(), lr=cfg.lr)
    for ep in range(cfg.epochs):
        student.train(); loss_sum=0; total=0
        for x,y in kd_loader:
            x = x.to(device)   # NOTE: y not used in KD (we match teachers, not true labels)
            with torch.no_grad():
                # teacher ensemble logits (mean)
                t_logits = [m(x) for m in teachers]
                D_teacher = torch.stack(t_logits).mean(dim=0)      # [B,10]
            # student logits
            D_student = student(x)
            # temperature-softened distributions
            p = F.softmax(D_teacher / cfg.T, dim=1)                # teacher probs
            q_log = F.log_softmax(D_student / cfg.T, dim=1)        # student log-probs
            # KL(p || q) * T^2 (common KD scaling)
            loss = F.kl_div(q_log, p, reduction='batchmean') * (cfg.T*cfg.T)
            opt.zero_grad(); loss.backward()
            if cfg.clip_grad: nn.utils.clip_grad_norm_(student.parameters(), cfg.clip_grad)
            opt.step()
            loss_sum += loss.item()*x.size(0); total += x.size(0)
        # quick top-1 on eval
        student.eval()
        etotal=0
        ecorrect=0
        with torch.no_grad():
            for x,y in eval_loader:
                x,y = x.to(device), y.to(device)
                out = student(x)
                ecorrect += (out.argmax(1)==y).sum().item(); etotal += x.size(0)
        print(f"KD Epoch {ep+1}: train KL {loss_sum/total:.4f}, eval acc {ecorrect/etotal:.3f}")
    return student.eval()

student = StudentSmall(width=16)
kd_cfg  = KDConfig(T=4.0, lr=1e-3, epochs=100)
student = kd_train(student, teachers, kd_loader, eval_loader, kd_cfg)


KD Epoch 1: train KL 0.3885, eval acc 0.211
KD Epoch 2: train KL 0.0480, eval acc 0.211
KD Epoch 3: train KL 0.0316, eval acc 0.211
KD Epoch 4: train KL 0.0244, eval acc 0.211
KD Epoch 5: train KL 0.0219, eval acc 0.211
KD Epoch 6: train KL 0.0197, eval acc 0.211
KD Epoch 7: train KL 0.0176, eval acc 0.211
KD Epoch 8: train KL 0.0169, eval acc 0.211
KD Epoch 9: train KL 0.0156, eval acc 0.211
KD Epoch 10: train KL 0.0146, eval acc 0.211
KD Epoch 11: train KL 0.0146, eval acc 0.211
KD Epoch 12: train KL 0.0138, eval acc 0.211
KD Epoch 13: train KL 0.0149, eval acc 0.211
KD Epoch 14: train KL 0.0128, eval acc 0.211
KD Epoch 15: train KL 0.0122, eval acc 0.211
KD Epoch 16: train KL 0.0125, eval acc 0.211
KD Epoch 17: train KL 0.0131, eval acc 0.211
KD Epoch 18: train KL 0.0117, eval acc 0.211
KD Epoch 19: train KL 0.0115, eval acc 0.211
KD Epoch 20: train KL 0.0116, eval acc 0.211
KD Epoch 21: train KL 0.0110, eval acc 0.211
KD Epoch 22: train KL 0.0112, eval acc 0.211
KD Epoch 23: train 

In [13]:

# Cell 7: Final evaluation
student.eval()
total=0
correct=0
with torch.no_grad():
    for x,y in eval_loader:
        x,y = x.to(device), y.to(device)
        out = student(x)
        correct += (out.argmax(1)==y).sum().item(); total += x.size(0)
print(f"Final Student Accuracy on MNIST test: {correct/total:.3f}")


Final Student Accuracy on MNIST test: 0.211


## 2) (Optional) Server‑local virtual retraining (Phase 3 idea)
This simulates LEOShot’s Phase 3 to further improve accuracy without any new satellite communication:
clone several virtual students, partition the proxy data per orbit labels, train locally, then average.

In [16]:

# Cell 8 (Optional): Partition MNIST test by orbit label sets
def subset_by_labels(dataset, labels, max_items=None):
    idxs = [i for i,(x,y) in enumerate(dataset) if y in labels]
    if max_items: idxs = idxs[:max_items]
    return Subset(dataset, idxs)

parts = [ subset_by_labels(mnist_test, lbls, max_items=400) for lbls in orbits_labels ]

def train_virtual(model_init, ds, epochs=2, lr=5e-4, bs=128):
    m = copy.deepcopy(model_init).to(device)
    opt = torch.optim.Adam(m.parameters(), lr=lr)
    loader = DataLoader(ds, batch_size=bs, shuffle=True)
    for _ in range(epochs):
        m.train()
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            loss = F.cross_entropy(m(x), y)   # standard supervised loss
            opt.zero_grad(); loss.backward(); opt.step()
    return m.eval()

def avg_state_dicts(dicts):
    avg={}
    for k in dicts[0].keys():
        avg[k] = sum(d[k] for d in dicts) / len(dicts)
    return avg

rounds=50
for r in range(rounds):
    v_states=[]
    for ds in parts:
        vm = train_virtual(student, ds, epochs=100, lr=5e-4, bs=128)
        v_states.append(vm.state_dict())
    student.load_state_dict(avg_state_dicts(v_states))

# Re-evaluate
student.eval(); total=0; correct=0
with torch.no_grad():
    for x,y in eval_loader:
        x,y = x.to(device), y.to(device)
        out = student(x)
        correct += (out.argmax(1)==y).sum().item(); total += x.size(0)
print(f"After Virtual Retraining: Student Accuracy: {correct/total:.3f}")


After Virtual Retraining: Student Accuracy: 0.780
