In [1]:
# CIFAR-10: Baseline ResNet-20 vs FSNN-MoM block in stage-3 (channels=64)
# FSNN computes only k of M expert residual branches per block (global top-k per batch).
import time, random, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

def set_seed(s=42):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
set_seed(42)

# ---------- data ----------
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
transform_test = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.CIFAR10("./data", train=True, download=True, transform=transform_train)
test_ds  = datasets.CIFAR10("./data", train=False, download=True, transform=transform_test)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

# ---------- utils ----------
def count_params(m): return sum(p.numel() for p in m.parameters() if p.requires_grad)

@torch.no_grad()
def bs1_latency(model, shape=(1,3,32,32), iters=200, warmup=50):
    model.eval()
    x = torch.randn(*shape, device=device)
    # warmup
    for _ in range(warmup):
        _ = model(x)
        if device.type == "cuda": torch.cuda.synchronize()
    times = []
    for _ in range(iters):
        if device.type == "cuda": torch.cuda.synchronize()
        t0 = time.perf_counter()
        _ = model(x)
        if device.type == "cuda": torch.cuda.synchronize()
        times.append(time.perf_counter() - t0)
    a = np.array(times)
    return float(np.median(a)), float(np.percentile(a,90)), float(np.mean(a))

@torch.no_grad()
def evaluate(model):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss_sum += F.cross_entropy(logits, yb, reduction="sum").item()
        pred = logits.argmax(1)
        correct += (pred==yb).sum().item()
        total += yb.numel()
    return loss_sum/total, correct/total

def train(model, epochs=1, lr=0.1, wd=5e-4):
    model.to(device)
    opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd, nesterov=True)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs*len(train_loader))
    model.train()
    for ep in range(epochs):
        run, n = 0.0, 0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = F.cross_entropy(logits, yb)
            loss.backward()
            opt.step(); sched.step()
            run += loss.item()*xb.size(0); n += xb.size(0)
        print(f"epoch {ep+1}: train_loss={run/n:.4f}")

# ---------- model pieces ----------
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1   = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2   = nn.BatchNorm2d(planes)
        self.downsample = None
        if stride != 1 or in_planes != planes:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes))
        self.act = nn.ReLU(inplace=True)
    def forward(self, x):
        identity = x
        out = self.act(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.act(out + identity)
        return out

class Router(nn.Module):
    """Global-average-pooled router -> logits over M experts (one routing for whole batch)."""
    def __init__(self, channels, M):
        super().__init__()
        self.fc = nn.Linear(channels, M)
        nn.init.xavier_uniform_(self.fc.weight); nn.init.zeros_(self.fc.bias)
    def forward(self, x):
        # x: [B,C,H,W] -> GAP over H,W -> [B,C] -> batch-mean -> [C] -> logits [M]
        b, c, h, w = x.shape
        z = x.mean(dim=[2,3])              # [B, C]
        z = z.mean(dim=0, keepdim=True)    # [1, C] global routing for the batch (fast path)
        return self.fc(z)                  # [1, M]

class ExpertBlock(nn.Module):
    """One expert residual branch (like BasicBlock) with given in/out and stride."""
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1   = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2   = nn.BatchNorm2d(planes)
        self.act   = nn.ReLU(inplace=True)
        self.downsample = None
        if stride != 1 or in_planes != planes:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes))
    def forward(self, x):
        identity = x
        out = self.act(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        return out + identity              # no final ReLU: FSNN block will add and relu

class FSNNMoMBlock(nn.Module):
    """
    FSNN module-of-modules residual block:
      - M experts; router picks global top-k; we compute only those experts
      - weighted sum of selected experts' outputs, then ReLU
    """
    def __init__(self, in_planes, planes, stride=1, M=6, k=2):
        super().__init__()
        assert 1 <= k <= M
        self.M, self.k = M, k
        self.router = Router(in_planes, M)
        self.experts = nn.ModuleList([ExpertBlock(in_planes, planes, stride) for _ in range(M)])
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        logits = self.router(x)            # [1, M]
        topv, topi = torch.topk(logits, k=self.k, dim=-1)   # [1, k]
        weights = torch.softmax(topv, dim=-1).view(-1)      # [k]
        y = 0
        for j in range(self.k):            # k is small (1–2), loop OK
            m = int(topi[0, j])
            y = y + weights[j] * self.experts[m](x)
        return self.act(y)

class ResNetCIFAR(nn.Module):
    def __init__(self, block, layers, num_classes=10, fsnn_last_stage=False, M=6, k=2):
        super().__init__()
        self.in_planes = 16
        self.conv1 = conv3x3(3, 16)
        self.bn1   = nn.BatchNorm2d(16)
        self.act   = nn.ReLU(inplace=True)
        # stages
        self.layer1 = self._make_layer(block, 16,  layers[0], stride=1)
        self.layer2 = self._make_layer(block, 32,  layers[1], stride=2)
        # stage-3: optionally FSNN blocks
        if fsnn_last_stage:
            blocks = []
            for i in range(layers[2]):
                stride = 2 if i == 0 else 1
                blocks.append(FSNNMoMBlock(self.in_planes, 64, stride=stride, M=M, k=k))
                self.in_planes = 64
            self.layer3 = nn.Sequential(*blocks)
        else:
            self.layer3 = self._make_layer(block, 64,  layers[2], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc      = nn.Linear(64, num_classes)
        nn.init.xavier_uniform_(self.fc.weight); nn.init.zeros_(self.fc.bias)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(block(self.in_planes, planes, s))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.act(self.bn1(self.conv1(x)))
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x)
        x = self.avgpool(x).flatten(1)
        return self.fc(x)

def resnet20_cifar10():
    return ResNetCIFAR(BasicBlock, [3,3,3], num_classes=10, fsnn_last_stage=False)

def fsnn_resnet20_cifar10(M=6, k=2):
    return ResNetCIFAR(BasicBlock, [3,3,3], num_classes=10, fsnn_last_stage=True, M=M, k=k)

# ---------- run ----------
print("\n=== Baseline ResNet-20 ===")
baseline = resnet20_cifar10().to(device)
train(baseline, epochs=1, lr=0.1)
loss_b, acc_b = evaluate(baseline)
p50_b, p90_b, _ = bs1_latency(baseline)
pb = count_params(baseline)
print(f"Baseline: params={pb:,}  acc={acc_b:.4f}  p50={p50_b*1e3:.2f} ms  p90={p90_b*1e3:.2f} ms")

print("\n=== FSNN-MoM ResNet-20 (stage-3 only, M=6, k=2) ===")
fsnn = fsnn_resnet20_cifar10(M=6, k=2).to(device)
train(fsnn, epochs=1, lr=0.1)
loss_f, acc_f = evaluate(fsnn)
p50_f, p90_f, _ = bs1_latency(fsnn)
pf = count_params(fsnn)
print(f"FSNN   : params={pf:,}  acc={acc_f:.4f}  p50={p50_f*1e3:.2f} ms  p90={p90_f*1e3:.2f} ms")
print("\n=== Summary ===")
print(f"Baseline -> params={pb:,}, acc={acc_b:.4f}, p50={p50_b*1e3:.2f} ms")
print(f"FSNN     -> params={pf:,}, acc={acc_f:.4f}, p50={p50_f*1e3:.2f} ms  (M={6}, k={2})")


device: cuda
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100.0%


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified

=== Baseline ResNet-20 ===
epoch 1: train_loss=1.6870
Baseline: params=272,474  acc=0.4815  p50=3.24 ms  p90=4.66 ms

=== FSNN-MoM ResNet-20 (stage-3 only, M=6, k=2) ===
epoch 1: train_loss=1.5476
FSNN   : params=1,301,932  acc=0.5495  p50=9.02 ms  p90=13.68 ms

=== Summary ===
Baseline -> params=272,474, acc=0.4815, p50=3.24 ms
FSNN     -> params=1,301,932, acc=0.5495, p50=9.02 ms  (M=6, k=2)


In [5]:
# CIFAR-10: Baseline Bottleneck ResNet-20 vs FSNN Bottleneck (stage-3 only, shared 1x1, top-k=1)
import time, random, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

def set_seed(s=42):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
set_seed(42)

# ------------------ Data ------------------
tf_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
tf_test  = transforms.Compose([transforms.ToTensor()])

train_ds = datasets.CIFAR10("./data", train=True, download=True, transform=tf_train)
test_ds  = datasets.CIFAR10("./data", train=False, download=True, transform=tf_test)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

# ------------------ Utils ------------------
def count_params(m): return sum(p.numel() for p in m.parameters() if p.requires_grad)

@torch.no_grad()
def bs1_latency(model, shape=(1,3,32,32), iters=200, warmup=50):
    model.eval()
    x = torch.randn(*shape, device=device)
    for _ in range(warmup):
        _ = model(x)
        if device.type == "cuda": torch.cuda.synchronize()
    ts=[]
    for _ in range(iters):
        if device.type == "cuda": torch.cuda.synchronize()
        t0 = time.perf_counter(); _ = model(x)
        if device.type == "cuda": torch.cuda.synchronize()
        ts.append(time.perf_counter()-t0)
    a = np.array(ts); return float(np.median(a)), float(np.percentile(a,90)), float(np.mean(a))

@torch.no_grad()
def evaluate(model):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss_sum += F.cross_entropy(logits, yb, reduction="sum").item()
        pred = logits.argmax(1)
        correct += (pred==yb).sum().item(); total += yb.numel()
    return loss_sum/total, correct/total

def train(model, epochs=1, lr=0.1, wd=5e-4):
    model.to(device)
    opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd, nesterov=True)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs*len(train_loader))
    model.train()
    for ep in range(epochs):
        run, n = 0.0, 0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = F.cross_entropy(logits, yb)
            loss.backward(); opt.step(); sched.step()
            run += loss.item()*xb.size(0); n += xb.size(0)
        print(f"epoch {ep+1}: train_loss={run/n:.4f}")

# ------------------ Conv helpers ------------------
def conv1x1(in_ch, out_ch, stride=1):
    return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, bias=False)

def conv3x3(in_ch, out_ch, stride=1):
    return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False)

# ------------------ Blocks ------------------
class BottleneckBlock(nn.Module):
    """Baseline bottleneck: 1x1 reduce -> 3x3 mid -> 1x1 expand (+ optional downsample)."""
    def __init__(self, in_ch, out_ch, stride=1, r=4):
        super().__init__()
        mid = out_ch // r
        self.reduce = conv1x1(in_ch, mid, stride=stride)
        self.bn1    = nn.BatchNorm2d(mid)
        self.mid    = conv3x3(mid, mid, stride=1)
        self.bn2    = nn.BatchNorm2d(mid)
        self.expand = conv1x1(mid, out_ch, stride=1)
        self.bn3    = nn.BatchNorm2d(out_ch)
        self.act    = nn.ReLU(inplace=True)
        self.down   = None
        if stride != 1 or in_ch != out_ch:
            self.down = nn.Sequential(conv1x1(in_ch, out_ch, stride=stride),
                                      nn.BatchNorm2d(out_ch))
    def forward(self, x):
        identity = x
        y = self.act(self.bn1(self.reduce(x)))
        y = self.act(self.bn2(self.mid(y)))
        y = self.bn3(self.expand(y))
        if self.down is not None: identity = self.down(x)
        return self.act(y + identity)

class Router(nn.Module):
    """Global per-batch router over reduced features (GAP -> linear)."""
    def __init__(self, ch, M):
        super().__init__()
        self.fc = nn.Linear(ch, M)
        nn.init.xavier_uniform_(self.fc.weight); nn.init.zeros_(self.fc.bias)
    def forward(self, y_reduced):  # y_reduced: [B, C, H, W]
        z = y_reduced.mean(dim=[0,2,3], keepdim=True)  # [1, C]
        return self.fc(z.view(1, -1))                  # [1, M]

class FSNNBottleneckMoM(nn.Module):
    """
    FSNN bottleneck with shared 1x1 reduce/expand.
    Only the 3x3 mid conv is modular: M experts; compute top-1 (k=1).
    """
    def __init__(self, in_ch, out_ch, stride=1, r=4, M=4, k=1):
        super().__init__(); assert 1 <= k <= M
        self.k, self.M = k, M
        mid = out_ch // r
        # Shared reduce/expand
        self.reduce = conv1x1(in_ch, mid, stride=stride)
        self.bn1    = nn.BatchNorm2d(mid)
        self.expand = conv1x1(mid, out_ch, stride=1)
        self.bn3    = nn.BatchNorm2d(out_ch)
        # Expert 3x3 mids
        self.mids = nn.ModuleList([conv3x3(mid, mid, stride=1) for _ in range(M)])
        self.mbn  = nn.ModuleList([nn.BatchNorm2d(mid) for _ in range(M)])
        # Router on reduced features
        self.router = Router(mid, M)
        self.act = nn.ReLU(inplace=True)
        self.down = None
        if stride != 1 or in_ch != out_ch:
            self.down = nn.Sequential(conv1x1(in_ch, out_ch, stride=stride),
                                      nn.BatchNorm2d(out_ch))
    def forward(self, x):
        identity = x
        y1 = self.act(self.bn1(self.reduce(x)))      # shared reduce
        logits = self.router(y1)                     # [1, M]
        topv, topi = torch.topk(logits, k=self.k, dim=-1)
        weights = torch.softmax(topv, dim=-1).view(-1)   # [k]
        y = 0
        for j in range(self.k):                      # k is tiny (1–2)
            m = int(topi[0, j])
            y_mid = self.act(self.mbn[m](self.mids[m](y1)))
            y = y + weights[j] * y_mid
        y = self.bn3(self.expand(y))                 # shared expand
        if self.down is not None: identity = self.down(x)
        return self.act(y + identity)

# ------------------ ResNet skeleton ------------------
class ResNetCIFAR(nn.Module):
    def __init__(self, block1, block3, layers=(3,3,3), r=4, fsnn_cfg=None, num_classes=10):
        super().__init__()
        self.in_ch = 16
        self.stem = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        # stage 1 (16)
        self.layer1 = self._make_layer(block1, 16, layers[0], stride=1, r=r)
        # stage 2 (32)
        self.layer2 = self._make_layer(block1, 32, layers[1], stride=2, r=r)
        # stage 3 (64): baseline or FSNN
        if fsnn_cfg is None:
            self.layer3 = self._make_layer(block3, 64, layers[2], stride=2, r=r)
        else:
            self.layer3 = self._make_layer_fsnn(64, layers[2], stride=2, r=r, **fsnn_cfg)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc   = nn.Linear(64, num_classes)
        nn.init.xavier_uniform_(self.fc.weight); nn.init.zeros_(self.fc.bias)

    def _make_layer(self, block, out_ch, n, stride, r):
        layers = []
        for i in range(n):
            s = stride if i == 0 else 1
            layers.append(block(self.in_ch, out_ch, stride=s, r=r))
            self.in_ch = out_ch
        return nn.Sequential(*layers)

    def _make_layer_fsnn(self, out_ch, n, stride, r, M=4, k=1):
        layers = []
        for i in range(n):
            s = stride if i == 0 else 1
            layers.append(FSNNBottleneckMoM(self.in_ch, out_ch, stride=s, r=r, M=M, k=k))
            self.in_ch = out_ch
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x)
        x = self.pool(x).flatten(1)
        return self.fc(x)

def resnet20_bottleneck():
    return ResNetCIFAR(BottleneckBlock, BottleneckBlock, layers=(3,3,3), r=4)

def fsnn_resnet20_bottleneck(M=4, k=1):
    return ResNetCIFAR(BottleneckBlock, BottleneckBlock, layers=(3,3,3), r=4,
                       fsnn_cfg=dict(M=M, k=k))

# ------------------ Run ------------------
EPOCHS = 1     # use 50–100 for real accuracy
M, K = 4, 1    # experts and active experts (compute top-1)

print("\n=== Baseline (Bottleneck, stage-3) ===")
baseline = resnet20_bottleneck().to(device)
train(baseline, epochs=EPOCHS, lr=0.1)
loss_b, acc_b = evaluate(baseline)
p50_b, p90_b, _ = bs1_latency(baseline)
pb = count_params(baseline)
print(f"Baseline: params={pb:,}  acc={acc_b:.4f}  p50={p50_b*1e3:.2f} ms  p90={p90_b*1e3:.2f} ms")

print(f"\n=== FSNN (Bottleneck MoM, stage-3; shared 1x1; M={M}, k={K}) ===")
fsnn = fsnn_resnet20_bottleneck(M=M, k=K).to(device)
train(fsnn, epochs=EPOCHS, lr=0.1)
loss_f, acc_f = evaluate(fsnn)
p50_f, p90_f, _ = bs1_latency(fsnn)
pf = count_params(fsnn)
print(f"FSNN   : params={pf:,}  acc={acc_f:.4f}  p50={p50_f*1e3:.2f} ms  p90={p90_f*1e3:.2f} ms")

print("\n=== Summary ===")
print(f"Baseline -> params={pb:,}, acc={acc_b:.4f}, p50={p50_b*1e3:.2f} ms")
print(f"FSNN     -> params={pf:,}, acc={acc_f:.4f}, p50={p50_f*1e3:.2f} ms  (M={M}, k={K})")


device: cuda
Files already downloaded and verified
Files already downloaded and verified

=== Baseline (Bottleneck, stage-3) ===
epoch 1: train_loss=1.6770
Baseline: params=21,370  acc=0.4704  p50=6.54 ms  p90=9.04 ms

=== FSNN (Bottleneck MoM, stage-3; shared 1x1; M=4, k=1) ===
epoch 1: train_loss=1.7184
FSNN   : params=42,598  acc=0.4578  p50=8.99 ms  p90=12.68 ms

=== Summary ===
Baseline -> params=21,370, acc=0.4704, p50=6.54 ms
FSNN     -> params=42,598, acc=0.4578, p50=8.99 ms  (M=4, k=1)


In [6]:
# Transformer FFN microbench: Baseline vs FSNN-MoM (top-k) with torch.compile + theory FLOP ratios
import time, math, numpy as np, torch, torch.nn as nn, torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

# --- helpers ---
def bs1_latency(model, x, iters=400, warmup=80):
    model.eval()
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(x)
            if device.type=="cuda": torch.cuda.synchronize()
        ts=[]
        for _ in range(iters):
            if device.type=="cuda": torch.cuda.synchronize()
            t0=time.perf_counter()
            _=model(x)
            if device.type=="cuda": torch.cuda.synchronize()
            ts.append(time.perf_counter()-t0)
    a=np.array(ts); return float(np.median(a)), float(np.percentile(a,90)), float(np.mean(a))

def try_compile(m):
    try:
        return torch.compile(m)  # PyTorch 2+
    except Exception:
        return m

# --- models ---
class FFN(nn.Module):
    # Baseline Transformer FFN: d -> d_ff -> d
    def __init__(self, d=512, d_ff=2048):
        super().__init__()
        self.fc1 = nn.Linear(d, d_ff)
        self.fc2 = nn.Linear(d_ff, d)
        self.act = nn.GELU()
    def forward(self, x):          # x: [B, T, d]
        y = self.act(self.fc1(x))
        return self.fc2(y)

class FFN_MoM_TopK(nn.Module):
    """
    FSNN Mixture-of-Modules FFN:
      - M experts, each d->h->d
      - router runs on pooled tokens (global per batch) -> pick top-k experts
      - compute ONLY those k experts, weighted sum
      - dormant capacity: M*h ~= baseline d_ff
      - active compute:   k*h  << baseline d_ff
    """
    def __init__(self, d=512, M=8, k=2, h=256):
        super().__init__(); assert 1<=k<=M
        self.M, self.k, self.d, self.h = M, k, d, h
        self.W1 = nn.Parameter(torch.empty(M, d, h));   self.b1 = nn.Parameter(torch.zeros(M, h))
        self.W2 = nn.Parameter(torch.empty(M, h, d));   self.b2 = nn.Parameter(torch.zeros(M, d))
        nn.init.xavier_uniform_(self.W1); nn.init.xavier_uniform_(self.W2)
        self.router = nn.Linear(d, M)
        nn.init.xavier_uniform_(self.router.weight); nn.init.zeros_(self.router.bias)
        self.act = nn.GELU()

    def forward(self, x):          # x: [B, T, d]
        B,T,d = x.shape
        g = x.mean(dim=[0,1], keepdim=True)     # [1,1,d] pooled
        logits = self.router(g.view(1,d))       # [1,M]
        topv, topi = torch.topk(logits, k=self.k, dim=-1)  # [1,k]
        w = torch.softmax(topv, dim=-1).view(-1)           # [k]
        y = 0
        # k is tiny (1–2), so looping is OK; we still do large fused matmuls
        for j in range(self.k):
            m = int(topi[0, j])
            y1 = self.act(torch.matmul(x, self.W1[m]) + self.b1[m])   # [B,T,h]
            y  = y + w[j] * (torch.matmul(y1, self.W2[m]) + self.b2[m])  # [B,T,d]
        return y

# --- shapes & configs ---
B, T, d = 1, 128, 512
d_ff = 2048                   # baseline inner width
M, k, h = 8, 2, 256           # FSNN: dormant = M*h = 2048 (same), active = k*h = 512 (4× smaller)
x = torch.randn(B, T, d, device=device)

# Theoretical FLOPs (ignoring bias/activation):
# Baseline ~ 2 * B * T * d * d_ff
# FSNN    ~ 2 * B * T * d * (k*h)
flops_base = 2*B*T*d*d_ff
flops_fsnn = 2*B*T*d*(k*h)
print(f"Theory FLOPs ratio FSNN/Baseline ≈ {flops_fsnn/flops_base:.2f}  (expected ~{(k*h)/d_ff:.2f}x compute)")

# --- build & (optionally) compile ---
ffn_base = try_compile(FFN(d=d, d_ff=d_ff).to(device))
ffn_fsnn = try_compile(FFN_MoM_TopK(d=d, M=M, k=k, h=h).to(device))

# --- measure ---
p50b, p90b, _ = bs1_latency(ffn_base, x)
p50f, p90f, _ = bs1_latency(ffn_fsnn, x)
pb = sum(p.numel() for p in ffn_base.parameters())
pf = sum(p.numel() for p in ffn_fsnn.parameters())

print(f"\nBaseline FFN: params={pb:,}  d_ff={d_ff}  p50={p50b*1e3:.2f} ms  p90={p90b*1e3:.2f} ms")
print(f"FSNN-MoM    : params={pf:,}  M={M},k={k},h={h}  (active {k*h}, dormant {M*h})")
print(f"              p50={p50f*1e3:.2f} ms  p90={p90f*1e3:.2f} ms  -> speedup ~{p50b/max(p50f,1e-9):.2f}×")


device: cuda
Theory FLOPs ratio FSNN/Baseline ≈ 0.25  (expected ~0.25x compute)




BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at https://github.com/openai/triton

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
