In [1]:
# FSNN-Compact (one cell): rank sweep for Transformer FFN + optional quick MNIST MLP check
# - Measures bs=1 latency & params vs baseline FFN for ranks r in RANKS
# - Optional: quick 2-epoch MNIST MLP accuracy check with factorized layers
# Run on GPU for best signal (Runtime -> Change runtime type -> GPU)

import time, math, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# ---------------------------- Config ----------------------------
DO_MNIST_CHECK = True          # set False to skip the MNIST mini-train
B, T, d   = 1, 128, 512        # FFN microbench shape
d_ff      = 2048               # baseline FFN inner width
RANKS     = [512, 448, 384, 320, 256, 192, 128]  # sweep ranks for FSNN-Compact FFN
MNIST_R   = [192, 160, 128, 96]                  # ranks for MNIST demo
EPOCHS_MN = 2                                     # tiny train for MNIST demo

# ---------------------------- Setup ----------------------------
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

def set_seed(s=42):
    import random
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
set_seed(42)

def try_compile(m):
    try:
        return torch.compile(m)  # PyTorch 2.x
    except Exception:
        return m

def count_params(m): 
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

@torch.no_grad()
def bs1_latency(model, x, iters=300, warmup=60):
    model.eval()
    for _ in range(warmup):
        _ = model(x)
        if device.type == "cuda": torch.cuda.synchronize()
    times = []
    for _ in range(iters):
        if device.type == "cuda": torch.cuda.synchronize()
        t0 = time.perf_counter()
        _ = model(x)
        if device.type == "cuda": torch.cuda.synchronize()
        times.append(time.perf_counter() - t0)
    a = np.array(times)
    return float(np.median(a)), float(np.percentile(a, 90)), float(np.mean(a))

# ---------------------------- Models (FFN) ----------------------------
class FFN_Base(nn.Module):
    # Transformer FFN: d -> d_ff -> d
    def __init__(self, d=512, d_ff=2048):
        super().__init__()
        self.fc1 = nn.Linear(d, d_ff)
        self.fc2 = nn.Linear(d_ff, d)
        self.act = nn.GELU()
    def forward(self, x):  # x: [B,T,d]
        y = self.act(self.fc1(x))
        return self.fc2(y)

class FFN_Factorized(nn.Module):
    """
    FSNN-Compact FFN (low-rank, kernel-friendly; no routing):
      W1 ≈ U1 V1^T (U1:[d,r], V1:[d_ff,r]); W2 ≈ U2 V2^T (U2:[d_ff,r], V2:[d,r])
      Forward: (x@U1)@V1^T -> GELU -> (·@U2)@V2^T
    """
    def __init__(self, d=512, d_ff=2048, r=384):
        super().__init__()
        self.U1 = nn.Parameter(torch.empty(d,     r)); self.V1 = nn.Parameter(torch.empty(d_ff,  r))
        self.U2 = nn.Parameter(torch.empty(d_ff,  r)); self.V2 = nn.Parameter(torch.empty(d,     r))
        nn.init.xavier_uniform_(self.U1); nn.init.xavier_uniform_(self.V1)
        nn.init.xavier_uniform_(self.U2); nn.init.xavier_uniform_(self.V2)
        self.act = nn.GELU()
    def forward(self, x):  # [B,T,d]
        z  = x.matmul(self.U1)                 # [B,T,r]
        y1 = self.act(z.matmul(self.V1.t()))   # [B,T,d_ff]
        y2 = y1.matmul(self.U2).matmul(self.V2.t())  # [B,T,d]
        return y2

# ---------------------------- FFN Rank Sweep ----------------------------
x = torch.randn(B, T, d, device=device)

base = try_compile(FFN_Base(d=d, d_ff=d_ff).to(device))
p50b, p90b, _ = bs1_latency(base, x)
pb = count_params(base)

print("\n=== FFN microbench (bs=1) ===")
print(f"Baseline FFN : params={pb:,}  d={d}, d_ff={d_ff}  p50={p50b*1e3:.2f} ms  p90={p90b*1e3:.2f} ms")

print("\nFSNN-Compact FFN sweep:")
print(f"{'r':>5} | {'params':>10} | {'p50 (ms)':>8} | {'p90 (ms)':>8} | {'Δp50%':>6} | {'theory FLOPs%':>13}")
print("-"*60)
for r in RANKS:
    ffn = try_compile(FFN_Factorized(d=d, d_ff=d_ff, r=r).to(device))
    p50f, p90f, _ = bs1_latency(ffn, x)
    pf = count_params(ffn)
    # Theoretical FLOPs ratio (rough): base ~ 2*d*d_ff; compact ~ 2*(d*r + d_ff*r)
    flops_ratio = (d*r + d_ff*r) / (d*d_ff)  # per half; factor 2 cancels
    dpct = (p50f - p50b) / max(p50b, 1e-9) * 100.0
    print(f"{r:5d} | {pf:10,} | {p50f*1e3:8.2f} | {p90f*1e3:8.2f} | {dpct:6.1f} | {flops_ratio*100:13.1f}")

# ---------------------------- (Optional) MNIST Mini-Train ----------------------------
if DO_MNIST_CHECK:
    print("\n=== MNIST quick check (2 epochs) — Baseline MLP vs FSNN-Compact MLP ===")
    transform = transforms.ToTensor()
    train_ds = datasets.MNIST("./data", train=True,  download=True, transform=transform)
    test_ds  = datasets.MNIST("./data", train=False, download=True, transform=transform)
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
    test_loader  = DataLoader(test_ds,  batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

    D_in, H, NUM_CLASSES = 28*28, 512, 10

    class MLP_Base(nn.Module):
        def __init__(self, d=D_in, h=H):
            super().__init__()
            self.fc1 = nn.Linear(d, h); self.fc2 = nn.Linear(h, d); self.head = nn.Linear(d, NUM_CLASSES)
            self.act = nn.GELU()
        def forward(self, x):
            x = self.act(self.fc1(x)); x = self.act(self.fc2(x)); return self.head(x)

    class FactorizedLinear(nn.Module):
        def __init__(self, d_in, d_out, r):
            super().__init__()
            self.U = nn.Parameter(torch.empty(d_in, r)); self.V = nn.Parameter(torch.empty(d_out, r))
            nn.init.xavier_uniform_(self.U); nn.init.xavier_uniform_(self.V)
        def forward(self, x):  # [B, d_in] -> [B, d_out]
            return x.matmul(self.U).matmul(self.V.t())

    class MLP_Compact(nn.Module):
        def __init__(self, d=D_in, h=H, r1=128, r2=128):
            super().__init__()
            self.fl1 = FactorizedLinear(d, h, r1); self.fl2 = FactorizedLinear(h, d, r2)
            self.head = nn.Linear(d, NUM_CLASSES); self.act = nn.GELU()
        def forward(self, x):
            x = self.act(self.fl1(x)); x = self.act(self.fl2(x)); return self.head(x)

    @torch.no_grad()
    def evaluate(model):
        model.eval(); tot=0; correct=0; loss_sum=0.0
        for xb, yb in test_loader:
            xb = xb.view(xb.size(0), -1).to(device); yb = yb.to(device)
            logits = model(xb); loss_sum += F.cross_entropy(logits, yb, reduction="sum").item()
            pred = logits.argmax(1); correct += (pred==yb).sum().item(); tot += yb.numel()
        return loss_sum/tot, correct/tot

    def train_one(model, epochs=EPOCHS_MN, lr=1e-3):
        model.to(device); opt = torch.optim.AdamW(model.parameters(), lr=lr)
        model.train()
        for ep in range(epochs):
            run=0; n=0
            for xb, yb in train_loader:
                xb = xb.view(xb.size(0), -1).to(device); yb = yb.to(device)
                opt.zero_grad(set_to_none=True); logits = model(xb)
                loss = F.cross_entropy(logits, yb); loss.backward(); opt.step()
                run += loss.item()*xb.size(0); n += xb.size(0)
            print(f"[MNIST] epoch {ep+1}: train_loss={run/n:.4f}")

    mlp_b = MLP_Base().to(device); train_one(mlp_b); lb, ab = evaluate(mlp_b); pb = count_params(mlp_b)
    print(f"Baseline MLP    : params={pb:,}  acc={ab:.4f}")

    for r in MNIST_R:
        mlp_c = MLP_Compact(r1=r, r2=r).to(device); train_one(mlp_c)
        lc, ac = evaluate(mlp_c); pc = count_params(mlp_c)
        print(f"FSNN-Compact MLP (r={r:>3}): params={pc:,}  acc={ac:.4f}  Δacc={100*(ac-ab):+.2f} pp")


device: cuda




BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at https://github.com/openai/triton

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
