
# ✅ MLP Utils — Tests & Diagnostics

This notebook runs a **battery of tests** for your `mlp_utils.py` helpers.  
It will try to import `backend/ml/mlp_utils.py`. If it can't find it, it creates a **compatible fallback** so the tests still run.

**What we test**
- Import & environment sanity (CPU/GPU availability)
- Model construction (shapes, parameter counts)
- Forward/Backward passes (no NaNs, gradients finite)
- Training on synthetic data (loss goes down, accuracy up)
- Reproducibility (fixed seeds)
- Save/Load round-trip equivalence
- Simple learning rate & batch size sweeps (optional)


In [None]:

import os, sys, math, json, copy, time, pathlib, tempfile
import numpy as np
import matplotlib.pyplot as plt

# PyTorch
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import TensorDataset, DataLoader
    TORCH_OK = True
except Exception as e:
    TORCH_OK = False
    print("[warn] PyTorch not available:", e)


In [None]:

def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    if 'torch' in sys.modules:
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(1337)


## 1) Import `mlp_utils` (or provide fallback)

In [None]:

MLP_UTILS = None
ERR = None

if TORCH_OK:
    # Try common repo locations
    candidate_paths = [
        ".", "backend", "backend/ml", "ml"
    ]
    for cp in candidate_paths:
        p = pathlib.Path(cp)
        if (p / "mlp_utils.py").exists():
            sys.path.insert(0, str(p.resolve()))
            try:
                import mlp_utils as MLP_UTILS # type: ignore
                break
            except Exception as e:
                ERR = e

if MLP_UTILS is None and TORCH_OK:
    # Fallback minimal implementation (API compatible with tests)
    class MLP(nn.Module):
        def __init__(self, in_dim, hidden, out_dim, dropout=0.0):
            super().__init__()
            layers = []
            last = in_dim
            for h in hidden:
                layers += [nn.Linear(last, h), nn.ReLU(), nn.Dropout(dropout)]
                last = h
            layers.append(nn.Linear(last, out_dim))
            self.net = nn.Sequential(*layers)
        def forward(self, x): return self.net(x)
    def make_mlp(in_dim, hidden, out_dim, dropout=0.0):
        return MLP(in_dim, hidden, out_dim, dropout)
    def count_params(model):
        return sum(p.numel() for p in model.parameters())
    def save_model(model, path):
        torch.save(model.state_dict(), path)
    def load_model(model, path, map_location="cpu"):
        sd = torch.load(path, map_location=map_location)
        model.load_state_dict(sd)
        return model
    def train_one_epoch(model, loader, opt, device="cpu", scheduler=None):
        model.train()
        total = 0.0
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = F.cross_entropy(logits, yb)
            loss.backward()
            # gradient sanity check (finite)
            gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            opt.step()
            if scheduler: scheduler.step()
            total += float(loss.detach().cpu())
        return total / max(1, len(loader))
    def evaluate(model, loader, device="cpu"):
        model.eval()
        n, correct, total_loss = 0, 0, 0.0
        with torch.no_grad():
            for xb, yb in loader:
                xb, yb = xb.to(device), yb.to(device)
                logits = model(xb)
                total_loss += float(F.cross_entropy(logits, yb).cpu())
                pred = logits.argmax(dim=1)
                correct += int((pred == yb).sum().cpu())
                n += yb.numel()
        return {"loss": total_loss / max(1, len(loader)), "acc": correct / max(1, n)}
    # Bundle
    class _NS: pass
    MLP_UTILS = _NS()
    MLP_UTILS.make_mlp = make_mlp # type: ignore
    MLP_UTILS.count_params = count_params # type: ignore
    MLP_UTILS.save_model = save_model # type: ignore
    MLP_UTILS.load_model = load_model # type: ignore
    MLP_UTILS.train_one_epoch = train_one_epoch # type: ignore
    MLP_UTILS.evaluate = evaluate # type: ignore

print("mlp_utils source:", "fallback" if ERR or (MLP_UTILS and not hasattr(MLP_UTILS, "__file__")) else getattr(MLP_UTILS, "__file__", "unknown"))


## 2) Environment & Device

In [None]:

device = "cuda" if TORCH_OK and torch.cuda.is_available() else "cpu"
print("PyTorch:", TORCH_OK, "| Device:", device)


## 3) Synthetic Classification Dataset

In [None]:

if not TORCH_OK:
    raise RuntimeError("PyTorch is required for these tests.")

set_seed(123)
n_classes = 3
n_features = 16
n_train, n_val = 4000, 1000

W_true = torch.randn(n_features, n_classes)
b_true = torch.randn(n_classes)

X_tr = torch.randn(n_train, n_features)
logits_tr = X_tr @ W_true + b_true
y_tr = logits_tr.argmax(dim=1)

X_va = torch.randn(n_val, n_features)
logits_va = X_va @ W_true + b_true
y_va = logits_va.argmax(dim=1)

train_ds = TensorDataset(X_tr, y_tr)
val_ds   = TensorDataset(X_va, y_va)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, drop_last=False)
val_loader   = DataLoader(val_ds, batch_size=256, shuffle=False)


## 4) Model Build — Shapes & Params

In [None]:

model = MLP_UTILS.make_mlp(in_dim=n_features, hidden=[64, 32], out_dim=n_classes, dropout=0.1).to(device) # type: ignore
x = next(iter(train_loader))[0].to(device)
out = model(x)
assert out.shape == (x.shape[0], n_classes), f"Unexpected output shape: {out.shape}"
params = MLP_UTILS.count_params(model) # type: ignore
print("Param count:", params)
assert params > 0, "Parameter count should be > 0"


## 5) Training Loop — Loss Should Decrease

In [None]:

opt = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-2)
history = {"train_loss": [], "val_loss": [], "val_acc": []}

epochs = 12
best = 1e9
for ep in range(1, epochs+1):
    tl = MLP_UTILS.train_one_epoch(model, train_loader, opt, device=device) # type: ignore
    ev = MLP_UTILS.evaluate(model, val_loader, device=device) # type: ignore
    history["train_loss"].append(tl)
    history["val_loss"].append(ev["loss"])
    history["val_acc"].append(ev["acc"])
    best = min(best, ev["loss"])
    print(f"ep {ep:02d} | train {tl:.4f} | val {ev['loss']:.4f} | acc {ev['acc']:.3f}")
    
assert history["val_loss"][0] > history["val_loss"][-1] - 1e-6, "Validation loss did not decrease"


### Plot Loss Curves

In [None]:

plt.figure(figsize=(8,4))
plt.plot(history["train_loss"], label="train")
plt.plot(history["val_loss"], label="val")
plt.title("Loss Curves")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend()
plt.show()


## 6) Reproducibility — Fixed Seeds

In [None]:

def train_quick(seed=777):
    set_seed(seed)
    m = MLP_UTILS.make_mlp(n_features, [32], n_classes).to(device) # type: ignore
    opt = torch.optim.SGD(m.parameters(), lr=1e-2)
    for _ in range(3):
        _ = MLP_UTILS.train_one_epoch(m, train_loader, opt, device=device) # type: ignore
    ev = MLP_UTILS.evaluate(m, val_loader, device=device) # type: ignore
    # Return a small summary vector
    vec = torch.cat([p.detach().flatten()[:20].cpu() for p in m.parameters()])[:50].numpy()
    return ev["loss"], ev["acc"], vec

l1, a1, v1 = train_quick(seed=2024)
l2, a2, v2 = train_quick(seed=2024)

print("Repeat loss diff:", abs(l1 - l2))
print("Repeat acc  diff:", abs(a1 - a2))
print("Vec L2 diff    :", float(np.linalg.norm(v1 - v2)))
assert abs(l1 - l2) < 1e-4 and abs(a1 - a2) < 1e-4 and np.linalg.norm(v1 - v2) < 1e-6, "Non-deterministic outcome with fixed seed"


## 7) Save / Load Round-Trip

In [None]:

with tempfile.TemporaryDirectory() as td:
    p = os.path.join(td, "mlp.pt")
    MLP_UTILS.save_model(model, p) # type: ignore
    m2 = MLP_UTILS.make_mlp(n_features, [64,32], n_classes).to(device) # type: ignore
    MLP_UTILS.load_model(m2, p, map_location=device) # type: ignore
    # Compare a forward sample
    x = next(iter(val_loader))[0].to(device)[:8]
    y1 = model(x).detach().cpu().numpy()
    y2 = m2(x).detach().cpu().numpy()
    diff = np.abs(y1 - y2).max()
    print("Max forward diff after load:", diff)
    assert diff < 1e-6, "Model weights did not round-trip exactly"


## 8) Gradient Finite & Norm Check

In [None]:

for p in model.parameters():
    assert torch.isfinite(p).all(), "Parameter contains NaN/Inf"
    
# Backprop one step
xb, yb = next(iter(train_loader))
xb, yb = xb.to(device), yb.to(device)
opt.zero_grad(set_to_none=True)
loss = F.cross_entropy(model(xb), yb)
loss.backward()
gmax = 0.0
for p in model.parameters():
    if p.grad is not None:
        gmax = max(gmax, float(p.grad.detach().abs().max().cpu()))
        assert torch.isfinite(p.grad).all(), "Gradient contains NaN/Inf"
print("Max grad abs:", gmax)
assert gmax > 0.0, "Zero gradients observed"


## 9) (Optional) Mini LR Sweep

In [None]:

lrs = [1e-3, 3e-3, 1e-2]
results = []
for lr in lrs:
    set_seed(99)
    m = MLP_UTILS.make_mlp(n_features, [32], n_classes).to(device) # type: ignore
    opt = torch.optim.Adam(m.parameters(), lr=lr)
    tl = MLP_UTILS.train_one_epoch(m, train_loader, opt, device=device) # type: ignore
    ev = MLP_UTILS.evaluate(m, val_loader, device=device) # type: ignore
    results.append((lr, ev["loss"], ev["acc"]))

for lr, vl, va in results:
    print(f"lr={lr:.0e}: val_loss={vl:.4f} acc={va:.3f}")
