## Cell 0: Setup


In [2]:

import random, copy, time
from dataclasses import dataclass
from typing import List
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset, DataLoader
from torchvision import datasets, transforms, models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
torch.manual_seed(42); np.random.seed(42); random.seed(42)

# Speed knob: set FAST=True for quick runs (fewer samples/epochs)
FAST = False



Device: cuda


# Cell 1: MNIST loaders (resize to 224, 3-channels for ResNet)
# Paper used MNIST and CIFAR‑10; we adapt MNIST to ResNet input shape.
# Reference dataset: Deng, 2012. [1](https://o365khu-my.sharepoint.com/personal/2025315503_office_khu_ac_kr/Documents/Microsoft%20Copilot%20Chat%20Files/One-Shot%20Federated%20Learning%20for%20LEO%20Constellations.pdf)


In [3]:


tf_train = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])
tf_test = tf_train  # same normalization for test/distillation

mnist_train = datasets.MNIST(root="./data", train=True, download=True, transform=tf_train)
mnist_test  = datasets.MNIST(root="./data", train=False, download=True, transform=tf_test)

print("Train size:", len(mnist_train), "Test size:", len(mnist_test))


Train size: 60000 Test size: 10000



# Cell 2: Non‑IID orbit splits (5 orbits; 2 with 4 classes, 3 with 6 classes)
# Mirrors the paper's non‑IID setting (4 vs 6 classes per orbit). [1](https://o365khu-my.sharepoint.com/personal/2025315503_office_khu_ac_kr/Documents/Microsoft%20Copilot%20Chat%20Files/One-Shot%20Federated%20Learning%20for%20LEO%20Constellations.pdf


In [4]:

# Cell 2: Non‑IID orbit splits (5 orbits; 2 with 4 classes, 3 with 6 classes)
# Mirrors the paper's non‑IID setting (4 vs 6 classes per orbit). [1](https://o365khu-my.sharepoint.com/personal/2025315503_office_khu_ac_kr/Documents/Microsoft%20Copilot%20Chat%20Files/One-Shot%20Federated%20Learning%20for%20LEO%20Constellations.pdf)

# Fixed deterministic split
orbit_labels = [
    [0,1,2,3],                # Orbit 1 (4 classes)
    [4,5,6,7],                # Orbit 2 (4 classes)
    [0,1,2,3,4,5],            # Orbit 3 (6 classes)
    [6,7,8,9,0,1],            # Orbit 4 (6 classes)
    [2,3,4,5,8,9]             # Orbit 5 (6 classes)
]

def idxs_for_labels(dataset, labels):
    return [i for i,(x,y) in enumerate(dataset) if y in labels]

train_orbit_subsets, val_orbit_subsets = [], []
for lbls in orbit_labels:
    idxs = idxs_for_labels(mnist_train, lbls)
    random.shuffle(idxs)
    split = int(0.85 * len(idxs))  # 85/15 train/val
    tr_idx, va_idx = idxs[:split], idxs[split:]
    if FAST:
        tr_idx = tr_idx[:800]  # shrink per orbit for speed
        va_idx = va_idx[:200]
    train_orbit_subsets.append(Subset(mnist_train, tr_idx))
    val_orbit_subsets.append(Subset(mnist_train, va_idx))

for i,lbls in enumerate(orbit_labels, 1):
    print(f"Orbit {i}: labels {lbls}, train {len(train_orbit_subsets[i-1])}, val {len(val_orbit_subsets[i-1])}")


Orbit 1: labels [0, 1, 2, 3], train 21040, val 3714
Orbit 2: labels [4, 5, 6, 7], train 19929, val 3517
Orbit 3: labels [0, 1, 2, 3, 4, 5], train 30614, val 5403
Orbit 4: labels [6, 7, 8, 9, 0, 1], train 31150, val 5498
Orbit 5: labels [2, 3, 4, 5, 8, 9], train 29879, val 5273



# Cell 3: Build ResNet-50 teachers (first conv already 3-channels; we train from scratch)
# Paper trained ResNet‑50 locally on clients and used ResNet‑18 as the student at server. [1](https://o365khu-my.sharepoint.com/personal/2025315503_office_khu_ac_kr/Documents/Microsoft%20Copilot%20Chat%20Files/One-Shot%20Federated%20Learning%20for%20LEO%20Constellations.pdf)

In [16]:


def make_resnet50(n_classes=10):
    m = models.resnet50(weights=None)  # no pretrain (offline-friendly)
    m.fc = nn.Linear(m.fc.in_features, n_classes)
    return m

def make_resnet18(n_classes=10):
    m = models.resnet18(weights=None)
    m.fc = nn.Linear(m.fc.in_features, n_classes)
    return m



# Cell 4: Train one teacher (ResNet‑50) on an orbit split


In [9]:

def train_teacher(model, train_subset, val_subset, epochs=3, lr=1e-3, bs=64):
    model = model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    train_loader = DataLoader(train_subset, batch_size=bs, shuffle=True, num_workers=2, pin_memory=True)
    val_loader   = DataLoader(val_subset, batch_size=bs, num_workers=2, pin_memory=True)
    for ep in range(epochs):
        model.train(); total=0; correct=0; loss_sum=0.0
        for x,y in train_loader:
            x,y = x.to(device), y.to(device)
            logits = model(x)
            loss   = F.cross_entropy(logits, y)
            opt.zero_grad(); loss.backward(); opt.step()
            loss_sum += loss.item() * x.size(0)
            correct  += (logits.argmax(1)==y).sum().item()
            total    += x.size(0)
        # val
        model.eval(); vtot=0; vcor=0
        with torch.no_grad():
            for x,y in val_loader:
                x,y = x.to(device), y.to(device)
                out = model(x)
                vcor += (out.argmax(1)==y).sum().item()
                vtot += x.size(0)
        print(f"Teacher ep {ep+1}: train loss {loss_sum/total:.3f}, train acc {correct/total:.3f}, val acc {vcor/vtot:.3f}")
    return model.eval()

# Train 5 teachers
teachers = []
for i,(tr,va) in enumerate(zip(train_orbit_subsets, val_orbit_subsets), 1):
    print(f"\nTraining ResNet-50 Teacher for Orbit {i} (labels {orbit_labels[i-1]})")
    t = make_resnet50()
    t = train_teacher(t, tr, va, epochs=(2 if FAST else 6), lr=1e-3, bs=(32 if FAST else 64))
    teachers.append(t)



Training ResNet-50 Teacher for Orbit 1 (labels [0, 1, 2, 3])
Teacher ep 1: train loss 0.712, train acc 0.751, val acc 0.310
Teacher ep 2: train loss 0.167, train acc 0.948, val acc 0.920

Training ResNet-50 Teacher for Orbit 2 (labels [4, 5, 6, 7])
Teacher ep 1: train loss 0.668, train acc 0.769, val acc 0.495
Teacher ep 2: train loss 0.138, train acc 0.948, val acc 0.920

Training ResNet-50 Teacher for Orbit 3 (labels [0, 1, 2, 3, 4, 5])
Teacher ep 1: train loss 1.046, train acc 0.637, val acc 0.140
Teacher ep 2: train loss 0.265, train acc 0.919, val acc 0.775

Training ResNet-50 Teacher for Orbit 4 (labels [6, 7, 8, 9, 0, 1])
Teacher ep 1: train loss 0.894, train acc 0.664, val acc 0.590
Teacher ep 2: train loss 0.284, train acc 0.909, val acc 0.890

Training ResNet-50 Teacher for Orbit 5 (labels [2, 3, 4, 5, 8, 9])
Teacher ep 1: train loss 1.166, train acc 0.590, val acc 0.445
Teacher ep 2: train loss 0.316, train acc 0.891, val acc 0.815


# Cell 5: Server distillation dataset (proxy)


In [10]:

# In the paper, Phase 1 generates synthetic data; here we use MNIST test as the server's proxy for KD,
# which is a standard practice to demonstrate Phase 2 mechanics (data-free generator omitted for brevity). [1](https://o365khu-my.sharepoint.com/personal/2025315503_office_khu_ac_kr/Documents/Microsoft%20Copilot%20Chat%20Files/One-Shot%20Federated%20Learning%20for%20LEO%20Constellations.pdf)

kd_loader   = DataLoader(mnist_test, batch_size=(64 if FAST else 128), shuffle=True, num_workers=2, pin_memory=True)
eval_loader = DataLoader(mnist_test, batch_size=(64 if FAST else 128), num_workers=2, pin_memory=True)

# Quick evaluation utilities
def eval_acc(model, loader):
    model.eval(); tot=0; cor=0
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            out = model(x)
            cor += (out.argmax(1)==y).sum().item(); tot += x.size(0)
    return cor/tot

def eval_ensemble_acc(teachers, loader, weights=None):
    for m in teachers: m.eval()
    tot=0; cor=0
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            logits = [m(x) for m in teachers]
            if weights is None:
                ens = torch.stack(logits).mean(dim=0)
            else:
                w = torch.tensor(weights, device=device).view(-1,1,1)
                ens = (torch.stack(logits)*w).sum(dim=0)
            cor += (ens.argmax(1)==y).sum().item(); tot += x.size(0)
    return cor/tot

# Teacher and ensemble baseline accuracy (sanity check)
teacher_accs = [eval_acc(t, eval_loader) for t in teachers]
ens_acc = eval_ensemble_acc(teachers, eval_loader)
print("\nTeacher accuracies:", teacher_accs)
print("Unweighted ensemble accuracy:", ens_acc)



Teacher accuracies: [0.3893, 0.3578, 0.4571, 0.5294, 0.5036]
Unweighted ensemble accuracy: 0.5672


# Cell 6: Phase 2 KD — train ResNet‑18 student to match teacher ensemble (KL loss with temperature)


In [11]:

# Paper's KD objective: R_KL^S = KL(D_teacher, D_student); student updates by SGD/Adam. [1](https://o365khu-my.sharepoint.com/personal/2025315503_office_khu_ac_kr/Documents/Microsoft%20Copilot%20Chat%20Files/One-Shot%20Federated%20Learning%20for%20LEO%20Constellations.pdf)

@dataclass
class KDConfig:
    T: float = 3.0           # temperature
    lr: float = 5e-4
    epochs: int = (6 if FAST else 12)
    clip_grad: float = 1.0
    conf_th: float = 0.70     # keep confident teacher consensus
    use_weights: bool = True  # weight teachers by accuracy

# Teacher weights from their eval accuracy
weights = np.array(teacher_accs)
weights = weights / (weights.sum() + 1e-8)

student = make_resnet18().to(device)
opt = torch.optim.SGD(student.parameters(), lr=KDConfig.lr, momentum=0.9, weight_decay=1e-4)

for ep in range(KDConfig.epochs):
    student.train(); loss_sum=0; kept=0
    for x,_ in kd_loader:
        x = x.to(device)
        with torch.no_grad():
            logits_list = [m(x) for m in teachers]
            if KDConfig.use_weights:
                w = torch.tensor(weights, device=device).view(-1,1,1)
                D_teacher = (torch.stack(logits_list) * w).sum(dim=0)
            else:
                D_teacher = torch.stack(logits_list).mean(dim=0)
        D_student = student(x)

        # Temperature-softened distributions
        p = F.softmax(D_teacher / KDConfig.T, dim=1)
        q_log = F.log_softmax(D_student / KDConfig.T, dim=1)

        # Confidence filter: only learn from strong consensus
        maxp, _ = p.max(dim=1)
        mask = (maxp >= KDConfig.conf_th)
        if mask.sum() == 0:
            continue
        kept += mask.sum().item()
        p_sel  = p[mask]
        q_log_sel = q_log[mask]

        loss = F.kl_div(q_log_sel, p_sel, reduction='batchmean') * (KDConfig.T**2)
        opt.zero_grad(); loss.backward()
        if KDConfig.clip_grad: nn.utils.clip_grad_norm_(student.parameters(), KDConfig.clip_grad)
        opt.step()
        loss_sum += loss.item()

    # quick eval
    student.eval(); tot=0; cor=0
    with torch.no_grad():
        for x,y in eval_loader:
            x,y = x.to(device), y.to(device)
            out = student(x)
            cor += (out.argmax(1)==y).sum().item(); tot += x.size(0)
    print(f"KD Epoch {ep+1}: train KL {loss_sum/max(1,kept):.4f}, eval acc {cor/tot:.3f}, kept {kept}")


KD Epoch 1: train KL 1.9537, eval acc 0.114, kept 531
KD Epoch 2: train KL 0.5747, eval acc 0.114, kept 531
KD Epoch 3: train KL 0.3968, eval acc 0.114, kept 531
KD Epoch 4: train KL 0.3871, eval acc 0.114, kept 531
KD Epoch 5: train KL 0.2814, eval acc 0.114, kept 531
KD Epoch 6: train KL 0.2656, eval acc 0.126, kept 531


# Cell 7: Final evaluation + (optional) Phase 3 virtual retraining


In [12]:

print("Final student accuracy:", eval_acc(student, eval_loader))

# OPTIONAL: Server-local virtual retraining (Phase 3 idea from the paper). [1](https://o365khu-my.sharepoint.com/personal/2025315503_office_khu_ac_kr/Documents/Microsoft%20Copilot%20Chat%20Files/One-Shot%20Federated%20Learning%20for%20LEO%20Constellations.pdf)
# Clone several virtual students, train each on an "orbit-style" partition of MNIST test, then average.

def subset_by_labels(dataset, labels, limit=None):
    idxs = [i for i,(x,y) in enumerate(dataset) if y in labels]
    if FAST and limit: idxs = idxs[:limit]
    return Subset(dataset, idxs)

parts = [
    subset_by_labels(mnist_test, orbit_labels[2], limit=1000),  # mimic a 6-class partition
    subset_by_labels(mnist_test, orbit_labels[3], limit=1000),
    subset_by_labels(mnist_test, orbit_labels[4], limit=1000),
]

def train_virtual(init_model, ds, epochs=(1 if FAST else 3), lr=5e-4, bs=64):
    m = copy.deepcopy(init_model).to(device)
    opt = torch.optim.Adam(m.parameters(), lr=lr)
    loader = DataLoader(ds, batch_size=bs, shuffle=True, num_workers=2, pin_memory=True)
    for _ in range(epochs):
        m.train()
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            loss = F.cross_entropy(m(x), y)
            opt.zero_grad(); loss.backward(); opt.step()
    return m.eval()

def avg_state_dicts(dicts):
    avg = {}
    for k in dicts[0].keys():
        avg[k] = sum(d[k] for d in dicts) / len(dicts)
    return avg

# One virtual round (optional)
virtual_states = []
for ds in parts:
    vm = train_virtual(student, ds, epochs=(1 if FAST else 2), lr=5e-4, bs=64)
    virtual_states.append(vm.state_dict())
student.load_state_dict(avg_state_dicts(virtual_states))

print("After virtual retraining, student accuracy:", eval_acc(student, eval_loader))


Final student accuracy: 0.126
After virtual retraining, student accuracy: 0.2424
