In [None]:
# FSNN top-k (fixed gather) — Param-parity vs Compute-parity on MNIST
import time, random, math, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
set_seed(42)

# ----- data -----
BATCH_SIZE = 128
D = 28*28; NUM_CLASSES = 10
transform = transforms.ToTensor()
train_ds = datasets.MNIST("./data", train=True,  download=True, transform=transform)
test_ds  = datasets.MNIST("./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

def count_params(m): return sum(p.numel() for p in m.parameters() if p.requires_grad)

@torch.no_grad()
def latency_bs1(model, d=D, iters=120, warmup=30):
    model.eval()
    x = torch.randn(1, d, device=device)
    for _ in range(warmup):
        _ = model(x);  torch.cuda.synchronize() if device.type=="cuda" else None
    ts=[]
    for _ in range(iters):
        torch.cuda.synchronize() if device.type=="cuda" else None
        t0=time.perf_counter(); _=model(x)
        torch.cuda.synchronize() if device.type=="cuda" else None
        ts.append(time.perf_counter()-t0)
    ts=np.array(ts);  return float(np.median(ts)), float(np.percentile(ts,90)), float(np.mean(ts))

def evaluate(model):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    with torch.no_grad():
        for xb, yb in test_loader:
            xb = xb.view(xb.size(0), -1).to(device); yb = yb.to(device)
            logits = model(xb)
            loss_sum += F.cross_entropy(logits, yb, reduction="sum").item()
            pred = logits.argmax(1); correct += (pred==yb).sum().item(); total += yb.numel()
    return loss_sum/total, correct/total

def train(model, epochs=2, lr=1e-3, wd=0.0):
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    model.train()
    for ep in range(1, epochs+1):
        run, n = 0.0, 0
        for xb, yb in train_loader:
            xb = xb.view(xb.size(0), -1).to(device); yb = yb.to(device)
            opt.zero_grad(set_to_none=True)
            logits = model(xb); loss = F.cross_entropy(logits, yb)
            loss.backward(); opt.step()
            run += loss.item()*xb.size(0); n += xb.size(0)
        print(f"epoch {ep}: train_loss={run/n:.4f}")

# ----- Baseline -----
class BaselineMLP(nn.Module):
    def __init__(self, d, hidden, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(d, hidden)
        self.fc2 = nn.Linear(hidden, d)
        self.head = nn.Linear(d, num_classes)
        self.act = nn.GELU()
        for m in [self.fc1, self.fc2, self.head]:
            nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)
    def forward(self, x):
        x = self.act(self.fc1(x)); x = self.act(self.fc2(x)); return self.head(x)

# ----- FSNN-MoM (two CPU/GPU-friendly variants) -----
class FSNNMoMTopK(nn.Module):
    """
    True top-k compute: we gather only the selected experts' weights/biases and
    compute them. Robust gather handles any rank (W1: [M,d,h], W2: [M,h,d], b: [M,*]).
    """
    def __init__(self, d, num_classes, M=6, k=2, h=64):
        super().__init__(); assert 1 <= k <= M
        self.M, self.k, self.d, self.h = M, k, d, h
        self.W1 = nn.Parameter(torch.empty(M, d, h));   self.b1 = nn.Parameter(torch.zeros(M, h))
        self.W2 = nn.Parameter(torch.empty(M, h, d));   self.b2 = nn.Parameter(torch.zeros(M, d))
        nn.init.xavier_uniform_(self.W1); nn.init.xavier_uniform_(self.W2)
        self.router = nn.Linear(d, M)
        self.head   = nn.Linear(d, num_classes)
        nn.init.xavier_uniform_(self.router.weight); nn.init.zeros_(self.router.bias)
        nn.init.xavier_uniform_(self.head.weight);   nn.init.zeros_(self.head.bias)
        self.act = nn.GELU()

    def _gather_expert(self, t, topi):
        # t: [M, ...], topi: [B, k]; returns gathered [B, k, ...] along t's first dim
        B, k = topi.shape
        # Expand t over batch: [B, M, ...]
        t_exp = t.unsqueeze(0).expand(B, *t.shape)
        # Build index with enough singleton dims to match t_exp
        add_dims = t.dim() - 1
        idx = topi.view(B, k, *([1]*add_dims)).expand(B, k, *t.shape[1:])
        return t_exp.gather(dim=1, index=idx)

    def forward(self, x):             # x: [B, d]
        B = x.size(0)
        logits = self.router(x)       # [B, M]
        topv, topi = torch.topk(logits, k=self.k, dim=-1)   # [B, k]
        weights = torch.softmax(topv, dim=-1)               # [B, k]

        W1_sel = self._gather_expert(self.W1, topi)  # [B, k, d, h]
        b1_sel = self._gather_expert(self.b1, topi)  # [B, k, h]
        W2_sel = self._gather_expert(self.W2, topi)  # [B, k, h, d]
        b2_sel = self._gather_expert(self.b2, topi)  # [B, k, d]

        y1 = torch.einsum('bd,bkdh->bkh', x, W1_sel) + b1_sel
        y1 = self.act(y1)
        y  = torch.einsum('bkh,bkhd->bkd', y1, W2_sel) + b2_sel
        y  = (y * weights.unsqueeze(-1)).sum(dim=1)         # [B, d]
        return self.head(y)

class FSNNMoMAll(nn.Module):
    """
    Compute-all variant: compute all M experts in one batched matmul, then mask to top-k.
    This can be faster on CPU for small M because it avoids gathers.
    """
    def __init__(self, d, num_classes, M=6, k=2, h=64):
        super().__init__(); assert 1 <= k <= M
        self.M, self.k, self.d, self.h = M, k, d, h
        self.W1 = nn.Parameter(torch.empty(M, d, h));   self.b1 = nn.Parameter(torch.zeros(M, h))
        self.W2 = nn.Parameter(torch.empty(M, h, d));   self.b2 = nn.Parameter(torch.zeros(M, d))
        nn.init.xavier_uniform_(self.W1); nn.init.xavier_uniform_(self.W2)
        self.router = nn.Linear(d, M); self.head = nn.Linear(d, num_classes)
        nn.init.xavier_uniform_(self.router.weight); nn.init.zeros_(self.router.bias)
        nn.init.xavier_uniform_(self.head.weight);   nn.init.zeros_(self.head.bias)
        self.act = nn.GELU()

    def forward(self, x):             # x: [B, d]
        B = x.size(0)
        logits = self.router(x)       # [B, M]
        topv, topi = torch.topk(logits, k=self.k, dim=-1)
        weights = torch.softmax(topv, dim=-1)               # [B, k]
        # Compute all experts: [B, M, d]
        y1 = torch.einsum('bd,mdh->bmh', x, self.W1) + self.b1
        y1 = self.act(y1)
        y_all = torch.einsum('bmh,mhd->bmd', y1, self.W2) + self.b2
        # Build mask with top-k weights
        mask = torch.zeros(B, self.M, device=x.device)
        mask.scatter_(1, topi, weights)
        y = (y_all * mask.unsqueeze(-1)).sum(dim=1)         # [B, d]
        return self.head(y)

# ----- comparison harness -----
def run_case(title, baseline_hidden, mom_cfg, epochs=2, variant="topk"):
    print(f"\n=== {title} | FSNN variant: {variant} ===")
    # Baseline
    base = BaselineMLP(D, baseline_hidden, NUM_CLASSES).to(device)
    train(base, epochs=epochs, lr=1e-3)
    lb, ab = evaluate(base); p50b, p90b, _ = latency_bs1(base); pb = count_params(base)
    print(f"Baseline: params={pb:,}  acc={ab:.4f}  p50={p50b*1e3:.2f} ms  p90={p90b*1e3:.2f} ms")

    # FSNN
    FSNN = FSNNMoMTopK if variant=="topk" else FSNNMoMAll
    fsnn = FSNN(D, NUM_CLASSES, **mom_cfg).to(device)
    train(fsnn, epochs=epochs, lr=1e-3)
    lf, af = evaluate(fsnn); p50f, p90f, _ = latency_bs1(fsnn); pf = count_params(fsnn)
    k, h = mom_cfg['k'], mom_cfg['h']
    print(f"FSNN  : params={pf:,}  acc={af:.4f}  p50={p50f*1e3:.2f} ms  p90={p90f*1e3:.2f} ms  (active width ~ {k*h})")

# 1) PARAM-PARITY: match baseline hidden to M*h, but FSNN computes only k experts
run_case("Param-Parity (sparse compute)", baseline_hidden=384,
         mom_cfg=dict(M=6, k=2, h=64), epochs=2, variant="topk")

# For CPU, also try the compute-all masking variant:
run_case("Param-Parity (sparse compute)", baseline_hidden=384,
         mom_cfg=dict(M=6, k=2, h=64), epochs=2, variant="all")

# 2) COMPUTE-PARITY: same active compute (k*h = 128) vs baseline hidden=128
run_case("Compute-Parity (richer capacity)", baseline_hidden=128,
         mom_cfg=dict(M=6, k=2, h=64), epochs=2, variant="topk")
