11/22/2025 In this notebook, we benchmark againt a different control (aside from the 2g x 2g baseline): a randomly generated activation going from 2 \to 2g, comparing apples to apples against our best genus 30 AJ models on MNIST. This sort of test was suggested by Mike Douglas.  A random fourier features setup does well but not as well (and an MLP alternative for the 2\to2g does not do well).   AJ ~92% vs RFF ~ 88 % according to initial runs.  To do:  inject this into Noam's now much faster pipeline which has AJ working easily up to genus 150 on MNIST (as predicted competitively to 2g x 2g but with order 2g parameters, much fewer as we go up in genus).



In [1]:
# @title Load from Drive
!pip install -q torch torchvision numpy

import os, math, json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from google.colab import drive

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ===========================
# 1) Drive + Paths (match Notebook 1)
# ===========================
DRIVE_FOLDER = "AJ_Tables_g30"  # same as Notebook 1
drive.mount('/content/drive', force_remount=True)
SAVE_DIR = f"/content/drive/MyDrive/{DRIVE_FOLDER}"
INTEGRALS_PATH = os.path.join(SAVE_DIR, "aj_integrals_genus30.pt")
OMEGAS_PATH    = os.path.join(SAVE_DIR, "aj_omegas_genus30.pt")
assert os.path.exists(INTEGRALS_PATH), "Integrals file not found in Drive"
assert os.path.exists(OMEGAS_PATH),    "Omegas file not found in Drive"



# ===========================
# 2) Load tables
# ===========================
ints = torch.load(INTEGRALS_PATH, map_location='cpu', weights_only=False)
omeg = torch.load(OMEGAS_PATH,    map_location='cpu', weights_only=False)

# ints = torch.load(INTEGRALS_PATH, map_location='cpu')  ### old version
# omeg = torch.load(OMEGAS_PATH,    map_location='cpu')

genus       = int(ints["genus"])
grid_r_np   = np.array(ints["grid_r"])
grid_i_np   = np.array(ints["grid_i"])
branch_pts  = np.array(ints["branch_pts"])
I_plus      = ints["I_plus"]            # (g, H, W) complex
Om_plus     = omeg["omega_plus"]        # (g, H, W) complex

H, W = I_plus.shape[-2:]
grid_r = torch.tensor(grid_r_np, dtype=torch.float32)
grid_i = torch.tensor(grid_i_np, dtype=torch.float32)
branch_pts_t = torch.tensor(branch_pts)   # complex128 → complex64 in torch if needed

print(f"Loaded: genus={genus}, grid=({H}x{W})")


Mounted at /content/drive
Loaded: genus=30, grid=(96x96)


In [2]:
# @title Base class AJMNIST_Anchored (needed for PeriodicHead)

import torch
import torch.nn as nn
import torch.nn.functional as F

class AJMNIST_Anchored(nn.Module):
    """
    Anchored AJ model:
      - conv trunk (1→64)
      - learned per-point embeddings (g×embed_dim)
      - shared point_head producing (x,y, sheet_logit) for each point
      - per-point bias initialized to ω-aware anchors
      - AJGridActivationNorm does lookup, standardization, sheet sign, and sum
      - classifier: linear 2g→10
    """
    def __init__(self, genus, I_plus, Om_plus, grid_r, grid_i, branch_pts,
                 anchors_xy, mu, sigma, embed_dim=8):
        super().__init__()
        self.genus = genus
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((1,1))
        )
        self.embed = nn.Parameter(torch.empty(genus, embed_dim))
        nn.init.uniform_(self.embed, -2.0, 2.0)

        # shared point_head (no bias) + per-point bias
        self.point_head = nn.Linear(64 + embed_dim, 3, bias=False)
        nn.init.xavier_uniform_(self.point_head.weight)
        self.point_bias = nn.Parameter(torch.zeros(genus, 3))

        # AJ activation with normalization
        self.aj = AJGridActivationNorm(I_plus, Om_plus, grid_r, grid_i, branch_pts, mu, sigma)
        self.classifier = nn.Linear(2*genus, 10)

        # ω-aware initialization for biases
        rmin, rmax = float(grid_r.min()), float(grid_r.max())
        imin, imax = float(grid_i.min()), float(grid_i.max())
        def logit(p): return float(torch.log(p/(1-p)))
        with torch.no_grad():
            for i in range(genus):
                x0, y0 = float(anchors_xy[i,0]), float(anchors_xy[i,1])
                px = (x0 - rmin) / (rmax - rmin)
                py = (y0 - imin) / (imax - imin)
                self.point_bias[i, 0] = logit(torch.tensor(px))
                self.point_bias[i, 1] = logit(torch.tensor(py))
                target_sign = 0.8 if (i % 2 == 0) else -0.8
                self.point_bias[i, 2] = float(torch.atanh(torch.tensor(target_sign)))

    def forward(self, x, return_aux=False):
        B = x.size(0)
        h = self.conv(x).view(B, -1)
        h_exp = h.unsqueeze(1).expand(-1, self.genus, -1)
        emb   = self.embed.unsqueeze(0).expand(B, -1, -1)
        inp   = torch.cat([h_exp, emb], dim=2)

        out   = self.point_head(inp) + self.point_bias.unsqueeze(0)  # (B,g,3)
        raw_xy, sheet_logits = out[..., :2], out[..., 2]
        coords, aux = self.aj(raw_xy, sheet_logits, return_aux=True)
        logits = self.classifier(coords)
        if return_aux:
            return logits, aux
        return logits


In [3]:
# @title AJ model with ω-aware init + AJ normalization + learnable gain

import torch
import torch.nn as nn
import torch.nn.functional as F

def _pack_complex_table(table_gHW: torch.Tensor) -> torch.Tensor:
    re, im = table_gHW.real, table_gHW.imag
    return torch.cat([re, im], dim=0).unsqueeze(0).contiguous()  # (1, 2g, H, W)

class AJGridActivationNorm(nn.Module):
    """
    AJ lookup with:
      - bilinear sampling
      - continuous sheet sign via tanh
      - channel-wise standardization: (I - mu)/sigma
      - learnable global gain gamma
      - boundary/branch penalties (for diagnostics/regularization)
    """
    def __init__(self, I_plus, Om_plus, grid_r, grid_i, branch_pts, mu, sigma):
        super().__init__()
        self.g = I_plus.shape[0]
        self.register_buffer("I_plus",  _pack_complex_table(I_plus))
        self.register_buffer("Om_plus", _pack_complex_table(Om_plus))

        # stats
        self.register_buffer("mu",    mu.view(1, 1, -1))     # (1,1,2g)
        self.register_buffer("sigma", sigma.view(1, 1, -1))  # (1,1,2g)
        self.gamma = nn.Parameter(torch.tensor(1.0))         # learnable gain

        # bounds
        self.register_buffer("r_min", torch.tensor(float(grid_r.min())))
        self.register_buffer("r_max", torch.tensor(float(grid_r.max())))
        self.register_buffer("i_min", torch.tensor(float(grid_i.min())))
        self.register_buffer("i_max", torch.tensor(float(grid_i.max())))

        # branch points
        self.register_buffer("bp_real", branch_pts.real.float())
        self.register_buffer("bp_imag", branch_pts.imag.float())

    def _map_raw_to_bounds(self, raw_xy):
        # sigmoid mapping into box
        xr = self.r_min + (self.r_max - self.r_min) * torch.sigmoid(raw_xy[..., 0])
        yi = self.i_min + (self.i_max - self.i_min) * torch.sigmoid(raw_xy[..., 1])
        return xr, yi

    def _norm_to_grid(self, xr, yi):
        gx = 2.0 * (xr - self.r_min) / (self.r_max - self.r_min) - 1.0
        gy = 2.0 * (yi - self.i_min) / (self.i_max - self.i_min) - 1.0
        return gx, gy

    def forward(self, raw_xy: torch.Tensor, sheet_logits: torch.Tensor, return_aux=True):
        B, g, _ = raw_xy.shape
        assert g == self.g

        xr, yi = self._map_raw_to_bounds(raw_xy)
        gx, gy = self._norm_to_grid(xr, yi)
        grid = torch.stack([gx, gy], dim=-1).view(B*g, 1, 1, 2)

        # Sample integrals and standardize
        I = F.grid_sample(self.I_plus.expand(B*g, -1, -1, -1),
                          grid, mode="bilinear", align_corners=True).view(B, g, -1)
        I_std = (I - self.mu) / self.sigma                        # (B, g, 2g)

        # Continuous sheet sign (initialized away from 0 in the constructor below)
        #sign = torch.tanh(sheet_logits).unsqueeze(-1)             # (B, g, 1)
        # inside AJGridActivationNorm.forward(...)

        sign = torch.tanh(sheet_logits).unsqueeze(-1)  # (B,g,1)
        contrib = sign * I_std
        coords  = self.gamma * contrib.sum(dim=1)      # (B,2g)


        # contrib = sign * I_std
        # coords = self.gamma * contrib.sum(dim=1)                  # (B, 2g)

        aux = None
        if return_aux:
            margin = 0.95
            bpen = ((gx.abs() - margin).clamp_min(0)**2 +
                    (gy.abs() - margin).clamp_min(0)**2).mean()

            dx = xr.unsqueeze(-1) - self.bp_real
            dy = yi.unsqueeze(-1) - self.bp_imag
            d2 = dx*dx + dy*dy
            tau = 0.07
            rpen = torch.exp(-d2 / (2*tau*tau)).mean()

            # aux = {"bound_penalty": bpen, "branch_penalty": rpen,
            #        "gx": gx, "gy": gy, "x": xr, "y": yi}
            # inside AJGridActivationNorm.forward(...), in the return_aux block
            aux = {
                "x": xr, "y": yi, "gx": gx, "gy": gy,
                "bound_penalty": bpen, "branch_penalty": rpen,
                "sheet_sign": sign.squeeze(-1)   # keep this
                # "omega_channels": Om  <-- delete this line
            }

            # aux = {
            #     "x": xr, "y": yi, "gx": gx, "gy": gy,
            #     "bound_penalty": bpen, "branch_penalty": rpen,
            #     "omega_channels": Om,
            #     "sheet_sign": sign.squeeze(-1)  # <-- add this line
            # }
        return coords, aux

def logit(p):  # inverse sigmoid
    p = np.clip(p, 1e-6, 1-1e-6)
    return float(np.log(p/(1-p)))





In [4]:
# AJ periodic head WITHOUT tau: axis-aligned Fourier features (global K)
import torch, torch.nn as nn, torch.nn.functional as F

class TorusFeatures(nn.Module):
    def __init__(self, dim: int, K: int = 2, freqs=None, learnable: bool = False):
        super().__init__()
        if freqs is None:
            freqs = torch.tensor([0.5, 1.0], dtype=torch.float32)[:K]
        else:
            freqs = torch.as_tensor(freqs, dtype=torch.float32)[:K]
        if learnable:
            self.freqs = nn.Parameter(freqs)
        else:
            self.register_buffer("freqs", freqs)
        self.dim, self.K = dim, len(freqs)
    def forward(self, u):
        B, D = u.shape
        f = (self.freqs if isinstance(self.freqs, torch.Tensor) else torch.as_tensor(self.freqs)).view(1,1,-1).to(u.device)
        ang = u.unsqueeze(-1) * f
        return torch.cat([torch.cos(ang), torch.sin(ang)], dim=-1).view(B, D*2*self.K)

class AJMNIST_AxisPeriodic(nn.Module):
    """
    Drop-in head: AJ (anchored+norm) → axis-aligned periodic features → classifier
    Expects you already have AJMNIST_Anchored and AJGridActivationNorm defined.
    """
    def __init__(self, genus, I_plus, Om_plus, grid_r, grid_i, branch_pts,
                 anchors_xy, mu, sigma, embed_dim=8, K=2, learnable_freqs=False):
        super().__init__()
        self.base = AJMNIST_Anchored(genus, I_plus, Om_plus, grid_r, grid_i, branch_pts,
                                     anchors_xy, mu, sigma, embed_dim=embed_dim)
        D = 2*genus
        self.torus = TorusFeatures(D, K=K, learnable=learnable_freqs)
        self.classifier = nn.Linear(D*2*K, 10)
    @property
    def genus(self): return self.base.genus
    def forward(self, x, return_aux=False):
        B = x.size(0)
        h = self.base.conv(x).view(B, -1)
        h_exp = h.unsqueeze(1).expand(-1, self.genus, -1)
        emb   = self.base.embed.unsqueeze(0).expand(B, -1, -1)
        out   = self.base.point_head(torch.cat([h_exp, emb], dim=2)) + self.base.point_bias.unsqueeze(0)
        raw_xy, sheet_logits = out[..., :2], out[..., 2]
        coords, aux = self.base.aj(raw_xy, sheet_logits, return_aux=True)  # (B, 2g)
        feats = self.torus(coords)
        logits = self.classifier(feats)
        if return_aux: return logits, aux
        return logits


In [5]:
# @title Compute ω-aware anchors and AJ normalization
import torch
import numpy as np

# 1) Build a scalar "ω-strength" map on the grid to avoid dead-gradient regions
#    Use sum_k |ω_k| (or |ω_k|^2); both work similarly.
with torch.no_grad():
    # Om_plus: (g, H, W) complex
    Wmap = Om_plus.abs().sum(dim=0)  # (H, W)
    Wmap_np = Wmap.cpu().numpy()

# 2) Mask out boundaries and branch points (stay away from peaks/singularities too)
H, W = Wmap_np.shape
gy = np.linspace(-1, 1, H)[:, None]
gx = np.linspace(-1, 1, W)[None, :]
edge_mask = (np.abs(gx) < 0.92) & (np.abs(gy) < 0.92)

# Distance from branch points in grid coordinates
grid_x = grid_r_np[None, :].repeat(H, axis=0)  # (H,W)
grid_y = grid_i_np[:, None].repeat(W, axis=1)  # (H,W)
bp_real = np.real(branch_pts).reshape(-1, 1, 1)
bp_imag = np.imag(branch_pts).reshape(-1, 1, 1)
d2 = (grid_x - bp_real)**2 + (grid_y - bp_imag)**2  # (P, H, W)
bp_mask = (d2.min(axis=0) > 0.25)  # keep points with dist > ~0.5

mask = edge_mask & bp_mask

# 3) Pick g anchor points among top-quantile ω regions, well separated
score = np.where(mask, Wmap_np, -np.inf)
th = np.quantile(score[score > -np.inf], 0.85)  # top 15% by ω-strength
cand = np.argwhere(score >= th)                 # list of (iy, ix)

# Farthest-point sampling to pick 'genus' diverse anchors
def farthest_k(points_hw, k):
    pts = points_hw.copy()
    # Start from the global maximum
    start = pts[np.argmax(score[tuple(pts.T)])]
    chosen = [start]
    if k == 1: return np.array(chosen)
    # Precompute physical coordinates for distances
    coords = np.stack([grid_x[tuple(pts.T)], grid_y[tuple(pts.T)]], axis=1)
    c0 = np.array([grid_x[start[0], start[1]], grid_y[start[0], start[1]]])[None, :]
    mind = np.sum((coords - c0)**2, axis=1)
    for _ in range(1, k):
        j = np.argmax(mind)
        chosen.append(pts[j])
        cj = coords[j][None, :]
        mind = np.minimum(mind, np.sum((coords - cj)**2, axis=1))
    return np.array(chosen)

anchors_hw = farthest_k(cand, genus)  # shape (g, 2) with (iy, ix)

# 4) Convert anchors to (x0, y0) coordinates
x0 = grid_r_np[anchors_hw[:, 1]]
y0 = grid_i_np[anchors_hw[:, 0]]
anchors_xy = np.stack([x0, y0], axis=1)  # (g, 2)

# 5) Compute per-channel mean/std of AJ coordinates for standardization
#    Pack integrals into 2g channels [Re..., Im...], then compute stats over H×W.
I_re = I_plus.real    # (g, H, W)
I_im = I_plus.imag    # (g, H, W)
I_ch = torch.cat([I_re, I_im], dim=0)  # (2g, H, W)
mu = I_ch.mean(dim=(1, 2))             # (2g,)
sigma = I_ch.std(dim=(1, 2)).clamp_min(1e-6)

# Save for later use
anchors_xy_t = torch.tensor(anchors_xy, dtype=torch.float32)
mu_t = mu.float()
sigma_t = sigma.float()

print("Chosen anchors (x0,y0):")
for i, (x, y) in enumerate(anchors_xy):
    print(f"  point {i:2d}: x0={x:+.3f}, y0={y:+.3f}")
print("\nAJ normalization ready: per-channel mean/std computed.")


Chosen anchors (x0,y0):
  point  0: x0=+2.084, y0=-4.232
  point  1: x0=-0.442, y0=+5.495
  point  2: x0=-5.242, y0=-0.947
  point  3: x0=+4.358, y0=+1.453
  point  4: x0=-2.463, y0=-4.611
  point  5: x0=-3.853, y0=+2.968
  point  6: x0=+4.611, y0=-1.958
  point  7: x0=+2.463, y0=+4.105
  point  8: x0=-3.600, y0=-2.589
  point  9: x0=-0.189, y0=-4.989
  point 10: x0=-4.989, y0=+1.200
  point 11: x0=-2.084, y0=+3.979
  point 12: x0=+3.979, y0=-3.853
  point 13: x0=+1.200, y0=+5.116
  point 14: x0=+1.200, y0=-5.242
  point 15: x0=-4.863, y0=-2.211
  point 16: x0=-3.474, y0=-3.853
  point 17: x0=-4.737, y0=+0.063
  point 18: x0=+3.095, y0=-4.737
  point 19: x0=-1.200, y0=+4.611
  point 20: x0=-4.232, y0=+1.958
  point 21: x0=-3.095, y0=+3.726
  point 22: x0=+4.105, y0=-2.842
  point 23: x0=-1.453, y0=-4.484
  point 24: x0=+2.968, y0=-3.726
  point 25: x0=+1.453, y0=+4.232
  point 26: x0=+0.316, y0=+4.863
  point 27: x0=+0.695, y0=-4.484
  point 28: x0=+2.211, y0=-5.116
  point 29: x0=-4.3

In [6]:
# ===========================
# 5) Data
# ===========================
tfm = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
train_ds = torchvision.datasets.MNIST(root="/content/data", train=True, download=True, transform=tfm)
test_ds  = torchvision.datasets.MNIST(root="/content/data", train=False, download=True, transform=tfm)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = torch.utils.data.DataLoader(test_ds,  batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

# ===========================
# 6) Train / Eval
# ===========================
def train_epoch(model, loader, opt, clip=1.0, lam_branch=1e-3, lam_bound=1e-3):
    model.train()
    ce = nn.CrossEntropyLoss()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        logits, aux = model(x, return_aux=True)

        loss = ce(logits, y)
        loss = loss + lam_branch * aux["branch_penalty"] + lam_bound * aux["bound_penalty"]
        loss.backward()
        if clip is not None:
            nn.utils.clip_grad_norm_(model.parameters(), clip)
        opt.step()

        tot += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n += x.size(0)
    return tot/n, 100.0*correct/n

@torch.no_grad()
def eval_epoch(model, loader):
    model.eval()
    ce = nn.CrossEntropyLoss()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        tot += ce(logits, y).item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n += x.size(0)
    return tot/n, 100.0*correct/n

100%|██████████| 9.91M/9.91M [00:00<00:00, 12.6MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 338kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.19MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.44MB/s]


In [10]:
# @title Train τ-free axis-periodic AJ model with smaller batch & AMP (OOM-safe)

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# -------------------- sanity: required globals --------------------
required = [
    "genus", "I_plus", "Om_plus", "grid_r", "grid_i", "branch_pts_t",
    "anchors_xy_t", "mu_t", "sigma_t",
    "train_loader", "test_loader",
    "AJMNIST_AxisPeriodic"
]
for name in required:
    if name not in globals():
        raise RuntimeError(f"Missing required object `{name}` before training cell.")

print(f"Current genus = {genus}")

# -------------------- build smaller DataLoaders to save memory --------------------
from torch.utils.data import DataLoader

orig_train_loader = train_loader
orig_test_loader  = test_loader

# Use a conservative batch size (e.g. 64) for AJ model
BASE_BS = 64
orig_bs = getattr(orig_train_loader, "batch_size", None)
small_bs = BASE_BS if (orig_bs is None) else min(BASE_BS, orig_bs)

train_loader_axis = DataLoader(
    orig_train_loader.dataset,
    batch_size=small_bs,
    shuffle=True,
    num_workers=getattr(orig_train_loader, "num_workers", 0),
    pin_memory=getattr(orig_train_loader, "pin_memory", False),
    drop_last=True
)
test_loader_axis = DataLoader(
    orig_test_loader.dataset,
    batch_size=min(256, small_bs*2),
    shuffle=False,
    num_workers=getattr(orig_test_loader, "num_workers", 0),
    pin_memory=getattr(orig_test_loader, "pin_memory", False),
    drop_last=False
)

print(f"Original train batch size: {orig_bs}")
print(f"AJ-axis train batch size : {small_bs}")

# -------------------- helpers --------------------
def count_params(m: nn.Module) -> int:
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

def heads_only_params(m: nn.Module) -> int:
    # conv trunk is in m.base.conv
    total = count_params(m)
    conv_params = sum(p.numel() for p in m.base.conv.parameters() if p.requires_grad)
    return total - conv_params

@torch.no_grad()
def eval_epoch(model: nn.Module, loader):
    model.eval()
    ce = nn.CrossEntropyLoss()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        tot += ce(logits, y).item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n += x.size(0)
    return tot/n, 100.0*correct/n

# -------------------- AMP-aware training epoch --------------------
USE_AMP = (device.type == "cuda")
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

def train_epoch_amp(model: nn.Module, loader, opt,
                    clip: float = 1.0,
                    lam_branch: float = 1e-3,
                    lam_bound:  float = 1e-3):
    model.train()
    ce = nn.CrossEntropyLoss()
    tot, correct, n = 0.0, 0, 0

    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            logits, aux = model(x, return_aux=True)
            loss = ce(logits, y)
            if aux is not None:
                loss = loss + lam_branch*aux.get("branch_penalty", 0.0) \
                             + lam_bound *aux.get("bound_penalty",  0.0)
        scaler.scale(loss).backward()
        if clip is not None:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), clip)
        scaler.step(opt)
        scaler.update()

        tot     += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n       += x.size(0)
    return tot/n, 100.0*correct/n

# optional AJ diagnostics hook if you defined one earlier
def maybe_aj_diags(model, loader, n_batches=2):
    if "aj_diags" in globals():
        aj_diags(model, loader, device, n_batches=n_batches)

# -------------------- instantiate axis-periodic AJ model --------------------
K = 2                 # number of Fourier harmonics per coordinate
LEARN_FREQS = False   # fixed frequencies for τ-free head
EMBED_DIM  = 4        # small embedding to keep params modest

aj_axis = AJMNIST_AxisPeriodic(
    genus,
    I_plus, Om_plus,
    grid_r, grid_i,
    branch_pts_t,
    anchors_xy_t.to(device),
    mu_t.to(device), sigma_t.to(device),
    embed_dim=EMBED_DIM,
    K=K,
    learnable_freqs=LEARN_FREQS
).to(device)

total_params = count_params(aj_axis)
head_params  = heads_only_params(aj_axis)
print(f"\nAJ axis-periodic (g={genus}, K={K}) params: {total_params:,}")
print(f"  heads-only (total minus conv trunk): {head_params:,}")

# -------------------- optimizer (two-tier LR) --------------------
fast_params = list(aj_axis.base.point_head.parameters()) + \
              [aj_axis.base.point_bias] + \
              list(aj_axis.torus.parameters()) + \
              list(aj_axis.classifier.parameters())

base_params = list(aj_axis.base.conv.parameters()) + \
              list(aj_axis.base.aj.parameters())

opt = torch.optim.AdamW(
    [
        {"params": base_params, "lr": 3e-4},
        {"params": fast_params, "lr": 1e-3},
    ],
    weight_decay=1e-4
)

# -------------------- training loop --------------------
EPOCHS     = 8
CLIP_NORM  = 1.0
LAM_BRANCH = 1e-3
LAM_BOUND  = 1e-3

print("\nStarting training with AMP =", USE_AMP)
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_epoch_amp(
        aj_axis, train_loader_axis, opt,
        clip=CLIP_NORM,
        lam_branch=LAM_BRANCH,
        lam_bound=LAM_BOUND
    )
    te_loss, te_acc = eval_epoch(aj_axis, test_loader_axis)
    print(f"[AJ axis K={K}] Epoch {ep:02d} | "
          f"train {tr_loss:.4f} / {tr_acc:.2f}% | "
          f"test {te_loss:.4f} / {te_acc:.2f}%")
    maybe_aj_diags(aj_axis, train_loader_axis, n_batches=2)


Using device: cuda
Current genus = 30
Original train batch size: 128
AJ-axis train batch size : 64

AJ axis-periodic (g=30, K=2) params: 22,251
  heads-only (total minus conv trunk): 3,435

Starting training with AMP = True


  scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
  with torch.cuda.amp.autocast(enabled=USE_AMP):


[AJ axis K=2] Epoch 01 | train 1.5287 / 43.11% | test 0.9165 / 67.98%
[AJ axis K=2] Epoch 02 | train 0.8225 / 71.01% | test 0.5715 / 80.85%
[AJ axis K=2] Epoch 03 | train 0.6117 / 79.12% | test 0.4531 / 85.51%
[AJ axis K=2] Epoch 04 | train 0.4956 / 83.48% | test 0.3597 / 88.29%
[AJ axis K=2] Epoch 05 | train 0.4202 / 86.33% | test 0.4384 / 85.59%
[AJ axis K=2] Epoch 06 | train 0.3756 / 87.84% | test 0.3611 / 88.01%
[AJ axis K=2] Epoch 07 | train 0.3244 / 89.66% | test 0.2530 / 92.23%
[AJ axis K=2] Epoch 08 | train 0.2728 / 91.40% | test 0.2718 / 91.28%


In [16]:
# @title RFF–AxisPeriodic (simple): conv → 2 → RFF(2→4gK) → 10  | AMP + small-batch

import math, torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

# ---------------- device & sanity ----------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)
assert 'genus' in globals(), "Please define `genus` first (e.g., genus = 30)."
assert 'train_loader' in globals() and 'test_loader' in globals(), "Please create train/test loaders first."
print(f"Genus = {genus} → 2g = {2*genus}")

# -------------- small-batch loaders --------------
BASE_BS = 64
orig_train_loader = train_loader
orig_test_loader  = test_loader
orig_bs = getattr(orig_train_loader, "batch_size", None)
small_bs = BASE_BS if (orig_bs is None) else min(BASE_BS, orig_bs)

train_loader_rff_axis = DataLoader(
    orig_train_loader.dataset,
    batch_size=small_bs, shuffle=True,
    num_workers=getattr(orig_train_loader, "num_workers", 0),
    pin_memory=getattr(orig_train_loader, "pin_memory", False),
    drop_last=True
)
test_loader_rff_axis = DataLoader(
    orig_test_loader.dataset,
    batch_size=min(256, small_bs*2), shuffle=False,
    num_workers=getattr(orig_test_loader, "num_workers", 0),
    pin_memory=getattr(orig_test_loader, "pin_memory", False),
    drop_last=False
)
print(f"RFF-axis train batch size: {train_loader_rff_axis.batch_size}")

# ---------------- helpers (scoped) ----------------
def make_conv_trunk_rffaxis():
    return nn.Sequential(
        nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.AdaptiveAvgPool2d((1, 1))
    )

def init_two_band_rffaxis_(tensor, low=1.0, high=3.0):
    with torch.no_grad():
        sign = (torch.randint(0, 2, tensor.shape, device=tensor.device)*2 - 1).float()
        mag  = torch.empty_like(tensor, dtype=torch.float32).uniform_(low, high)
        tensor.copy_(sign * mag)

def count_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

def conv_param_count(m):
    return sum(p.numel() for n,p in m.named_parameters()
               if p.requires_grad and n.startswith("conv"))

def heads_only_params(m):
    return count_params(m) - conv_param_count(m)

# ---------------- model (simple, no norm/dropout) ----------------
class RFFAxisPeriodicSimple(nn.Module):
    """
    conv(1→64) → Linear(64→2) →
      RFF lift: z ↦ [cos(Wz+b), sin(Wz+b)]
        with M = 2 g K so D_feat = 2M = 4 g K (matches AJ axis-periodic head)
      → Linear(D_feat→10)

    - W ∈ R^{M×2}, b ∈ R^M are frozen (random features)
    - No BN/LN/Dropout in the head (like earlier RFF baselines that trained well)
    """
    def __init__(self, genus: int, K: int = 2, omega_std: float = 1.0, rff_seed: int = 1234):
        super().__init__()
        self.genus, self.K = genus, K
        g = genus

        # conv + 2D bottleneck
        self.conv = make_conv_trunk_rffaxis()
        self.fc_down = nn.Linear(64, 2)
        init_two_band_rffaxis_(self.fc_down.weight); nn.init.zeros_(self.fc_down.bias)

        # RFF frequencies & phases (frozen)
        M = 2 * g * K                # M such that D_feat = 2M = 4 g K
        gen = torch.Generator().manual_seed(rff_seed)
        W = torch.randn(M, 2, generator=gen) * omega_std
        b = 2*math.pi * torch.rand(M, generator=gen)
        self.register_buffer("W", W.float())     # (M,2)
        self.register_buffer("b", b.float())     # (M,)

        # Classifier on 4 g K features
        self.classifier = nn.Linear(2*M, 10)

    def forward(self, x):
        B = x.size(0)
        h = self.conv(x).view(B, -1)                 # (B,64)
        z = self.fc_down(h)                          # (B,2)

        W = self.W.to(z.device, dtype=z.dtype)
        b = self.b.to(z.device, dtype=z.dtype)
        theta = z @ W.T + b                          # (B, M)
        feats = torch.cat([torch.cos(theta), torch.sin(theta)], dim=1)  # (B, 2M)

        logits = self.classifier(feats)              # (B,10)
        return logits

# --------------- AMP training utils ---------------
USE_AMP = (device.type == 'cuda')
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
ce = nn.CrossEntropyLoss()

def train_epoch_amp(model, loader, opt, clip=1.0):
    model.train()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            logits = model(x)
            loss   = ce(logits, y)
        scaler.scale(loss).backward()
        if clip is not None:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), clip)
        scaler.step(opt); scaler.update()
        tot     += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n       += x.size(0)
    return tot/n, 100.0*correct/n

@torch.no_grad()
def eval_epoch(model, loader):
    model.eval()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss   = ce(logits, y)
        tot     += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n       += x.size(0)
    return tot/n, 100.0*correct/n

# --------------- instantiate & train ---------------
K = 2               # match AJ axis-periodic harmonics
OMEGA_STD = 1.0     # RFF bandwidth; 0.6–1.2 are typical

rff_axis_simple = RFFAxisPeriodicSimple(genus, K=K, omega_std=OMEGA_STD).to(device)

total = count_params(rff_axis_simple)
conv  = conv_param_count(rff_axis_simple)
heads = heads_only_params(rff_axis_simple)
print(f"\nRFF–AxisPeriodic SIMPLE (g={genus}, K={K})")
print(f"  total params : {total:,}")
print(f"  conv trunk   : {conv:,}")
print(f"  heads-only   : {heads:,}  (total minus conv trunk)")

# Local (model-only) LRs — TUNE HERE if needed
BASE_LR = 3e-4   # conv trunk
HEAD_LR = 1e-3   # fc_down + classifier  ← increase to 1.5e-3 or 2e-3 if it's flat

base_params = list(rff_axis_simple.conv.parameters())
fast_params = list(rff_axis_simple.fc_down.parameters()) + list(rff_axis_simple.classifier.parameters())

opt_rff_axis_simple = torch.optim.AdamW(
    [{"params": base_params, "lr": BASE_LR},
     {"params": fast_params, "lr": HEAD_LR}],
    weight_decay=1e-4
)
print(f"Using BASE_LR={BASE_LR:.1e}, HEAD_LR={HEAD_LR:.1e}")

EPOCHS, CLIP = 8, 1.0
print("\nStarting RFF–AxisPeriodic SIMPLE training (AMP =", USE_AMP, ")")
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_epoch_amp(rff_axis_simple, train_loader_rff_axis, opt_rff_axis_simple, clip=CLIP)
    te_loss, te_acc = eval_epoch(rff_axis_simple, test_loader_rff_axis)
    print(f"[RFF–AxisPeriodic SIMPLE K={K}] Ep {ep:02d} | "
          f"train {tr_loss:.4f}/{tr_acc:.2f}% | test {te_loss:.4f}/{te_acc:.2f}%")


Using device: cuda
Genus = 30 → 2g = 60
RFF-axis train batch size: 64

RFF–AxisPeriodic SIMPLE (g=30, K=2)
  total params : 21,356
  conv trunk   : 18,816
  heads-only   : 2,540  (total minus conv trunk)
Using BASE_LR=3.0e-04, HEAD_LR=1.0e-03

Starting RFF–AxisPeriodic SIMPLE training (AMP = True )


  scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
  with torch.cuda.amp.autocast(enabled=USE_AMP):


[RFF–AxisPeriodic SIMPLE K=2] Ep 01 | train 1.2317/50.80% | test 0.8181/69.36%
[RFF–AxisPeriodic SIMPLE K=2] Ep 02 | train 0.8063/69.63% | test 0.6601/75.70%
[RFF–AxisPeriodic SIMPLE K=2] Ep 03 | train 0.6733/75.07% | test 0.6995/75.03%
[RFF–AxisPeriodic SIMPLE K=2] Ep 04 | train 0.6128/77.97% | test 0.5532/80.57%
[RFF–AxisPeriodic SIMPLE K=2] Ep 05 | train 0.5592/80.51% | test 0.5604/80.69%
[RFF–AxisPeriodic SIMPLE K=2] Ep 06 | train 0.5173/82.38% | test 0.4555/85.17%
[RFF–AxisPeriodic SIMPLE K=2] Ep 07 | train 0.4899/83.38% | test 0.4275/85.87%
[RFF–AxisPeriodic SIMPLE K=2] Ep 08 | train 0.4726/84.17% | test 0.3813/87.61%


In [11]:
# @title Compact τ-free axis-periodic AJ model (definition + training)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math, os, time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# --------- sanity checks for precomputed objects ---------
needed = [
    "genus", "I_plus", "Om_plus", "grid_r", "grid_i", "branch_pts_t",
    "anchors_xy_t", "mu_t", "sigma_t", "AJMNIST_Anchored", "TorusFeatures",
    "train_loader", "test_loader"
]
for name in needed:
    if name not in globals():
        raise RuntimeError(f"Missing `{name}`; make sure your genus-{globals().get('genus','?')} setup cells ran.")

print(f"Current genus = {genus}")

# If you already have smaller-batch loaders from the previous cell, reuse them:
train_loader_compact = train_loader_axis if 'train_loader_axis' in globals() else train_loader
test_loader_compact  = test_loader_axis  if 'test_loader_axis'  in globals() else test_loader
print("train_loader_compact batch_size:", train_loader_compact.batch_size)

# --------- helper: fixed orthogonal projector ---------
def make_fixed_orthogonal(D: int, r: int, seed: int = 1234):
    """
    Returns a D x r matrix with orthonormal columns (Q), as float32.
    Used as a non-trainable bottleneck from feature dim D to r.
    """
    assert r <= D, f"r={r} must be <= D={D}"
    gen = torch.Generator().manual_seed(seed)
    A = torch.randn(D, r, generator=gen)
    Q, _ = torch.linalg.qr(A, mode='reduced')  # D x r
    return Q[:, :r].float()

# --------- compact axis-periodic model ---------
class AJMNIST_AxisPeriodic_Compact(nn.Module):
    """
    AJ base (anchored + normalized) + axis-aligned Fourier torus features,
    followed by a FIXED orthogonal projector P: R^{D_feat} → R^r and a tiny classifier.

    D_feat = 2*(2g)*K (cos/sin per coordinate per harmonic)
    """
    def __init__(self, genus, I_plus, Om_plus, grid_r, grid_i, branch_pts,
                 anchors_xy, mu, sigma,
                 embed_dim=8, K=2, r=32,
                 learnable_freqs=False, proj_seed=1234):
        super().__init__()
        self.base = AJMNIST_Anchored(
            genus, I_plus, Om_plus,
            grid_r, grid_i, branch_pts,
            anchors_xy, mu, sigma,
            embed_dim=embed_dim
        )
        self.K = K
        D = 2*genus
        self.torus = TorusFeatures(D, K=K, learnable=learnable_freqs)
        D_feat = 2 * D * K   # cos+sin per dim per harmonic
        assert r <= D_feat, f"r={r} must be <= D_feat={D_feat} for K={K}, g={genus}"
        P = make_fixed_orthogonal(D_feat, r, seed=proj_seed)
        self.register_buffer("P", P)               # non-trainable projector
        self.scale = nn.Parameter(torch.ones(r))   # tiny learned diagonal scale
        self.classifier = nn.Linear(r, 10)         # small head

    @property
    def genus(self):
        return self.base.genus

    def forward(self, x, return_aux=False):
        B = x.size(0)
        # shared conv + point head (same as axis model)
        h = self.base.conv(x).view(B, -1)
        h_exp = h.unsqueeze(1).expand(-1, self.genus, -1)
        emb   = self.base.embed.unsqueeze(0).expand(B, -1, -1)
        out   = self.base.point_head(torch.cat([h_exp, emb], dim=2)) \
              + self.base.point_bias.unsqueeze(0)
        raw_xy, sheet_logits = out[..., :2], out[..., 2]

        coords_std, aux = self.base.aj(raw_xy, sheet_logits, return_aux=True)  # (B, 2g)
        feats = self.torus(coords_std)                                         # (B, D_feat)
        P = self.P.to(feats.device, dtype=feats.dtype)                         # (D_feat, r)
        z = feats @ P                                                          # (B, r)
        z = z * self.scale                                                     # (B, r)
        logits = self.classifier(z)                                            # (B, 10)
        if return_aux:
            return logits, aux
        return logits

# --------- param counting helpers ---------
def count_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

def heads_only_params(m):
    total = count_params(m)
    conv_params = sum(p.numel() for p in m.base.conv.parameters() if p.requires_grad)
    return total - conv_params

# --------- AMP-aware training helpers (reuse if present) ---------
# Removed the 'if' guard to ensure function definition is always updated.
ce = nn.CrossEntropyLoss()
USE_AMP = (device.type == 'cuda')
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

def train_epoch_amp(model, loader, opt,
                    clip=1.0, lam_branch=1e-3, lam_bound=1e-3):
    model.train()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            logits, aux = model(x, return_aux=True)
            loss = ce(logits, y)
            loss = loss + lam_branch*aux.get("branch_penalty", 0.0) \
                       + lam_bound *aux.get("bound_penalty",  0.0)
        scaler.scale(loss).backward()
        if clip is not None:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), clip)
        scaler.step(opt)
        scaler.update()
        tot     += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n       += x.size(0)
    return tot/n, 100.0*correct/n

# Removed the 'if' guard to ensure function definition is always updated.
ce_eval = nn.CrossEntropyLoss()
@torch.no_grad()
def eval_epoch(model, loader):
    model.eval()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = ce_eval(logits, y)
        tot     += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n += x.size(0)
    return tot/n, 100.0*correct/n

# Optional AJ diagnostics
def maybe_aj_diags(model, loader, n_batches=2):
    if "aj_diags" in globals():
        aj_diags(model, loader, device, n_batches=n_batches)

# --------- move lookup tables to device if needed ---------
# If you already have device copies, reuse; else create them.
I_plus_dev  = I_plus.to(device)   if I_plus.device.type  == 'cpu' else I_plus
Om_plus_dev = Om_plus.to(device)  if Om_plus.device.type == 'cpu' else Om_plus
grid_r_dev  = grid_r.to(device)
grid_i_dev  = grid_i.to(device)
branch_pts_dev = branch_pts_t.to(device)
anchors_xy_dev = anchors_xy_t.to(device)
mu_dev, sigma_dev = mu_t.to(device), sigma_t.to(device)

# --------- instantiate compact model ---------
K = 2       # harmonics per coordinate (same as axis model)
r = 32      # compact dimension (you can tune this)
EMBED_DIM = 4

aj_axis_compact = AJMNIST_AxisPeriodic_Compact(
    genus,
    I_plus_dev, Om_plus_dev,
    grid_r_dev, grid_i_dev,
    branch_pts_dev,
    anchors_xy_dev,
    mu_dev, sigma_dev,
    embed_dim=EMBED_DIM,
    K=K, r=r,
    learnable_freqs=False,
    proj_seed=1234
).to(device)

total = count_params(aj_axis_compact)
heads = heads_only_params(aj_axis_compact)
conv  = sum(p.numel() for p in aj_axis_compact.base.conv.parameters() if p.requires_grad)
print(f"\nAJ axis-periodic COMPACT (g={genus}, K={K}, r={r})")
print(f"  total params  : {total:,}")
print(f"  conv trunk    : {conv:,}")
print(f"  heads-only    : {heads:,} (total minus conv trunk)")

# --------- optimizer (two-tier) ---------
fast_params = list(aj_axis_compact.base.point_head.parameters()) + \
              [aj_axis_compact.base.point_bias] + \
              list(aj_axis_compact.torus.parameters()) + \
              list(aj_axis_compact.classifier.parameters())

base_params = list(aj_axis_compact.base.conv.parameters()) + \
              list(aj_axis_compact.base.aj.parameters())

opt_compact = torch.optim.AdamW(
    [
        {"params": base_params, "lr": 3e-4},
        {"params": fast_params, "lr": 1e-3},
    ],
    weight_decay=1e-4
)

# --------- train compact model ---------
EPOCHS     = 8
CLIP_NORM  = 1.0
LAM_BRANCH = 1e-3
LAM_BOUND  = 1e-3

USE_AMP = (device.type == 'cuda')
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

print("\nStarting training of COMPACT axis-periodic model (AMP =", USE_AMP, ")")
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_epoch_amp(
        aj_axis_compact, train_loader_compact, opt_compact,
        clip=CLIP_NORM, lam_branch=LAM_BRANCH, lam_bound=LAM_BOUND
    )
    te_loss, te_acc = eval_epoch(aj_axis_compact, test_loader_compact)
    print(f"[AJ axis COMPACT K={K}, r={r}] Ep {ep:02d} | "
          f"train {tr_loss:.4f} / {tr_acc:.2f}% | "
          f"test {te_loss:.4f} / {te_acc:.2f}%")
    maybe_aj_diags(aj_axis_compact, train_loader_compact, n_batches=2)


Using device: cuda
Current genus = 30
train_loader_compact batch_size: 64

AJ axis-periodic COMPACT (g=30, K=2, r=32)
  total params  : 20,203
  conv trunk    : 18,816
  heads-only    : 1,387 (total minus conv trunk)

Starting training of COMPACT axis-periodic model (AMP = True )


  scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
  scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
  with torch.cuda.amp.autocast(enabled=USE_AMP):


[AJ axis COMPACT K=2, r=32] Ep 01 | train 1.9336 / 30.23% | test 1.3924 / 48.78%
[AJ axis COMPACT K=2, r=32] Ep 02 | train 1.1924 / 54.39% | test 1.1871 / 54.16%
[AJ axis COMPACT K=2, r=32] Ep 03 | train 0.9494 / 66.23% | test 1.1348 / 56.43%
[AJ axis COMPACT K=2, r=32] Ep 04 | train 0.7695 / 73.50% | test 0.8765 / 67.23%
[AJ axis COMPACT K=2, r=32] Ep 05 | train 0.6238 / 79.56% | test 0.5535 / 82.19%
[AJ axis COMPACT K=2, r=32] Ep 06 | train 0.5504 / 82.64% | test 0.4156 / 87.85%
[AJ axis COMPACT K=2, r=32] Ep 07 | train 0.4398 / 86.82% | test 0.3649 / 88.90%
[AJ axis COMPACT K=2, r=32] Ep 08 | train 0.3583 / 89.19% | test 0.3125 / 91.04%


In [12]:
# @title RFF-Compact baseline (conv → 2 → RFF(2→4gK) → fixed P→r → 10), AMP-safe

import torch, math
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# -------- Sanity --------
assert 'genus' in globals(), "Define `genus` first (e.g., genus = 30)."
assert 'train_loader' in globals() and 'test_loader' in globals(), "Load your dataset loaders first."
print(f"Genus = {genus} → 2g = {2*genus}")

# -------- Small-batch loaders (reuse if present) --------
BASE_BS = 64
orig_train_loader = train_loader
orig_test_loader  = test_loader
orig_bs = getattr(orig_train_loader, "batch_size", None)
small_bs = BASE_BS if (orig_bs is None) else min(BASE_BS, orig_bs)

train_loader_rff = DataLoader(
    orig_train_loader.dataset,
    batch_size=small_bs, shuffle=True,
    num_workers=getattr(orig_train_loader, "num_workers", 0),
    pin_memory=getattr(orig_train_loader, "pin_memory", False),
    drop_last=True
)
test_loader_rff = DataLoader(
    orig_test_loader.dataset,
    batch_size=min(256, small_bs*2), shuffle=False,
    num_workers=getattr(orig_test_loader, "num_workers", 0),
    pin_memory=getattr(orig_test_loader, "pin_memory", False),
    drop_last=False
)
print(f"RFF-compact train batch size: {train_loader_rff.batch_size}")

# -------- Helpers --------
def make_conv_trunk():
    return nn.Sequential(
        nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.AdaptiveAvgPool2d((1, 1))
    )

def init_two_band_(tensor, low=1.0, high=3.0):
    with torch.no_grad():
        sign = (torch.randint(0, 2, tensor.shape, device=tensor.device)*2 - 1).float()
        mag  = torch.empty_like(tensor, dtype=torch.float32).uniform_(low, high)
        tensor.copy_(sign * mag)

def make_fixed_orthogonal(D: int, r: int, seed: int = 1234):
    assert r <= D, f"r={r} must be ≤ D={D}"
    gen = torch.Generator().manual_seed(seed)
    A = torch.randn(D, r, generator=gen)
    Q, _ = torch.linalg.qr(A, mode='reduced')   # D×r
    return Q[:, :r].float()

# -------- RFF-Compact model (mirrors AJ-compact head) --------
class RFFCompact2DTo2g(nn.Module):
    """
    conv(1→64) → Linear(64→2) → BN(no affine) → α·z
      → RFF lift: z ↦ [cos(Wz+b), sin(Wz+b)] with M = 2gK (so D_feat = 2M = 4gK)
      → fixed orthonormal projector P: D_feat→r
      → diag scale (learnable r)
      → Linear(r→10)
    All RFF params (W,b) and projector P are frozen.
    """
    def __init__(self, genus: int, K: int = 2, r: int = 32,
                 omega_std: float = 1.0, proj_seed: int = 1234, rff_seed: int = 1234):
        super().__init__()
        self.genus = genus
        self.K = K
        g = genus

        # Shared trunk + 2D bottleneck
        self.conv = make_conv_trunk()
        self.fc_down = nn.Linear(64, 2)
        init_two_band_(self.fc_down.weight); nn.init.zeros_(self.fc_down.bias)

        # Stabilize the 2D code numerically (no learnable affine)
        self.bn2 = nn.BatchNorm1d(2, affine=False, eps=1e-5, momentum=0.1)
        self.alpha = nn.Parameter(torch.tensor(1.0))   # global scale on z

        # RFF: choose M so that D_feat matches AJ Fourier head (D_feat = 4 g K)
        M = 2 * g * K
        self.M = M
        gen = torch.Generator().manual_seed(rff_seed)
        W = torch.randn(M, 2, generator=gen) * omega_std      # (M,2) rows are ω_i
        b = 2*math.pi * torch.rand(M, generator=gen)          # (M,)
        self.register_buffer("W", W.float())
        self.register_buffer("b", b.float())

        # Fixed projector P: D_feat × r, with D_feat = 2M = 4 g K
        D_feat = 2 * M
        assert r <= D_feat, f"r={r} must be ≤ D_feat={D_feat} for g={g}, K={K}"
        P = make_fixed_orthogonal(D_feat, r, seed=proj_seed)
        self.register_buffer("P", P)                     # (D_feat, r), frozen

        # Tiny trainable head on top
        self.scale = nn.Parameter(torch.ones(r))         # diag scale after P
        self.classifier = nn.Linear(r, 10)

    def forward(self, x, return_aux: bool = False):
        B = x.size(0)
        h = self.conv(x).view(B, -1)                     # (B,64)
        z = self.fc_down(h)                              # (B,2)
        z = self.bn2(z) * self.alpha                     # (B,2)

        # RFF lift: (B,2) → (B, 2M) with cos & sin
        # theta = z @ W^T + b
        W = self.W.to(z.device, dtype=z.dtype)
        b = self.b.to(z.device, dtype=z.dtype)
        theta = z @ W.T + b                              # (B, M)
        feats = torch.cat([torch.cos(theta), torch.sin(theta)], dim=1)  # (B, 2M)

        # Compact head: project to r, scale, classify
        P = self.P.to(feats.device, dtype=feats.dtype)   # (2M, r)
        zc = (feats @ P) * self.scale                    # (B, r)
        logits = self.classifier(zc)                     # (B, 10)

        if return_aux:
            zero = torch.zeros((), device=x.device)
            aux = {"branch_penalty": zero, "bound_penalty": zero}
            return logits, aux
        return logits

# -------- Parameter accounting --------
def count_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

def conv_param_count(m):
    return sum(p.numel() for n,p in m.named_parameters()
               if p.requires_grad and n.startswith("conv"))

def heads_only_params(m):
    return count_params(m) - conv_param_count(m)

# -------- AMP training helpers --------
USE_AMP = (device.type == 'cuda')
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
ce = nn.CrossEntropyLoss()

def train_epoch_amp(model, loader, opt, clip=1.0):
    model.train()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            logits = model(x)
            loss   = ce(logits, y)
        scaler.scale(loss).backward()
        if clip is not None:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), clip)
        scaler.step(opt); scaler.update()
        tot += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n += x.size(0)
    return tot/n, 100.0*correct/n

@torch.no_grad()
def eval_epoch(model, loader):
    model.eval()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        tot += ce(logits, y).item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n += x.size(0)
    return tot/n, 100.0*correct/n

# -------- Instantiate & train (match AJ-compact defaults) --------
K = 2
r = 32
OMEGA_STD = 1.0         # try 0.5, 1.0, 2.0; α (learned) will also adapt
model = RFFCompact2DTo2g(genus, K=K, r=r, omega_std=OMEGA_STD,
                         proj_seed=1234, rff_seed=1234).to(device)

total = count_params(model)
conv  = conv_param_count(model)
heads = heads_only_params(model)
print(f"\nRFF-Compact (g={genus}, K={K}, r={r})")
print(f"  total params : {total:,}")
print(f"  conv trunk   : {conv:,}")
print(f"  heads-only   : {heads:,}  (total minus conv trunk)")

# Two-tier optimizer, analogous to AJ-compact (base=conv; fast=head)
fast_params = list(model.fc_down.parameters()) + [model.alpha, model.scale] + list(model.classifier.parameters())
base_params = list(model.conv.parameters())
opt = torch.optim.AdamW(
    [{"params": base_params, "lr": 3e-4},
     {"params": fast_params, "lr": 1e-3}],
    weight_decay=1e-4
)

EPOCHS, CLIP = 8, 1.0
print("\nStarting training (AMP =", USE_AMP, ")")
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_epoch_amp(model, train_loader_rff, opt, clip=CLIP)
    te_loss, te_acc = eval_epoch(model, test_loader_rff)
    print(f"[RFF-Compact K={K}, r={r}] Ep {ep:02d} | "
          f"train {tr_loss:.4f}/{tr_acc:.2f}% | test {te_loss:.4f}/{te_acc:.2f}%")


Using device: cuda
Genus = 30 → 2g = 60
RFF-compact train batch size: 64

RFF-Compact (g=30, K=2, r=32)
  total params : 19,309
  conv trunk   : 18,816
  heads-only   : 493  (total minus conv trunk)

Starting training (AMP = True )


  scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
  with torch.cuda.amp.autocast(enabled=USE_AMP):


[RFF-Compact K=2, r=32] Ep 01 | train 1.1453/64.04% | test 1.7621/36.55%
[RFF-Compact K=2, r=32] Ep 02 | train 0.6380/78.74% | test 1.0487/62.77%
[RFF-Compact K=2, r=32] Ep 03 | train 0.5454/81.90% | test 0.7104/74.68%
[RFF-Compact K=2, r=32] Ep 04 | train 0.5025/83.34% | test 0.8979/68.43%
[RFF-Compact K=2, r=32] Ep 05 | train 0.4711/84.51% | test 0.6324/78.21%
[RFF-Compact K=2, r=32] Ep 06 | train 0.4479/85.37% | test 0.5578/80.50%
[RFF-Compact K=2, r=32] Ep 07 | train 0.4231/86.23% | test 0.3735/88.41%
[RFF-Compact K=2, r=32] Ep 08 | train 0.4104/86.58% | test 0.4425/85.79%


In [13]:
# @title Random 2→2g nonlinear baselines (RFF and Frozen-MLP) — definition + training

import torch, math
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)
assert 'genus' in globals(), "Please define `genus` first."
assert 'train_loader' in globals() and 'test_loader' in globals(), "Please define train/test loaders."

# -------------------- reuse smaller DataLoaders (like AJ cells) --------------------
BASE_BS = 64
orig_train_loader = train_loader
orig_test_loader  = test_loader
orig_bs = getattr(orig_train_loader, "batch_size", None)
small_bs = BASE_BS if (orig_bs is None) else min(BASE_BS, orig_bs)

train_loader_rnd = DataLoader(
    orig_train_loader.dataset, batch_size=small_bs, shuffle=True,
    num_workers=getattr(orig_train_loader, "num_workers", 0),
    pin_memory=getattr(orig_train_loader, "pin_memory", False),
    drop_last=True
)
test_loader_rnd = DataLoader(
    orig_test_loader.dataset, batch_size=min(256, small_bs*2), shuffle=False,
    num_workers=getattr(orig_test_loader, "num_workers", 0),
    pin_memory=getattr(orig_test_loader, "pin_memory", False),
    drop_last=False
)
print(f"Random baseline train batch size: {train_loader_rnd.batch_size}")

# -------------------- shared conv trunk (same as AJ baselines) --------------------
def make_conv_trunk():
    return nn.Sequential(
        nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.AdaptiveAvgPool2d((1, 1))  # → (B,64,1,1)
    )

def init_two_band_(tensor, low=1.0, high=3.0):
    with torch.no_grad():
        sign = (torch.randint(0, 2, tensor.shape, device=tensor.device)*2 - 1).float()
        mag  = torch.empty_like(tensor, dtype=torch.float32).uniform_(low, high)
        tensor.copy_(sign * mag)

# -------------------- Random Fourier Features 2→2g (frozen) --------------------
class RandomFourier2Dto2g(nn.Module):
    """
    z ∈ R^2  →  [cos(W z + b), sin(W z + b)] ∈ R^{2g}
    W ∈ R^{g×2}, b ∈ R^g are frozen buffers.
    """
    def __init__(self, genus: int, freq_scale: float = 1.0, seed: int = 1234, learn_gain: bool = False):
        super().__init__()
        self.genus = genus
        g = genus
        # Frozen random frequencies & phases
        gen = torch.Generator().manual_seed(seed)
        W = torch.randn(g, 2, generator=gen) * freq_scale
        b = 2*math.pi * torch.rand(g, generator=gen)
        self.register_buffer("W", W.float())
        self.register_buffer("b", b.float())
        # Optional learnable per-feature gain (small, to keep fairness)
        self.gain = nn.Parameter(torch.ones(2*g)) if learn_gain else None

    def forward(self, z):  # z: (B,2)
        theta = z @ self.W.T + self.b           # (B,g)
        feats = torch.cat([torch.cos(theta), torch.sin(theta)], dim=1)  # (B, 2g)
        if self.gain is not None:
            feats = feats * self.gain
        return feats

# -------------------- Frozen random MLP 2→2g --------------------
class FrozenMLP2Dto2g(nn.Module):
    """
    z ∈ R^2 → h → 2g with a GELU, weights frozen at init.
    """
    def __init__(self, genus: int, hidden: int = None, seed: int = 1234):
        super().__init__()
        g = genus
        D = 2*g
        if hidden is None:
            hidden = min(4*g, 128)  # modest hidden width by default
        self.fc1 = nn.Linear(2, hidden)
        self.fc2 = nn.Linear(hidden, D)
        # Init & freeze
        torch.manual_seed(seed)
        nn.init.xavier_uniform_(self.fc1.weight); nn.init.zeros_(self.fc1.bias)
        nn.init.xavier_uniform_(self.fc2.weight); nn.init.zeros_(self.fc2.bias)
        for p in self.fc1.parameters(): p.requires_grad = False
        for p in self.fc2.parameters(): p.requires_grad = False
        self.act = nn.GELU()

    def forward(self, z):  # z: (B,2)
        h  = self.act(self.fc1(z))
        y  = self.fc2(h)              # (B,2g)
        return y

# -------------------- Projection2D → RandomNonlin(2→2g) → 10 --------------------
class Projection2D_RandomLift(nn.Module):
    """
    conv(1→64) → Linear(64→2) → RandomLift(2→2g) [frozen] → Linear(2g→10)
    """
    def __init__(self, genus: int, lift_kind: str = "rff", **lift_kwargs):
        super().__init__()
        self.genus = genus
        self.conv = make_conv_trunk()
        self.fc_down = nn.Linear(64, 2)
        # Lift: choose 'rff' or 'mlp'
        if lift_kind == "rff":
            self.lift = RandomFourier2Dto2g(genus, **lift_kwargs)
        elif lift_kind == "mlp":
            self.lift = FrozenMLP2Dto2g(genus, **lift_kwargs)
        else:
            raise ValueError("lift_kind must be 'rff' or 'mlp'")
        self.classifier = nn.Linear(2*genus, 10)
        # init similar to earlier baselines
        init_two_band_(self.fc_down.weight); nn.init.zeros_(self.fc_down.bias)

    def forward(self, x):
        B = x.size(0)
        h  = self.conv(x).view(B, -1)   # (B,64)
        z2 = self.fc_down(h)            # (B,2)
        z  = self.lift(z2)              # (B,2g) frozen nonlinear lift
        logits = self.classifier(z)     # (B,10)
        return logits

# -------------------- Training helpers (AMP) --------------------
def count_trainable_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

@torch.no_grad()
def eval_epoch(model, loader):
    model.eval()
    ce = nn.CrossEntropyLoss()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = ce(logits, y)
        tot     += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n += x.size(0)
    return tot/n, 100.0*correct/n

USE_AMP = (device.type == "cuda")
scaler  = torch.cuda.amp.GradScaler(enabled=USE_AMP)

def train_epoch_amp(model, loader, opt, clip=1.0):
    model.train()
    ce = nn.CrossEntropyLoss()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            logits = model(x)
            loss   = ce(logits, y)
        scaler.scale(loss).backward()
        if clip is not None:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), clip)
        scaler.step(opt)
        scaler.update()
        tot     += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n       += x.size(0)
    return tot/n, 100.0*correct/n

# -------------------- Instantiate & train: choose your lift --------------------
LIFT_KIND   = "rff"     # "rff" or "mlp"
RFF_SCALE   = 1.0       # typical 0.5–2.0; affects oscillation rate of cos/sin
RFF_GAIN    = False     # set True to allow a tiny learnable per-feature scale
MLP_HIDDEN  = None      # if LIFT_KIND="mlp", None→min(4g,128)
EPOCHS      = 8
LR          = 3e-4
WDECAY      = 1e-4
CLIP_NORM   = 1.0

if LIFT_KIND == "rff":
    model = Projection2D_RandomLift(genus, lift_kind="rff", freq_scale=RFF_SCALE, learn_gain=RFF_GAIN).to(device)
else:
    model = Projection2D_RandomLift(genus, lift_kind="mlp", hidden=MLP_HIDDEN).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WDECAY)

trainable = count_trainable_params(model)
conv_params = sum(p.numel() for n,p in model.named_parameters()
                  if p.requires_grad and n.startswith("conv"))
heads_only = trainable - conv_params
print(f"\nRandom 2→2g baseline ({LIFT_KIND}) — trainable params: {trainable:,}")
print(f"  conv trunk (trainable) : {conv_params:,}")
print(f"  heads-only (trainable) : {heads_only:,} (total minus conv trunk)")

print("\nStarting training (AMP =", USE_AMP, ")")
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_epoch_amp(model, train_loader_rnd, opt, clip=CLIP_NORM)
    te_loss, te_acc = eval_epoch(model, test_loader_rnd)
    print(f"[RANDOM {LIFT_KIND} g={genus}] Ep {ep:02d} | train {tr_loss:.4f}/{tr_acc:.2f}% | test {te_loss:.4f}/{te_acc:.2f}%")


Using device: cuda
Random baseline train batch size: 64

Random 2→2g baseline (rff) — trainable params: 19,556
  conv trunk (trainable) : 18,816
  heads-only (trainable) : 740 (total minus conv trunk)

Starting training (AMP = True )


  scaler  = torch.cuda.amp.GradScaler(enabled=USE_AMP)
  with torch.cuda.amp.autocast(enabled=USE_AMP):


[RANDOM rff g=30] Ep 01 | train 1.5800/46.61% | test 1.2498/56.25%
[RANDOM rff g=30] Ep 02 | train 1.0110/66.91% | test 0.8972/68.53%
[RANDOM rff g=30] Ep 03 | train 0.8105/72.76% | test 0.6477/79.20%
[RANDOM rff g=30] Ep 04 | train 0.6922/76.75% | test 0.9236/65.13%
[RANDOM rff g=30] Ep 05 | train 0.6384/78.38% | test 0.6232/77.96%
[RANDOM rff g=30] Ep 06 | train 0.6169/79.11% | test 0.5482/82.09%
[RANDOM rff g=30] Ep 07 | train 0.5643/81.17% | test 0.4730/84.76%
[RANDOM rff g=30] Ep 08 | train 0.5459/81.88% | test 0.5397/82.07%


In [14]:
# @title Random 2→2g nonlinear baselines (RFF and Frozen-MLP) — definition + training

import torch, math
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)
assert 'genus' in globals(), "Please define `genus` first."
assert 'train_loader' in globals() and 'test_loader' in globals(), "Please define train/test loaders."

# -------------------- reuse smaller DataLoaders (like AJ cells) --------------------
BASE_BS = 64
orig_train_loader = train_loader
orig_test_loader  = test_loader
orig_bs = getattr(orig_train_loader, "batch_size", None)
small_bs = BASE_BS if (orig_bs is None) else min(BASE_BS, orig_bs)

train_loader_rnd = DataLoader(
    orig_train_loader.dataset, batch_size=small_bs, shuffle=True,
    num_workers=getattr(orig_train_loader, "num_workers", 0),
    pin_memory=getattr(orig_train_loader, "pin_memory", False),
    drop_last=True
)
test_loader_rnd = DataLoader(
    orig_test_loader.dataset, batch_size=min(256, small_bs*2), shuffle=False,
    num_workers=getattr(orig_test_loader, "num_workers", 0),
    pin_memory=getattr(orig_test_loader, "pin_memory", False),
    drop_last=False
)
print(f"Random baseline train batch size: {train_loader_rnd.batch_size}")

# -------------------- shared conv trunk (same as AJ baselines) --------------------
def make_conv_trunk():
    return nn.Sequential(
        nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.AdaptiveAvgPool2d((1, 1))  # → (B,64,1,1)
    )

def init_two_band_(tensor, low=1.0, high=3.0):
    with torch.no_grad():
        sign = (torch.randint(0, 2, tensor.shape, device=tensor.device)*2 - 1).float()
        mag  = torch.empty_like(tensor, dtype=torch.float32).uniform_(low, high)
        tensor.copy_(sign * mag)

# -------------------- Random Fourier Features 2→2g (frozen) --------------------
class RandomFourier2Dto2g(nn.Module):
    """
    z ∈ R^2  →  [cos(W z + b), sin(W z + b)] ∈ R^{2g}
    W ∈ R^{g×2}, b ∈ R^g are frozen buffers.
    """
    def __init__(self, genus: int, freq_scale: float = 1.0, seed: int = 1234, learn_gain: bool = False):
        super().__init__()
        self.genus = genus
        g = genus
        # Frozen random frequencies & phases
        gen = torch.Generator().manual_seed(seed)
        W = torch.randn(g, 2, generator=gen) * freq_scale
        b = 2*math.pi * torch.rand(g, generator=gen)
        self.register_buffer("W", W.float())
        self.register_buffer("b", b.float())
        # Optional learnable per-feature gain (small, to keep fairness)
        self.gain = nn.Parameter(torch.ones(2*g)) if learn_gain else None

    def forward(self, z):  # z: (B,2)
        theta = z @ self.W.T + self.b           # (B,g)
        feats = torch.cat([torch.cos(theta), torch.sin(theta)], dim=1)  # (B, 2g)
        if self.gain is not None:
            feats = feats * self.gain
        return feats

# -------------------- Frozen random MLP 2→2g --------------------
class FrozenMLP2Dto2g(nn.Module):
    """
    z ∈ R^2 → h → 2g with a GELU, weights frozen at init.
    """
    def __init__(self, genus: int, hidden: int = None, seed: int = 1234):
        super().__init__()
        g = genus
        D = 2*g
        if hidden is None:
            hidden = min(4*g, 128)  # modest hidden width by default
        self.fc1 = nn.Linear(2, hidden)
        self.fc2 = nn.Linear(hidden, D)
        # Init & freeze
        torch.manual_seed(seed)
        nn.init.xavier_uniform_(self.fc1.weight); nn.init.zeros_(self.fc1.bias)
        nn.init.xavier_uniform_(self.fc2.weight); nn.init.zeros_(self.fc2.bias)
        for p in self.fc1.parameters(): p.requires_grad = False
        for p in self.fc2.parameters(): p.requires_grad = False
        self.act = nn.GELU()

    def forward(self, z):  # z: (B,2)
        h  = self.act(self.fc1(z))
        y  = self.fc2(h)              # (B,2g)
        return y

# -------------------- Projection2D → RandomNonlin(2→2g) → 10 --------------------
class Projection2D_RandomLift(nn.Module):
    """
    conv(1→64) → Linear(64→2) → RandomLift(2→2g) [frozen] → Linear(2g→10)
    """
    def __init__(self, genus: int, lift_kind: str = "rff", **lift_kwargs):
        super().__init__()
        self.genus = genus
        self.conv = make_conv_trunk()
        self.fc_down = nn.Linear(64, 2)
        # Lift: choose 'rff' or 'mlp'
        if lift_kind == "rff":
            self.lift = RandomFourier2Dto2g(genus, **lift_kwargs)
        elif lift_kind == "mlp":
            self.lift = FrozenMLP2Dto2g(genus, **lift_kwargs)
        else:
            raise ValueError("lift_kind must be 'rff' or 'mlp'")
        self.classifier = nn.Linear(2*genus, 10)
        # init similar to earlier baselines
        init_two_band_(self.fc_down.weight); nn.init.zeros_(self.fc_down.bias)

    def forward(self, x):
        B = x.size(0)
        h  = self.conv(x).view(B, -1)   # (B,64)
        z2 = self.fc_down(h)            # (B,2)
        z  = self.lift(z2)              # (B,2g) frozen nonlinear lift
        logits = self.classifier(z)     # (B,10)
        return logits

# -------------------- Training helpers (AMP) --------------------
def count_trainable_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

@torch.no_grad()
def eval_epoch(model, loader):
    model.eval()
    ce = nn.CrossEntropyLoss()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = ce(logits, y)
        tot     += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n += x.size(0)
    return tot/n, 100.0*correct/n

USE_AMP = (device.type == "cuda")
scaler  = torch.cuda.amp.GradScaler(enabled=USE_AMP)

def train_epoch_amp(model, loader, opt, clip=1.0):
    model.train()
    ce = nn.CrossEntropyLoss()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            logits = model(x)
            loss   = ce(logits, y)
        scaler.scale(loss).backward()
        if clip is not None:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), clip)
        scaler.step(opt)
        scaler.update()
        tot     += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n       += x.size(0)
    return tot/n, 100.0*correct/n

# -------------------- Instantiate & train: choose your lift --------------------
LIFT_KIND   = "mlp"     # "rff" or "mlp"
RFF_SCALE   = 1.0       # typical 0.5–2.0; affects oscillation rate of cos/sin
RFF_GAIN    = False     # set True to allow a tiny learnable per-feature scale
MLP_HIDDEN  = None      # if LIFT_KIND="mlp", None→min(4g,128)
EPOCHS      = 8
LR          = 3e-4
WDECAY      = 1e-4
CLIP_NORM   = 1.0

if LIFT_KIND == "rff":
    model = Projection2D_RandomLift(genus, lift_kind="rff", freq_scale=RFF_SCALE, learn_gain=RFF_GAIN).to(device)
else:
    model = Projection2D_RandomLift(genus, lift_kind="mlp", hidden=MLP_HIDDEN).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WDECAY)

trainable = count_trainable_params(model)
conv_params = sum(p.numel() for n,p in model.named_parameters()
                  if p.requires_grad and n.startswith("conv"))
heads_only = trainable - conv_params
print(f"\nRandom 2→2g baseline ({LIFT_KIND}) — trainable params: {trainable:,}")
print(f"  conv trunk (trainable) : {conv_params:,}")
print(f"  heads-only (trainable) : {heads_only:,} (total minus conv trunk)")

print("\nStarting training (AMP =", USE_AMP, ")")
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_epoch_amp(model, train_loader_rnd, opt, clip=CLIP_NORM)
    te_loss, te_acc = eval_epoch(model, test_loader_rnd)
    print(f"[RANDOM {LIFT_KIND} g={genus}] Ep {ep:02d} | train {tr_loss:.4f}/{tr_acc:.2f}% | test {te_loss:.4f}/{te_acc:.2f}%")


Using device: cuda
Random baseline train batch size: 64

Random 2→2g baseline (mlp) — trainable params: 19,556
  conv trunk (trainable) : 18,816
  heads-only (trainable) : 740 (total minus conv trunk)

Starting training (AMP = True )


  scaler  = torch.cuda.amp.GradScaler(enabled=USE_AMP)
  with torch.cuda.amp.autocast(enabled=USE_AMP):


[RANDOM mlp g=30] Ep 01 | train 1.5637/37.27% | test 1.3217/46.85%
[RANDOM mlp g=30] Ep 02 | train 1.2484/49.17% | test 1.1402/54.48%
[RANDOM mlp g=30] Ep 03 | train 1.1176/55.33% | test 1.1090/55.60%
[RANDOM mlp g=30] Ep 04 | train 1.0324/59.58% | test 0.9748/62.13%
[RANDOM mlp g=30] Ep 05 | train 0.9725/62.51% | test 0.9350/63.36%
[RANDOM mlp g=30] Ep 06 | train 0.9164/65.46% | test 0.8966/65.77%
[RANDOM mlp g=30] Ep 07 | train 0.8798/66.94% | test 0.8737/67.14%
[RANDOM mlp g=30] Ep 08 | train 0.8349/69.60% | test 0.8003/71.42%
