In [3]:
# FSNN-Compact: kernel-friendly low-rank (no runtime routing)
# Demo A: Transformer FFN microbench (latency win)
# Demo B: MNIST MLP quick train (param/FLOP win with tiny accuracy loss)

import os
os.environ.setdefault("TORCHINDUCTOR_DISABLE_CUDAGRAPHS", "1")  # stabilizes on Windows laptops
import torch, torch._dynamo


torch._dynamo.config.suppress_errors = True   # if compile fails, silently fall back to eager

# Let tensor cores accelerate float32 (RTX 4050 supports TF32)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = True


import time, math, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms





device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)
torch.backends.cudnn.benchmark = True

# ---------- Common helpers ----------
def bs1_latency(model, x, iters=300, warmup=60):
    model.eval()
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(x);  torch.cuda.synchronize() if device.type=="cuda" else None
        ts=[]
        for _ in range(iters):
            torch.cuda.synchronize() if device.type=="cuda" else None
            t0=time.perf_counter(); _=model(x)
            torch.cuda.synchronize() if device.type=="cuda" else None
            ts.append(time.perf_counter()-t0)
    a=np.array(ts); return float(np.median(a)), float(np.percentile(a,90)), float(np.mean(a))

def count_params(m): return sum(p.numel() for p in m.parameters() if p.requires_grad)

def try_compile(m):
    try: return torch.compile(m)
    except: return m

# ========== Demo A: Transformer FFN (baseline vs factorized) ==========
class FFN_Base(nn.Module):
    def __init__(self, d=512, d_ff=2048):
        super().__init__()
        self.fc1 = nn.Linear(d, d_ff)
        self.fc2 = nn.Linear(d_ff, d)
        self.act = nn.GELU()
    def forward(self, x):   # x: [B,T,d]
        y = self.act(self.fc1(x)); return self.fc2(y)

class FFN_Factorized(nn.Module):
    """
    Factorized FFN (FSNN-Compact):
      W1 ≈ U1 V1^T (U1:[d,r], V1:[d_ff,r])
      W2 ≈ U2 V2^T (U2:[d_ff,r], V2:[d,r])
    Compute: (x@U1)@V1^T -> act -> (·@U2)@V2^T
    """
    def __init__(self, d=512, d_ff=2048, r=512):
        super().__init__()
        self.U1 = nn.Parameter(torch.empty(d,     r)*0.02)
        self.V1 = nn.Parameter(torch.empty(d_ff,  r)*0.02)
        self.U2 = nn.Parameter(torch.empty(d_ff,  r)*0.02)
        self.V2 = nn.Parameter(torch.empty(d,     r)*0.02)
        nn.init.xavier_uniform_(self.U1); nn.init.xavier_uniform_(self.V1)
        nn.init.xavier_uniform_(self.U2); nn.init.xavier_uniform_(self.V2)
        self.act = nn.GELU()
    def forward(self, x):
        z  = x.matmul(self.U1)                 # [B,T,r]
        y1 = self.act(z.matmul(self.V1.t()))   # [B,T,d_ff]
        y2 = y1.matmul(self.U2).matmul(self.V2.t())  # [B,T,d]
        return y2

# Shapes for microbench
B,T,d = 1, 128, 512
d_ff = 2048
# Choose r so active compute ~ (d*r + d_ff*r + d_ff*r + d*r) vs baseline (d*d_ff + d_ff*d) -> roughly 4*r*(d + d_ff)/ (2*d*d_ff)
# r = 512 gives ~4x less inner width equivalent than baseline, strong speedup w/o routing.
r = 512

x = torch.randn(B,T,d, device=device)
ffn_base = try_compile(FFN_Base(d=d, d_ff=d_ff).to(device))
ffn_comp = try_compile(FFN_Factorized(d=d, d_ff=d_ff, r=r).to(device))

# Theory FLOPs (rough)
flops_base = 2*B*T*d*d_ff
flops_comp = 2*B*T*(d*r + d_ff*r)  # per matmul stage (two stages total but symmetric; factor 2 absorbed)
ratio = flops_comp / flops_base
print(f"[FFN] Theory FLOPs ratio (FSNN-Compact/Base) ≈ {ratio:.2f}")

p50b, p90b, _ = bs1_latency(ffn_base, x)
p50c, p90c, _ = bs1_latency(ffn_comp, x)
pb = count_params(ffn_base); pc = count_params(ffn_comp)
print(f"Baseline FFN: params={pb:,}  p50={p50b*1e3:.2f} ms  p90={p90b*1e3:.2f} ms")
print(f"FSNN-Compact: params={pc:,}  p50={p50c*1e3:.2f} ms  p90={p90c*1e3:.2f} ms  (r={r})")

# ========== Demo B: MNIST MLP (baseline vs factorized) ==========
transform = transforms.ToTensor()
train_ds = datasets.MNIST("./data", train=True,  download=True, transform=transform)
test_ds  = datasets.MNIST("./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

D = 28*28; NUM_CLASSES = 10

class MLP_Base(nn.Module):
    def __init__(self, d=D, h=512, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(d, h); self.fc2 = nn.Linear(h, d); self.head = nn.Linear(d, num_classes)
        self.act = nn.GELU()
    def forward(self, x):
        x = self.act(self.fc1(x)); x = self.act(self.fc2(x)); return self.head(x)

class FactorizedLinear(nn.Module):
    # Linear(d_in -> d_out) via U:[d_in,r], V:[d_out,r]; y = (x@U)@V^T
    def __init__(self, d_in, d_out, r):
        super().__init__()
        self.U = nn.Parameter(torch.empty(d_in,  r)); self.V = nn.Parameter(torch.empty(d_out, r))
        nn.init.xavier_uniform_(self.U); nn.init.xavier_uniform_(self.V)
    def forward(self, x):  # [B, d_in] -> [B, d_out]
        return x.matmul(self.U).matmul(self.V.t())

class MLP_FSNNCompact(nn.Module):
    def __init__(self, d=D, h=512, r1=128, r2=128, num_classes=10):
        super().__init__()
        self.fl1  = FactorizedLinear(d, h, r1)
        self.fl2  = FactorizedLinear(h, d, r2)
        self.head = nn.Linear(d, num_classes)
        self.act  = nn.GELU()
    def forward(self, x):
        x = self.act(self.fl1(x)); x = self.act(self.fl2(x)); return self.head(x)

def evaluate(model):
    model.eval(); tot=0; correct=0; loss_sum=0.0
    with torch.no_grad():
        for xb, yb in test_loader:
            xb = xb.view(xb.size(0), -1).to(device); yb = yb.to(device)
            logits = model(xb); loss_sum += F.cross_entropy(logits, yb, reduction="sum").item()
            pred = logits.argmax(1); correct += (pred==yb).sum().item(); tot += yb.numel()
    return loss_sum/tot, correct/tot

def train_one(model, epochs=2, lr=1e-3):
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    model.train()
    for ep in range(epochs):
        run=0; n=0
        for xb, yb in train_loader:
            xb = xb.view(xb.size(0), -1).to(device); yb = yb.to(device)
            opt.zero_grad(set_to_none=True); logits = model(xb)
            loss = F.cross_entropy(logits, yb); loss.backward(); opt.step()
            run += loss.item()*xb.size(0); n += xb.size(0)
        print(f"[MLP] epoch {ep+1}: train_loss={run/n:.4f}")

print("\n[MLP] === Baseline vs FSNN-Compact ===")
mlp_b = MLP_Base(h=512).to(device);           train_one(mlp_b, epochs=2)
lb, ab = evaluate(mlp_b); pb = count_params(mlp_b)

mlp_c = MLP_FSNNCompact(h=512, r1=128, r2=128).to(device);  train_one(mlp_c, epochs=2)
lc, ac = evaluate(mlp_c); pc = count_params(mlp_c)

print(f"Baseline MLP    : params={pb:,}  acc={ab:.4f}")
print(f"FSNN-Compact MLP: params={pc:,}  acc={ac:.4f}  (r1=r2=128)")


W0819 11:33:00.829000 47856 site-packages\torch\_dynamo\convert_frame.py:1125] WON'T CONVERT forward C:\Users\dpanc\AppData\Local\Temp\ipykernel_47856\2209687021.py line 59 
W0819 11:33:00.829000 47856 site-packages\torch\_dynamo\convert_frame.py:1125] due to: 
W0819 11:33:00.829000 47856 site-packages\torch\_dynamo\convert_frame.py:1125] Traceback (most recent call last):
W0819 11:33:00.829000 47856 site-packages\torch\_dynamo\convert_frame.py:1125]   File "c:\Users\dpanc\miniconda3\envs\nvidia_env\lib\site-packages\torch\_dynamo\output_graph.py", line 1446, in _call_user_compiler
W0819 11:33:00.829000 47856 site-packages\torch\_dynamo\convert_frame.py:1125]     compiled_fn = compiler_fn(gm, self.example_inputs())
W0819 11:33:00.829000 47856 site-packages\torch\_dynamo\convert_frame.py:1125]   File "c:\Users\dpanc\miniconda3\envs\nvidia_env\lib\site-packages\torch\_dynamo\repro\after_dynamo.py", line 129, in __call__
W0819 11:33:00.829000 47856 site-packages\torch\_dynamo\convert_fram

device: cuda
[FFN] Theory FLOPs ratio (FSNN-Compact/Base) ≈ 1.25


W0819 11:33:01.032000 47856 site-packages\torch\_dynamo\convert_frame.py:1125] WON'T CONVERT forward C:\Users\dpanc\AppData\Local\Temp\ipykernel_47856\2209687021.py line 78 
W0819 11:33:01.032000 47856 site-packages\torch\_dynamo\convert_frame.py:1125] due to: 
W0819 11:33:01.032000 47856 site-packages\torch\_dynamo\convert_frame.py:1125] Traceback (most recent call last):
W0819 11:33:01.032000 47856 site-packages\torch\_dynamo\convert_frame.py:1125]   File "c:\Users\dpanc\miniconda3\envs\nvidia_env\lib\site-packages\torch\_dynamo\output_graph.py", line 1446, in _call_user_compiler
W0819 11:33:01.032000 47856 site-packages\torch\_dynamo\convert_frame.py:1125]     compiled_fn = compiler_fn(gm, self.example_inputs())
W0819 11:33:01.032000 47856 site-packages\torch\_dynamo\convert_frame.py:1125]   File "c:\Users\dpanc\miniconda3\envs\nvidia_env\lib\site-packages\torch\_dynamo\repro\after_dynamo.py", line 129, in __call__
W0819 11:33:01.032000 47856 site-packages\torch\_dynamo\convert_fram

Baseline FFN: params=2,099,712  p50=0.18 ms  p90=0.24 ms
FSNN-Compact: params=2,621,440  p50=0.21 ms  p90=0.41 ms  (r=512)

[MLP] === Baseline vs FSNN-Compact ===
[MLP] epoch 1: train_loss=0.2694
[MLP] epoch 2: train_loss=0.0952
[MLP] epoch 1: train_loss=0.2877
[MLP] epoch 2: train_loss=0.1200
Baseline MLP    : params=811,962  acc=0.9728
FSNN-Compact MLP: params=339,626  acc=0.9679  (r1=r2=128)
