In this notebook we start to scale up, using the pre-computed integrals for genus 30 (but not yet pre-computed periods, instead we mock up the periodicity along the original axes). This already beats not only the 2d projection model but also is comparable to  the 2g x 2g model with many fewer parameters.  [Depending on the AJ version it can do a bit better in fact (maybe periodicity helps, beyond just sheer dimensionality of the layer).]  



In [None]:
# @title Load from Drive
!pip install -q torch torchvision numpy

import os, math, json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from google.colab import drive

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ===========================
# 1) Drive + Paths (match Notebook 1)
# ===========================
DRIVE_FOLDER = "AJ_Tables_g30"  # same as Notebook 1
drive.mount('/content/drive', force_remount=True)
SAVE_DIR = f"/content/drive/MyDrive/{DRIVE_FOLDER}"
INTEGRALS_PATH = os.path.join(SAVE_DIR, "aj_integrals_genus30.pt")
OMEGAS_PATH    = os.path.join(SAVE_DIR, "aj_omegas_genus30.pt")
assert os.path.exists(INTEGRALS_PATH), "Integrals file not found in Drive"
assert os.path.exists(OMEGAS_PATH),    "Omegas file not found in Drive"



# ===========================
# 2) Load tables
# ===========================
ints = torch.load(INTEGRALS_PATH, map_location='cpu', weights_only=False)
omeg = torch.load(OMEGAS_PATH,    map_location='cpu', weights_only=False)

# ints = torch.load(INTEGRALS_PATH, map_location='cpu')  ### old version
# omeg = torch.load(OMEGAS_PATH,    map_location='cpu')

genus       = int(ints["genus"])
grid_r_np   = np.array(ints["grid_r"])
grid_i_np   = np.array(ints["grid_i"])
branch_pts  = np.array(ints["branch_pts"])
I_plus      = ints["I_plus"]            # (g, H, W) complex
Om_plus     = omeg["omega_plus"]        # (g, H, W) complex

H, W = I_plus.shape[-2:]
grid_r = torch.tensor(grid_r_np, dtype=torch.float32)
grid_i = torch.tensor(grid_i_np, dtype=torch.float32)
branch_pts_t = torch.tensor(branch_pts)   # complex128 → complex64 in torch if needed

print(f"Loaded: genus={genus}, grid=({H}x{W})")


Mounted at /content/drive
Loaded: genus=30, grid=(96x96)


In [None]:
# @title Base class AJMNIST_Anchored (needed for PeriodicHead)

import torch
import torch.nn as nn
import torch.nn.functional as F

class AJMNIST_Anchored(nn.Module):
    """
    Anchored AJ model:
      - conv trunk (1→64)
      - learned per-point embeddings (g×embed_dim)
      - shared point_head producing (x,y, sheet_logit) for each point
      - per-point bias initialized to ω-aware anchors
      - AJGridActivationNorm does lookup, standardization, sheet sign, and sum
      - classifier: linear 2g→10
    """
    def __init__(self, genus, I_plus, Om_plus, grid_r, grid_i, branch_pts,
                 anchors_xy, mu, sigma, embed_dim=8):
        super().__init__()
        self.genus = genus
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((1,1))
        )
        self.embed = nn.Parameter(torch.empty(genus, embed_dim))
        nn.init.uniform_(self.embed, -2.0, 2.0)

        # shared point_head (no bias) + per-point bias
        self.point_head = nn.Linear(64 + embed_dim, 3, bias=False)
        nn.init.xavier_uniform_(self.point_head.weight)
        self.point_bias = nn.Parameter(torch.zeros(genus, 3))

        # AJ activation with normalization
        self.aj = AJGridActivationNorm(I_plus, Om_plus, grid_r, grid_i, branch_pts, mu, sigma)
        self.classifier = nn.Linear(2*genus, 10)

        # ω-aware initialization for biases
        rmin, rmax = float(grid_r.min()), float(grid_r.max())
        imin, imax = float(grid_i.min()), float(grid_i.max())
        def logit(p): return float(torch.log(p/(1-p)))
        with torch.no_grad():
            for i in range(genus):
                x0, y0 = float(anchors_xy[i,0]), float(anchors_xy[i,1])
                px = (x0 - rmin) / (rmax - rmin)
                py = (y0 - imin) / (imax - imin)
                self.point_bias[i, 0] = logit(torch.tensor(px))
                self.point_bias[i, 1] = logit(torch.tensor(py))
                target_sign = 0.8 if (i % 2 == 0) else -0.8
                self.point_bias[i, 2] = float(torch.atanh(torch.tensor(target_sign)))

    def forward(self, x, return_aux=False):
        B = x.size(0)
        h = self.conv(x).view(B, -1)
        h_exp = h.unsqueeze(1).expand(-1, self.genus, -1)
        emb   = self.embed.unsqueeze(0).expand(B, -1, -1)
        inp   = torch.cat([h_exp, emb], dim=2)

        out   = self.point_head(inp) + self.point_bias.unsqueeze(0)  # (B,g,3)
        raw_xy, sheet_logits = out[..., :2], out[..., 2]
        coords, aux = self.aj(raw_xy, sheet_logits, return_aux=True)
        logits = self.classifier(coords)
        if return_aux:
            return logits, aux
        return logits


In [None]:
# @title AJ model with ω-aware init + AJ normalization + learnable gain

import torch
import torch.nn as nn
import torch.nn.functional as F

def _pack_complex_table(table_gHW: torch.Tensor) -> torch.Tensor:
    re, im = table_gHW.real, table_gHW.imag
    return torch.cat([re, im], dim=0).unsqueeze(0).contiguous()  # (1, 2g, H, W)

class AJGridActivationNorm(nn.Module):
    """
    AJ lookup with:
      - bilinear sampling
      - continuous sheet sign via tanh
      - channel-wise standardization: (I - mu)/sigma
      - learnable global gain gamma
      - boundary/branch penalties (for diagnostics/regularization)
    """
    def __init__(self, I_plus, Om_plus, grid_r, grid_i, branch_pts, mu, sigma):
        super().__init__()
        self.g = I_plus.shape[0]
        self.register_buffer("I_plus",  _pack_complex_table(I_plus))
        self.register_buffer("Om_plus", _pack_complex_table(Om_plus))

        # stats
        self.register_buffer("mu",    mu.view(1, 1, -1))     # (1,1,2g)
        self.register_buffer("sigma", sigma.view(1, 1, -1))  # (1,1,2g)
        self.gamma = nn.Parameter(torch.tensor(1.0))         # learnable gain

        # bounds
        self.register_buffer("r_min", torch.tensor(float(grid_r.min())))
        self.register_buffer("r_max", torch.tensor(float(grid_r.max())))
        self.register_buffer("i_min", torch.tensor(float(grid_i.min())))
        self.register_buffer("i_max", torch.tensor(float(grid_i.max())))

        # branch points
        self.register_buffer("bp_real", branch_pts.real.float())
        self.register_buffer("bp_imag", branch_pts.imag.float())

    def _map_raw_to_bounds(self, raw_xy):
        # sigmoid mapping into box
        xr = self.r_min + (self.r_max - self.r_min) * torch.sigmoid(raw_xy[..., 0])
        yi = self.i_min + (self.i_max - self.i_min) * torch.sigmoid(raw_xy[..., 1])
        return xr, yi

    def _norm_to_grid(self, xr, yi):
        gx = 2.0 * (xr - self.r_min) / (self.r_max - self.r_min) - 1.0
        gy = 2.0 * (yi - self.i_min) / (self.i_max - self.i_min) - 1.0
        return gx, gy

    def forward(self, raw_xy: torch.Tensor, sheet_logits: torch.Tensor, return_aux=True):
        B, g, _ = raw_xy.shape
        assert g == self.g

        xr, yi = self._map_raw_to_bounds(raw_xy)
        gx, gy = self._norm_to_grid(xr, yi)
        grid = torch.stack([gx, gy], dim=-1).view(B*g, 1, 1, 2)

        # Sample integrals and standardize
        I = F.grid_sample(self.I_plus.expand(B*g, -1, -1, -1),
                          grid, mode="bilinear", align_corners=True).view(B, g, -1)
        I_std = (I - self.mu) / self.sigma                        # (B, g, 2g)

        # Continuous sheet sign (initialized away from 0 in the constructor below)
        #sign = torch.tanh(sheet_logits).unsqueeze(-1)             # (B, g, 1)
        # inside AJGridActivationNorm.forward(...)

        sign = torch.tanh(sheet_logits).unsqueeze(-1)  # (B,g,1)
        contrib = sign * I_std
        coords  = self.gamma * contrib.sum(dim=1)      # (B,2g)


        # contrib = sign * I_std
        # coords = self.gamma * contrib.sum(dim=1)                  # (B, 2g)

        aux = None
        if return_aux:
            margin = 0.95
            bpen = ((gx.abs() - margin).clamp_min(0)**2 +
                    (gy.abs() - margin).clamp_min(0)**2).mean()

            dx = xr.unsqueeze(-1) - self.bp_real
            dy = yi.unsqueeze(-1) - self.bp_imag
            d2 = dx*dx + dy*dy
            tau = 0.07
            rpen = torch.exp(-d2 / (2*tau*tau)).mean()

            # aux = {"bound_penalty": bpen, "branch_penalty": rpen,
            #        "gx": gx, "gy": gy, "x": xr, "y": yi}
            # inside AJGridActivationNorm.forward(...), in the return_aux block
            aux = {
                "x": xr, "y": yi, "gx": gx, "gy": gy,
                "bound_penalty": bpen, "branch_penalty": rpen,
                "sheet_sign": sign.squeeze(-1)   # keep this
                # "omega_channels": Om  <-- delete this line
            }

            # aux = {
            #     "x": xr, "y": yi, "gx": gx, "gy": gy,
            #     "bound_penalty": bpen, "branch_penalty": rpen,
            #     "omega_channels": Om,
            #     "sheet_sign": sign.squeeze(-1)  # <-- add this line
            # }
        return coords, aux

def logit(p):  # inverse sigmoid
    p = np.clip(p, 1e-6, 1-1e-6)
    return float(np.log(p/(1-p)))





In [None]:
# AJ periodic head WITHOUT tau: axis-aligned Fourier features (global K)
import torch, torch.nn as nn, torch.nn.functional as F

class TorusFeatures(nn.Module):
    def __init__(self, dim: int, K: int = 2, freqs=None, learnable: bool = False):
        super().__init__()
        if freqs is None:
            freqs = torch.tensor([0.5, 1.0], dtype=torch.float32)[:K]
        else:
            freqs = torch.as_tensor(freqs, dtype=torch.float32)[:K]
        if learnable:
            self.freqs = nn.Parameter(freqs)
        else:
            self.register_buffer("freqs", freqs)
        self.dim, self.K = dim, len(freqs)
    def forward(self, u):
        B, D = u.shape
        f = (self.freqs if isinstance(self.freqs, torch.Tensor) else torch.as_tensor(self.freqs)).view(1,1,-1).to(u.device)
        ang = u.unsqueeze(-1) * f
        return torch.cat([torch.cos(ang), torch.sin(ang)], dim=-1).view(B, D*2*self.K)

class AJMNIST_AxisPeriodic(nn.Module):
    """
    Drop-in head: AJ (anchored+norm) → axis-aligned periodic features → classifier
    Expects you already have AJMNIST_Anchored and AJGridActivationNorm defined.
    """
    def __init__(self, genus, I_plus, Om_plus, grid_r, grid_i, branch_pts,
                 anchors_xy, mu, sigma, embed_dim=8, K=2, learnable_freqs=False):
        super().__init__()
        self.base = AJMNIST_Anchored(genus, I_plus, Om_plus, grid_r, grid_i, branch_pts,
                                     anchors_xy, mu, sigma, embed_dim=embed_dim)
        D = 2*genus
        self.torus = TorusFeatures(D, K=K, learnable=learnable_freqs)
        self.classifier = nn.Linear(D*2*K, 10)
    @property
    def genus(self): return self.base.genus
    def forward(self, x, return_aux=False):
        B = x.size(0)
        h = self.base.conv(x).view(B, -1)
        h_exp = h.unsqueeze(1).expand(-1, self.genus, -1)
        emb   = self.base.embed.unsqueeze(0).expand(B, -1, -1)
        out   = self.base.point_head(torch.cat([h_exp, emb], dim=2)) + self.base.point_bias.unsqueeze(0)
        raw_xy, sheet_logits = out[..., :2], out[..., 2]
        coords, aux = self.base.aj(raw_xy, sheet_logits, return_aux=True)  # (B, 2g)
        feats = self.torus(coords)
        logits = self.classifier(feats)
        if return_aux: return logits, aux
        return logits


In [None]:
# @title Compute ω-aware anchors and AJ normalization
import torch
import numpy as np

# 1) Build a scalar "ω-strength" map on the grid to avoid dead-gradient regions
#    Use sum_k |ω_k| (or |ω_k|^2); both work similarly.
with torch.no_grad():
    # Om_plus: (g, H, W) complex
    Wmap = Om_plus.abs().sum(dim=0)  # (H, W)
    Wmap_np = Wmap.cpu().numpy()

# 2) Mask out boundaries and branch points (stay away from peaks/singularities too)
H, W = Wmap_np.shape
gy = np.linspace(-1, 1, H)[:, None]
gx = np.linspace(-1, 1, W)[None, :]
edge_mask = (np.abs(gx) < 0.92) & (np.abs(gy) < 0.92)

# Distance from branch points in grid coordinates
grid_x = grid_r_np[None, :].repeat(H, axis=0)  # (H,W)
grid_y = grid_i_np[:, None].repeat(W, axis=1)  # (H,W)
bp_real = np.real(branch_pts).reshape(-1, 1, 1)
bp_imag = np.imag(branch_pts).reshape(-1, 1, 1)
d2 = (grid_x - bp_real)**2 + (grid_y - bp_imag)**2  # (P, H, W)
bp_mask = (d2.min(axis=0) > 0.25)  # keep points with dist > ~0.5

mask = edge_mask & bp_mask

# 3) Pick g anchor points among top-quantile ω regions, well separated
score = np.where(mask, Wmap_np, -np.inf)
th = np.quantile(score[score > -np.inf], 0.85)  # top 15% by ω-strength
cand = np.argwhere(score >= th)                 # list of (iy, ix)

# Farthest-point sampling to pick 'genus' diverse anchors
def farthest_k(points_hw, k):
    pts = points_hw.copy()
    # Start from the global maximum
    start = pts[np.argmax(score[tuple(pts.T)])]
    chosen = [start]
    if k == 1: return np.array(chosen)
    # Precompute physical coordinates for distances
    coords = np.stack([grid_x[tuple(pts.T)], grid_y[tuple(pts.T)]], axis=1)
    c0 = np.array([grid_x[start[0], start[1]], grid_y[start[0], start[1]]])[None, :]
    mind = np.sum((coords - c0)**2, axis=1)
    for _ in range(1, k):
        j = np.argmax(mind)
        chosen.append(pts[j])
        cj = coords[j][None, :]
        mind = np.minimum(mind, np.sum((coords - cj)**2, axis=1))
    return np.array(chosen)

anchors_hw = farthest_k(cand, genus)  # shape (g, 2) with (iy, ix)

# 4) Convert anchors to (x0, y0) coordinates
x0 = grid_r_np[anchors_hw[:, 1]]
y0 = grid_i_np[anchors_hw[:, 0]]
anchors_xy = np.stack([x0, y0], axis=1)  # (g, 2)

# 5) Compute per-channel mean/std of AJ coordinates for standardization
#    Pack integrals into 2g channels [Re..., Im...], then compute stats over H×W.
I_re = I_plus.real    # (g, H, W)
I_im = I_plus.imag    # (g, H, W)
I_ch = torch.cat([I_re, I_im], dim=0)  # (2g, H, W)
mu = I_ch.mean(dim=(1, 2))             # (2g,)
sigma = I_ch.std(dim=(1, 2)).clamp_min(1e-6)

# Save for later use
anchors_xy_t = torch.tensor(anchors_xy, dtype=torch.float32)
mu_t = mu.float()
sigma_t = sigma.float()

print("Chosen anchors (x0,y0):")
for i, (x, y) in enumerate(anchors_xy):
    print(f"  point {i:2d}: x0={x:+.3f}, y0={y:+.3f}")
print("\nAJ normalization ready: per-channel mean/std computed.")


Chosen anchors (x0,y0):
  point  0: x0=+2.084, y0=-4.232
  point  1: x0=-0.442, y0=+5.495
  point  2: x0=-5.242, y0=-0.947
  point  3: x0=+4.358, y0=+1.453
  point  4: x0=-2.463, y0=-4.611
  point  5: x0=-3.853, y0=+2.968
  point  6: x0=+4.611, y0=-1.958
  point  7: x0=+2.463, y0=+4.105
  point  8: x0=-3.600, y0=-2.589
  point  9: x0=-0.189, y0=-4.989
  point 10: x0=-4.989, y0=+1.200
  point 11: x0=-2.084, y0=+3.979
  point 12: x0=+3.979, y0=-3.853
  point 13: x0=+1.200, y0=+5.116
  point 14: x0=+1.200, y0=-5.242
  point 15: x0=-4.863, y0=-2.211
  point 16: x0=-3.474, y0=-3.853
  point 17: x0=-4.737, y0=+0.063
  point 18: x0=+3.095, y0=-4.737
  point 19: x0=-1.200, y0=+4.611
  point 20: x0=-4.232, y0=+1.958
  point 21: x0=-3.095, y0=+3.726
  point 22: x0=+4.105, y0=-2.842
  point 23: x0=-1.453, y0=-4.484
  point 24: x0=+2.968, y0=-3.726
  point 25: x0=+1.453, y0=+4.232
  point 26: x0=+0.316, y0=+4.863
  point 27: x0=+0.695, y0=-4.484
  point 28: x0=+2.211, y0=-5.116
  point 29: x0=-4.3

In [None]:
# ===========================
# 5) Data
# ===========================
tfm = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
train_ds = torchvision.datasets.MNIST(root="/content/data", train=True, download=True, transform=tfm)
test_ds  = torchvision.datasets.MNIST(root="/content/data", train=False, download=True, transform=tfm)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = torch.utils.data.DataLoader(test_ds,  batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

# ===========================
# 6) Train / Eval
# ===========================
def train_epoch(model, loader, opt, clip=1.0, lam_branch=1e-3, lam_bound=1e-3):
    model.train()
    ce = nn.CrossEntropyLoss()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        logits, aux = model(x, return_aux=True)

        loss = ce(logits, y)
        loss = loss + lam_branch * aux["branch_penalty"] + lam_bound * aux["bound_penalty"]
        loss.backward()
        if clip is not None:
            nn.utils.clip_grad_norm_(model.parameters(), clip)
        opt.step()

        tot += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n += x.size(0)
    return tot/n, 100.0*correct/n

@torch.no_grad()
def eval_epoch(model, loader):
    model.eval()
    ce = nn.CrossEntropyLoss()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        tot += ce(logits, y).item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n += x.size(0)
    return tot/n, 100.0*correct/n

100%|██████████| 9.91M/9.91M [00:00<00:00, 13.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 340kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.23MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 11.3MB/s]


In [None]:
# @title Train τ-free axis-periodic AJ model with smaller batch & AMP (OOM-safe)

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# -------------------- sanity: required globals --------------------
required = [
    "genus", "I_plus", "Om_plus", "grid_r", "grid_i", "branch_pts_t",
    "anchors_xy_t", "mu_t", "sigma_t",
    "train_loader", "test_loader",
    "AJMNIST_AxisPeriodic"
]
for name in required:
    if name not in globals():
        raise RuntimeError(f"Missing required object `{name}` before training cell.")

print(f"Current genus = {genus}")

# -------------------- build smaller DataLoaders to save memory --------------------
from torch.utils.data import DataLoader

orig_train_loader = train_loader
orig_test_loader  = test_loader

# Use a conservative batch size (e.g. 64) for AJ model
BASE_BS = 64
orig_bs = getattr(orig_train_loader, "batch_size", None)
small_bs = BASE_BS if (orig_bs is None) else min(BASE_BS, orig_bs)

train_loader_axis = DataLoader(
    orig_train_loader.dataset,
    batch_size=small_bs,
    shuffle=True,
    num_workers=getattr(orig_train_loader, "num_workers", 0),
    pin_memory=getattr(orig_train_loader, "pin_memory", False),
    drop_last=True
)
test_loader_axis = DataLoader(
    orig_test_loader.dataset,
    batch_size=min(256, small_bs*2),
    shuffle=False,
    num_workers=getattr(orig_test_loader, "num_workers", 0),
    pin_memory=getattr(orig_test_loader, "pin_memory", False),
    drop_last=False
)

print(f"Original train batch size: {orig_bs}")
print(f"AJ-axis train batch size : {small_bs}")

# -------------------- helpers --------------------
def count_params(m: nn.Module) -> int:
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

def heads_only_params(m: nn.Module) -> int:
    # conv trunk is in m.base.conv
    total = count_params(m)
    conv_params = sum(p.numel() for p in m.base.conv.parameters() if p.requires_grad)
    return total - conv_params

@torch.no_grad()
def eval_epoch(model: nn.Module, loader):
    model.eval()
    ce = nn.CrossEntropyLoss()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        tot += ce(logits, y).item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n += x.size(0)
    return tot/n, 100.0*correct/n

# -------------------- AMP-aware training epoch --------------------
USE_AMP = (device.type == "cuda")
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

def train_epoch_amp(model: nn.Module, loader, opt,
                    clip: float = 1.0,
                    lam_branch: float = 1e-3,
                    lam_bound:  float = 1e-3):
    model.train()
    ce = nn.CrossEntropyLoss()
    tot, correct, n = 0.0, 0, 0

    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            logits, aux = model(x, return_aux=True)
            loss = ce(logits, y)
            if aux is not None:
                loss = loss + lam_branch*aux.get("branch_penalty", 0.0) \
                             + lam_bound *aux.get("bound_penalty",  0.0)
        scaler.scale(loss).backward()
        if clip is not None:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), clip)
        scaler.step(opt)
        scaler.update()

        tot     += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n       += x.size(0)
    return tot/n, 100.0*correct/n

# optional AJ diagnostics hook if you defined one earlier
def maybe_aj_diags(model, loader, n_batches=2):
    if "aj_diags" in globals():
        aj_diags(model, loader, device, n_batches=n_batches)

# -------------------- instantiate axis-periodic AJ model --------------------
K = 2                 # number of Fourier harmonics per coordinate
LEARN_FREQS = False   # fixed frequencies for τ-free head
EMBED_DIM  = 4        # small embedding to keep params modest

aj_axis = AJMNIST_AxisPeriodic(
    genus,
    I_plus, Om_plus,
    grid_r, grid_i,
    branch_pts_t,
    anchors_xy_t.to(device),
    mu_t.to(device), sigma_t.to(device),
    embed_dim=EMBED_DIM,
    K=K,
    learnable_freqs=LEARN_FREQS
).to(device)

total_params = count_params(aj_axis)
head_params  = heads_only_params(aj_axis)
print(f"\nAJ axis-periodic (g={genus}, K={K}) params: {total_params:,}")
print(f"  heads-only (total minus conv trunk): {head_params:,}")

# -------------------- optimizer (two-tier LR) --------------------
fast_params = list(aj_axis.base.point_head.parameters()) + \
              [aj_axis.base.point_bias] + \
              list(aj_axis.torus.parameters()) + \
              list(aj_axis.classifier.parameters())

base_params = list(aj_axis.base.conv.parameters()) + \
              list(aj_axis.base.aj.parameters())

opt = torch.optim.AdamW(
    [
        {"params": base_params, "lr": 3e-4},
        {"params": fast_params, "lr": 1e-3},
    ],
    weight_decay=1e-4
)

# -------------------- training loop --------------------
EPOCHS     = 8
CLIP_NORM  = 1.0
LAM_BRANCH = 1e-3
LAM_BOUND  = 1e-3

print("\nStarting training with AMP =", USE_AMP)
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_epoch_amp(
        aj_axis, train_loader_axis, opt,
        clip=CLIP_NORM,
        lam_branch=LAM_BRANCH,
        lam_bound=LAM_BOUND
    )
    te_loss, te_acc = eval_epoch(aj_axis, test_loader_axis)
    print(f"[AJ axis K={K}] Epoch {ep:02d} | "
          f"train {tr_loss:.4f} / {tr_acc:.2f}% | "
          f"test {te_loss:.4f} / {te_acc:.2f}%")
    maybe_aj_diags(aj_axis, train_loader_axis, n_batches=2)


Using device: cuda
Current genus = 30
Original train batch size: 128
AJ-axis train batch size : 64

AJ axis-periodic (g=30, K=2) params: 22,251
  heads-only (total minus conv trunk): 3,435

Starting training with AMP = True


  scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
  with torch.cuda.amp.autocast(enabled=USE_AMP):


[AJ axis K=2] Epoch 01 | train 1.9519 / 26.38% | test 1.7120 / 33.90%
[AJ axis K=2] Epoch 02 | train 1.5955 / 39.59% | test 1.3690 / 47.23%
[AJ axis K=2] Epoch 03 | train 1.2184 / 55.03% | test 1.0156 / 63.03%
[AJ axis K=2] Epoch 04 | train 0.9245 / 66.91% | test 1.0187 / 62.21%
[AJ axis K=2] Epoch 05 | train 0.7481 / 73.21% | test 0.6498 / 76.78%
[AJ axis K=2] Epoch 06 | train 0.6925 / 75.47% | test 0.6305 / 77.39%
[AJ axis K=2] Epoch 07 | train 0.6140 / 78.76% | test 0.5700 / 80.53%
[AJ axis K=2] Epoch 08 | train 0.5374 / 81.91% | test 0.4593 / 84.74%


In [None]:
# @title Parameter counts: AJ (current), Projection2D, Full 2g×2g
import torch
import torch.nn as nn

# def count_params(m: nn.Module) -> int:
#     return sum(p.numel() for p in m.parameters() if p.requires_grad)

# def group_counts(m: nn.Module):
#     from collections import defaultdict
#     groups = defaultdict(int)
#     for n, p in m.named_parameters():
#         if not p.requires_grad:
#             continue
#         top = n.split('.')[0]  # top-level module name
#         groups[top] += p.numel()
#     return dict(sorted(groups.items(), key=lambda kv: kv[0]))

# def pretty_mb(n_params: int, bytes_per=4):
#     return f"{n_params:,}  (~{n_params*bytes_per/1e6:.2f} MB @fp32)"

# def print_summary(name, m):
#     total = count_params(m)
#     print(f"\n{name}: {pretty_mb(total)}")
#     gc = group_counts(m)
#     for k, v in gc.items():
#         print(f"  {k:15s}: {v:,}")
#     # heads-only (subtract conv params if present)
#     conv_params = sum(p.numel() for n,p in m.named_parameters()
#                       if p.requires_grad and n.startswith("conv"))
#     heads_only = total - conv_params
#     print(f"  {'[heads only]':15s}: {heads_only:,} (total minus conv trunk)")

# # ---------------------------
# # Find your current AJ model
# # ---------------------------
# aj_candidates = ["aj_fourier_mix", "aj_fourier", "aj_norm_model", "aj_period", "model"]
# aj_model = None
# for nm in aj_candidates:
#     if nm in globals() and isinstance(globals()[nm], nn.Module):
#         aj_model = globals()[nm]
#         aj_name = nm
#         break

# if aj_model is None:
#     raise RuntimeError("Could not find your current AJ model instance. "
#                        "Make sure you've created e.g. `aj_fourier_mix` or `aj_fourier`.")

# device = next(aj_model.parameters()).device
# try:
#     g = int(genus)
# except NameError:
#     g = 6

# ---------------------------
# Ensure Projection2D & Full2g classes exist; define fallbacks if not
# ---------------------------
if 'Projection2DTo2gNet' not in globals():
    class Projection2DTo2gNet(nn.Module):
        def __init__(self, genus):
            super().__init__()
            self.genus = genus
            self.conv = nn.Sequential(
                nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
                nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
                nn.AdaptiveAvgPool2d((1,1))
            )
            self.fc_down = nn.Linear(64, 2)
            self.fc_up   = nn.Linear(2, 2*genus)
            self.classifier = nn.Linear(2*genus, 10)
        def forward(self, x, return_aux=False):
            B = x.size(0)
            h = self.conv(x).view(B, -1)
            z2 = self.fc_down(h)
            z  = self.fc_up(z2)
            logits = self.classifier(z)
            if return_aux:
                zero = torch.zeros((), device=x.device)
                return logits, {"branch_penalty": zero, "bound_penalty": zero}
            return logits

if 'FullMix2gNet' not in globals():
    class FullMix2gNet(nn.Module):
        def __init__(self, genus: int, use_nonlinearity: bool = False):
            super().__init__()
            self.genus = genus
            self.use_nonlinearity = use_nonlinearity
            self.conv = nn.Sequential(
                nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
                nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
                nn.AdaptiveAvgPool2d((1,1))
            )
            self.to2g = nn.Linear(64, 2*genus)
            self.mix2g = nn.Linear(2*genus, 2*genus)
            self.classifier = nn.Linear(2*genus, 10)
        def forward(self, x, return_aux=False):
            B = x.size(0)
            h = self.conv(x).view(B, -1)
            z = self.to2g(h)
            if self.use_nonlinearity:
                z = torch.relu(z)
            z = self.mix2g(z)
            logits = self.classifier(z)
            if return_aux:
                zero = torch.zeros((), device=x.device)
                return logits, {"branch_penalty": zero, "bound_penalty": zero}
            return logits

# # ---------------------------
# # Build baseline model *instances* (for counting)
# # ---------------------------
# proj_model = Projection2DTo2gNet(g).to(device)
# full2g_model = FullMix2gNet(g, use_nonlinearity=False).to(device)

# # ---------------------------
# # Print summaries
# # ---------------------------
# print_summary(f"AJ current model [{aj_name}]", aj_model)
# print_summary("Projection2D → 2g", proj_model)
# print_summary("Full 2g×2g (linear)", full2g_model)

# # Optional: also show a 'Full 2g×2g + ReLU' variant
# full2g_relu = FullMix2gNet(g, use_nonlinearity=True).to(device)
# print_summary("Full 2g×2g (with ReLU)", full2g_relu)


In [None]:
# @title 2D projection baseline: conv → 2 → 2g → 10 (no AJ)

import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# ---------- sanity ----------
if 'genus' not in globals():
    raise RuntimeError("Please define `genus` first (e.g. genus = 30).")
if 'train_loader' not in globals() or 'test_loader' not in globals():
    raise RuntimeError("Please define `train_loader` and `test_loader` first.")

print(f"Current genus = {genus}  → 2g = {2*genus}")

# ---------- shared conv trunk (same as AJ) ----------
def make_conv_trunk():
    return nn.Sequential(
        nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.AdaptiveAvgPool2d((1, 1))   # → (B,64,1,1)
    )

# optional init helper (same flavor as earlier experiments)
def init_two_band_(tensor, low=1.0, high=3.0):
    with torch.no_grad():
        sign = (torch.randint(0, 2, tensor.shape, device=tensor.device)*2 - 1).float()
        mag  = torch.empty_like(tensor, dtype=torch.float32).uniform_(low, high)
        tensor.copy_(sign * mag)

# ---------- 2D projection model ----------
class Projection2DTo2gNet(nn.Module):
    """
    Baseline:
      conv(1×28×28 → 64) →
      Linear(64 → 2) →
      Linear(2 → 2g) →
      Linear(2g → 10)

    No nonlinearity between 2 and 2g, so the representation is *strictly 2D*.
    Increasing genus only increases parameter count, not representational dimension.
    """
    def __init__(self, genus: int):
        super().__init__()
        self.genus = genus
        D = 2*genus

        self.conv = make_conv_trunk()
        self.fc_down = nn.Linear(64, 2)
        self.fc_up   = nn.Linear(2, D)
        self.classifier = nn.Linear(D, 10)

        # Match the style we used before
        init_two_band_(self.fc_down.weight); nn.init.zeros_(self.fc_down.bias)
        init_two_band_(self.fc_up.weight);   nn.init.zeros_(self.fc_up.bias)
        # classifier left at default init for parity with AJ code

    def forward(self, x, return_aux=False):
        B = x.size(0)
        h = self.conv(x).view(B, -1)     # (B, 64)
        z2  = self.fc_down(h)            # (B, 2)
        z2g = self.fc_up(z2)             # (B, 2g)
        logits = self.classifier(z2g)    # (B, 10)

        if return_aux:
            # Dummy aux so we can reuse AJ-style training loops if desired
            zero = torch.zeros((), device=x.device)
            aux = {"branch_penalty": zero, "bound_penalty": zero}
            return logits, aux
        return logits

# ---------- training helpers ----------
def count_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

@torch.no_grad()
def eval_epoch_plain(model, loader):
    model.eval()
    ce = nn.CrossEntropyLoss()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = ce(logits, y)
        tot     += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n += x.size(0)
    return tot/n, 100.0*correct/n

def train_epoch_plain(model, loader, opt, clip=None):
    model.train()
    ce = nn.CrossEntropyLoss()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        logits = model(x)
        loss = ce(logits, y)
        loss.backward()
        if clip is not None:
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip)
        opt.step()
        tot     += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n       += x.size(0)
    return tot/n, 100.0*correct/n

# ---------- instantiate & train ----------
proj2d = Projection2DTo2gNet(genus).to(device)
opt_proj = torch.optim.AdamW(proj2d.parameters(), lr=3e-4, weight_decay=1e-4)

total_params = count_params(proj2d)
print(f"\nProjection2DTo2gNet params (g={genus}): {total_params:,}")

EPOCHS = 8
CLIP   = 1.0

print("\nStarting training of 2D projection baseline...")
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_epoch_plain(proj2d, train_loader, opt_proj, clip=CLIP)
    te_loss, te_acc = eval_epoch_plain(proj2d, test_loader)
    print(f"[PROJ-2D→2g g={genus}] Ep {ep:02d} | "
          f"train {tr_loss:.4f} / {tr_acc:.2f}% | "
          f"test {te_loss:.4f} / {te_acc:.2f}%")


Using device: cuda
Current genus = 30  → 2g = 60

Projection2DTo2gNet params (g=30): 19,736

Starting training of 2D projection baseline...
[PROJ-2D→2g g=30] Ep 01 | train 1.4776 / 42.54% | test 1.1763 / 53.17%
[PROJ-2D→2g g=30] Ep 02 | train 1.0915 / 56.52% | test 0.9643 / 63.14%
[PROJ-2D→2g g=30] Ep 03 | train 0.9546 / 63.07% | test 0.8682 / 67.89%
[PROJ-2D→2g g=30] Ep 04 | train 0.8876 / 66.87% | test 1.0759 / 58.88%
[PROJ-2D→2g g=30] Ep 05 | train 0.8346 / 69.56% | test 0.7492 / 73.22%
[PROJ-2D→2g g=30] Ep 06 | train 0.7526 / 73.13% | test 0.7266 / 73.31%
[PROJ-2D→2g g=30] Ep 07 | train 0.6960 / 75.66% | test 0.6455 / 78.14%
[PROJ-2D→2g g=30] Ep 08 | train 0.6683 / 76.83% | test 0.5790 / 81.00%


In [None]:
print(genus)

30


In [None]:
# @title Compact τ-free axis-periodic AJ model (definition + training)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math, os, time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# --------- sanity checks for precomputed objects ---------
needed = [
    "genus", "I_plus", "Om_plus", "grid_r", "grid_i", "branch_pts_t",
    "anchors_xy_t", "mu_t", "sigma_t", "AJMNIST_Anchored", "TorusFeatures",
    "train_loader", "test_loader"
]
for name in needed:
    if name not in globals():
        raise RuntimeError(f"Missing `{name}`; make sure your genus-{globals().get('genus','?')} setup cells ran.")

print(f"Current genus = {genus}")

# If you already have smaller-batch loaders from the previous cell, reuse them:
train_loader_compact = train_loader_axis if 'train_loader_axis' in globals() else train_loader
test_loader_compact  = test_loader_axis  if 'test_loader_axis'  in globals() else test_loader
print("train_loader_compact batch_size:", train_loader_compact.batch_size)

# --------- helper: fixed orthogonal projector ---------
def make_fixed_orthogonal(D: int, r: int, seed: int = 1234):
    """
    Returns a D x r matrix with orthonormal columns (Q), as float32.
    Used as a non-trainable bottleneck from feature dim D to r.
    """
    assert r <= D, f"r={r} must be ≤ D={D}"
    gen = torch.Generator().manual_seed(seed)
    A = torch.randn(D, r, generator=gen)
    Q, _ = torch.linalg.qr(A, mode='reduced')  # D x r
    return Q[:, :r].float()

# --------- compact axis-periodic model ---------
class AJMNIST_AxisPeriodic_Compact(nn.Module):
    """
    AJ base (anchored + normalized) + axis-aligned Fourier torus features,
    followed by a FIXED orthogonal projector P: R^{D_feat} → R^r and a tiny classifier.

    D_feat = 2*(2g)*K (cos/sin per coordinate per harmonic)
    """
    def __init__(self, genus, I_plus, Om_plus, grid_r, grid_i, branch_pts,
                 anchors_xy, mu, sigma,
                 embed_dim=8, K=2, r=32,
                 learnable_freqs=False, proj_seed=1234):
        super().__init__()
        self.base = AJMNIST_Anchored(
            genus, I_plus, Om_plus,
            grid_r, grid_i, branch_pts,
            anchors_xy, mu, sigma,
            embed_dim=embed_dim
        )
        self.K = K
        D = 2*genus
        self.torus = TorusFeatures(D, K=K, learnable=learnable_freqs)
        D_feat = 2 * D * K   # cos+sin per dim per harmonic
        assert r <= D_feat, f"r={r} must be ≤ D_feat={D_feat} for K={K}, g={genus}"
        P = make_fixed_orthogonal(D_feat, r, seed=proj_seed)
        self.register_buffer("P", P)               # non-trainable projector
        self.scale = nn.Parameter(torch.ones(r))   # tiny learned diagonal scale
        self.classifier = nn.Linear(r, 10)         # small head

    @property
    def genus(self):
        return self.base.genus

    def forward(self, x, return_aux=False):
        B = x.size(0)
        # shared conv + point head (same as axis model)
        h = self.base.conv(x).view(B, -1)
        h_exp = h.unsqueeze(1).expand(-1, self.genus, -1)
        emb   = self.base.embed.unsqueeze(0).expand(B, -1, -1)
        out   = self.base.point_head(torch.cat([h_exp, emb], dim=2)) \
              + self.base.point_bias.unsqueeze(0)
        raw_xy, sheet_logits = out[..., :2], out[..., 2]

        coords_std, aux = self.base.aj(raw_xy, sheet_logits, return_aux=True)  # (B, 2g)
        feats = self.torus(coords_std)                                         # (B, D_feat)
        P = self.P.to(feats.device, dtype=feats.dtype)                         # (D_feat, r)
        z = feats @ P                                                          # (B, r)
        z = z * self.scale                                                     # (B, r)
        logits = self.classifier(z)                                            # (B, 10)
        if return_aux:
            return logits, aux
        return logits

# --------- param counting helpers ---------
def count_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

def heads_only_params(m):
    total = count_params(m)
    conv_params = sum(p.numel() for p in m.base.conv.parameters() if p.requires_grad)
    return total - conv_params

# --------- AMP-aware training helpers (reuse if present) ---------
if 'train_epoch_amp' not in globals():
    print("Defining train_epoch_amp locally (AMP training).")
    ce = nn.CrossEntropyLoss()
    USE_AMP = (device.type == 'cuda')
    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

    def train_epoch_amp(model, loader, opt,
                        clip=1.0, lam_branch=1e-3, lam_bound=1e-3):
        model.train()
        tot, correct, n = 0.0, 0, 0
        for x, y in loader:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                logits, aux = model(x, return_aux=True)
                loss = ce(logits, y)
                loss = loss + lam_branch*aux.get("branch_penalty", 0.0) \
                           + lam_bound *aux.get("bound_penalty",  0.0)
            scaler.scale(loss).backward()
            if clip is not None:
                scaler.unscale_(opt)
                nn.utils.clip_grad_norm_(model.parameters(), clip)
            scaler.step(opt)
            scaler.update()
            tot     += loss.item() * x.size(0)
            correct += (logits.argmax(1) == y).sum().item()
            n       += x.size(0)
        return tot/n, 100.0*correct/n

if 'eval_epoch' not in globals():
    ce_eval = nn.CrossEntropyLoss()
    @torch.no_grad()
    def eval_epoch(model, loader):
        model.eval()
        tot, correct, n = 0.0, 0, 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = ce_eval(logits, y)
            tot     += loss.item() * x.size(0)
            correct += (logits.argmax(1) == y).sum().item()
            n += x.size(0)
        return tot/n, 100.0*correct/n

# Optional AJ diagnostics
def maybe_aj_diags(model, loader, n_batches=2):
    if "aj_diags" in globals():
        aj_diags(model, loader, device, n_batches=n_batches)

# --------- move lookup tables to device if needed ---------
# If you already have device copies, reuse; else create them.
I_plus_dev  = I_plus.to(device)   if I_plus.device.type  == 'cpu' else I_plus
Om_plus_dev = Om_plus.to(device)  if Om_plus.device.type == 'cpu' else Om_plus
grid_r_dev  = grid_r.to(device)
grid_i_dev  = grid_i.to(device)
branch_pts_dev = branch_pts_t.to(device)
anchors_xy_dev = anchors_xy_t.to(device)
mu_dev, sigma_dev = mu_t.to(device), sigma_t.to(device)

# --------- instantiate compact model ---------
K = 2       # harmonics per coordinate (same as axis model)
r = 32      # compact dimension (you can tune this)
EMBED_DIM = 4

aj_axis_compact = AJMNIST_AxisPeriodic_Compact(
    genus,
    I_plus_dev, Om_plus_dev,
    grid_r_dev, grid_i_dev,
    branch_pts_dev,
    anchors_xy_dev,
    mu_dev, sigma_dev,
    embed_dim=EMBED_DIM,
    K=K, r=r,
    learnable_freqs=False,
    proj_seed=1234
).to(device)

total = count_params(aj_axis_compact)
heads = heads_only_params(aj_axis_compact)
conv  = sum(p.numel() for p in aj_axis_compact.base.conv.parameters() if p.requires_grad)
print(f"\nAJ axis-periodic COMPACT (g={genus}, K={K}, r={r})")
print(f"  total params  : {total:,}")
print(f"  conv trunk    : {conv:,}")
print(f"  heads-only    : {heads:,} (total minus conv trunk)")

# --------- optimizer (two-tier) ---------
fast_params = list(aj_axis_compact.base.point_head.parameters()) + \
              [aj_axis_compact.base.point_bias] + \
              list(aj_axis_compact.torus.parameters()) + \
              list(aj_axis_compact.classifier.parameters())

base_params = list(aj_axis_compact.base.conv.parameters()) + \
              list(aj_axis_compact.base.aj.parameters())

opt_compact = torch.optim.AdamW(
    [
        {"params": base_params, "lr": 3e-4},
        {"params": fast_params, "lr": 1e-3},
    ],
    weight_decay=1e-4
)

# --------- train compact model ---------
EPOCHS     = 8
CLIP_NORM  = 1.0
LAM_BRANCH = 1e-3
LAM_BOUND  = 1e-3

USE_AMP = (device.type == 'cuda')
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

print("\nStarting training of COMPACT axis-periodic model (AMP =", USE_AMP, ")")
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_epoch_amp(
        aj_axis_compact, train_loader_compact, opt_compact,
        clip=CLIP_NORM, lam_branch=LAM_BRANCH, lam_bound=LAM_BOUND
    )
    te_loss, te_acc = eval_epoch(aj_axis_compact, test_loader_compact)
    print(f"[AJ axis COMPACT K={K}, r={r}] Ep {ep:02d} | "
          f"train {tr_loss:.4f} / {tr_acc:.2f}% | "
          f"test {te_loss:.4f} / {te_acc:.2f}%")
    maybe_aj_diags(aj_axis_compact, train_loader_compact, n_batches=2)


Using device: cuda
Current genus = 30
train_loader_compact batch_size: 64

AJ axis-periodic COMPACT (g=30, K=2, r=32)
  total params  : 20,203
  conv trunk    : 18,816
  heads-only    : 1,387 (total minus conv trunk)

Starting training of COMPACT axis-periodic model (AMP = True )


  scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
  with torch.cuda.amp.autocast(enabled=USE_AMP):


[AJ axis COMPACT K=2, r=32] Ep 01 | train 2.0593 / 23.81% | test 1.7942 / 32.60%
[AJ axis COMPACT K=2, r=32] Ep 02 | train 1.4766 / 44.19% | test 1.1779 / 55.43%
[AJ axis COMPACT K=2, r=32] Ep 03 | train 1.1128 / 59.24% | test 0.9999 / 63.87%
[AJ axis COMPACT K=2, r=32] Ep 04 | train 0.8361 / 72.98% | test 0.6871 / 79.65%
[AJ axis COMPACT K=2, r=32] Ep 05 | train 0.6344 / 81.71% | test 0.5646 / 85.02%
[AJ axis COMPACT K=2, r=32] Ep 06 | train 0.5767 / 83.33% | test 0.4633 / 87.36%
[AJ axis COMPACT K=2, r=32] Ep 07 | train 0.4887 / 86.04% | test 0.3753 / 90.00%
[AJ axis COMPACT K=2, r=32] Ep 08 | train 0.4015 / 88.65% | test 0.3505 / 90.59%


In [None]:
# @title Full 2g×2g baseline: conv → 2g → 2g → 10 (no AJ)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ---------- sanity ----------
if "genus" not in globals():
    raise RuntimeError("`genus` is not defined; run your setup cell first (e.g., genus = 30).")
if "train_loader" not in globals() or "test_loader" not in globals():
    raise RuntimeError("`train_loader` / `test_loader` missing; make sure your dataset loaders exist.")

print(f"Full 2g×2g baseline will use genus = {genus}")
D2g = 2 * genus

# ---------- shared conv trunk (same as AJ models) ----------
def make_conv_trunk():
    return nn.Sequential(
        nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.AdaptiveAvgPool2d((1, 1))
    )

# ---------- full 2g×2g model ----------
class Full2gMixerNet(nn.Module):
    """
    Baseline:
      conv(1×28×28 → 64) →
      Linear(64 → 2g) →
      [optional ReLU] →
      Linear(2g → 2g) →
      Linear(2g → 10)

    No AJ, no torus. This is the "big" baseline that directly mixes a 2g-dimensional
    feature space with a full 2g×2g weight matrix.
    """
    def __init__(self, genus: int, use_relu: bool = False):
        super().__init__()
        self.genus = genus
        self.use_relu = use_relu
        D = 2 * genus

        self.conv = make_conv_trunk()
        self.to2g   = nn.Linear(64, D)
        self.mix2g  = nn.Linear(D, D)
        self.classifier = nn.Linear(D, 10)

        # Optional: Kaiming inits for the 2g layers
        nn.init.kaiming_uniform_(self.to2g.weight, a=math.sqrt(5))
        nn.init.zeros_(self.to2g.bias)
        nn.init.kaiming_uniform_(self.mix2g.weight, a=math.sqrt(5))
        nn.init.zeros_(self.mix2g.bias)
        # classifier left at default init

    def forward(self, x):
        B = x.size(0)
        h = self.conv(x).view(B, -1)   # (B, 64)

        z = self.to2g(h)               # (B, 2g)
        if self.use_relu:
            z = F.relu(z)
        z = self.mix2g(z)              # (B, 2g)
        logits = self.classifier(z)    # (B, 10)
        return logits

# ---------- helpers ----------
def count_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

def conv_params(m):
    return sum(p.numel() for n,p in m.named_parameters()
               if p.requires_grad and n.startswith("conv"))

def heads_only_params(m):
    return count_params(m) - conv_params(m)

if "train_epoch_plain" not in globals():
    def train_epoch_plain(model, loader, opt, clip=None):
        model.train()
        ce = nn.CrossEntropyLoss()
        tot_loss, correct, n = 0.0, 0, 0
        for x, y in loader:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            opt.zero_grad(set_to_none=True)
            logits = model(x)
            loss = ce(logits, y)
            loss.backward()
            if clip is not None:
                nn.utils.clip_grad_norm_(model.parameters(), clip)
            opt.step()
            tot_loss += loss.item() * x.size(0)
            correct  += (logits.argmax(1) == y).sum().item()
            n += x.size(0)
        return tot_loss/n, 100.0*correct/n

if "eval_epoch_plain" not in globals():
    @torch.no_grad()
    def eval_epoch_plain(model, loader):
        model.eval()
        ce = nn.CrossEntropyLoss()
        tot_loss, correct, n = 0.0, 0, 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = ce(logits, y)
            tot_loss += loss.item() * x.size(0)
            correct  += (logits.argmax(1) == y).sum().item()
            n += x.size(0)
        return tot_loss/n, 100.0*correct/n

# ---------- instantiate & train (linear version by default) ----------
use_relu = False   # set True if you also want the ReLU variant

full2g = Full2gMixerNet(genus, use_relu=use_relu).to(device)
opt_full2g = torch.optim.AdamW(full2g.parameters(), lr=3e-4, weight_decay=1e-4)

total = count_params(full2g)
conv  = conv_params(full2g)
heads = heads_only_params(full2g)
print(f"\nFull2gMixerNet (g={genus}, use_relu={use_relu}) params: {total:,}")
print(f"  conv trunk : {conv:,}")
print(f"  heads-only : {heads:,} (total minus conv trunk)")

EPOCHS = 8
CLIP   = 1.0

print("\nStarting training of full 2g×2g baseline...")
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_epoch_plain(full2g, train_loader, opt_full2g, clip=CLIP)
    te_loss, te_acc = eval_epoch_plain(full2g, test_loader)
    print(f"[FULL 2g×2g g={genus}, relu={use_relu}] Ep {ep:02d} | "
          f"train {tr_loss:.4f} / {tr_acc:.2f}% | "
          f"test {te_loss:.4f} / {te_acc:.2f}%")


Using device: cuda
Full 2g×2g baseline will use genus = 30

Full2gMixerNet (g=30, use_relu=False) params: 26,986
  conv trunk : 18,816
  heads-only : 8,170 (total minus conv trunk)

Starting training of full 2g×2g baseline...
[FULL 2g×2g g=30, relu=False] Ep 01 | train 1.6765 / 37.43% | test 1.2858 / 52.60%
[FULL 2g×2g g=30, relu=False] Ep 02 | train 1.1381 / 58.55% | test 0.9864 / 64.70%
[FULL 2g×2g g=30, relu=False] Ep 03 | train 0.9854 / 64.65% | test 0.8870 / 69.01%
[FULL 2g×2g g=30, relu=False] Ep 04 | train 0.9055 / 68.36% | test 0.8473 / 70.84%
[FULL 2g×2g g=30, relu=False] Ep 05 | train 0.8243 / 72.16% | test 0.7878 / 72.90%
[FULL 2g×2g g=30, relu=False] Ep 06 | train 0.7275 / 76.57% | test 0.6396 / 78.74%
[FULL 2g×2g g=30, relu=False] Ep 07 | train 0.6213 / 81.04% | test 0.5240 / 84.50%
[FULL 2g×2g g=30, relu=False] Ep 08 | train 0.5103 / 84.85% | test 0.4292 / 87.30%


In [None]:

# @title Full 2g×2g baseline: conv → 2g → 2g → 10 (no AJ):  now turn on relu

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ---------- sanity ----------
if "genus" not in globals():
    raise RuntimeError("`genus` is not defined; run your setup cell first (e.g., genus = 30).")
if "train_loader" not in globals() or "test_loader" not in globals():
    raise RuntimeError("`train_loader` / `test_loader` missing; make sure your dataset loaders exist.")

print(f"Full 2g×2g baseline will use genus = {genus}")
D2g = 2 * genus

# ---------- shared conv trunk (same as AJ models) ----------
def make_conv_trunk():
    return nn.Sequential(
        nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.AdaptiveAvgPool2d((1, 1))
    )

# ---------- full 2g×2g model ----------
class Full2gMixerNet(nn.Module):
    """
    Baseline:
      conv(1×28×28 → 64) →
      Linear(64 → 2g) →
      [optional ReLU] →
      Linear(2g → 2g) →
      Linear(2g → 10)

    No AJ, no torus. This is the "big" baseline that directly mixes a 2g-dimensional
    feature space with a full 2g×2g weight matrix.
    """
    def __init__(self, genus: int, use_relu: bool = False):
        super().__init__()
        self.genus = genus
        self.use_relu = use_relu
        D = 2 * genus

        self.conv = make_conv_trunk()
        self.to2g   = nn.Linear(64, D)
        self.mix2g  = nn.Linear(D, D)
        self.classifier = nn.Linear(D, 10)

        # Optional: Kaiming inits for the 2g layers
        nn.init.kaiming_uniform_(self.to2g.weight, a=math.sqrt(5))
        nn.init.zeros_(self.to2g.bias)
        nn.init.kaiming_uniform_(self.mix2g.weight, a=math.sqrt(5))
        nn.init.zeros_(self.mix2g.bias)
        # classifier left at default init

    def forward(self, x):
        B = x.size(0)
        h = self.conv(x).view(B, -1)   # (B, 64)

        z = self.to2g(h)               # (B, 2g)
        if self.use_relu:
            z = F.relu(z)
        z = self.mix2g(z)              # (B, 2g)
        logits = self.classifier(z)    # (B, 10)
        return logits

# ---------- helpers ----------
def count_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

def conv_params(m):
    return sum(p.numel() for n,p in m.named_parameters()
               if p.requires_grad and n.startswith("conv"))

def heads_only_params(m):
    return count_params(m) - conv_params(m)

if "train_epoch_plain" not in globals():
    def train_epoch_plain(model, loader, opt, clip=None):
        model.train()
        ce = nn.CrossEntropyLoss()
        tot_loss, correct, n = 0.0, 0, 0
        for x, y in loader:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            opt.zero_grad(set_to_none=True)
            logits = model(x)
            loss = ce(logits, y)
            loss.backward()
            if clip is not None:
                nn.utils.clip_grad_norm_(model.parameters(), clip)
            opt.step()
            tot_loss += loss.item() * x.size(0)
            correct  += (logits.argmax(1) == y).sum().item()
            n += x.size(0)
        return tot_loss/n, 100.0*correct/n

if "eval_epoch_plain" not in globals():
    @torch.no_grad()
    def eval_epoch_plain(model, loader):
        model.eval()
        ce = nn.CrossEntropyLoss()
        tot_loss, correct, n = 0.0, 0, 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = ce(logits, y)
            tot_loss += loss.item() * x.size(0)
            correct  += (logits.argmax(1) == y).sum().item()
            n += x.size(0)
        return tot_loss/n, 100.0*correct/n

# ---------- instantiate & train (linear version by default) ----------
use_relu = True   # set True if you also want the ReLU variant

full2g = Full2gMixerNet(genus, use_relu=use_relu).to(device)
opt_full2g = torch.optim.AdamW(full2g.parameters(), lr=3e-4, weight_decay=1e-4)

total = count_params(full2g)
conv  = conv_params(full2g)
heads = heads_only_params(full2g)
print(f"\nFull2gMixerNet (g={genus}, use_relu={use_relu}) params: {total:,}")
print(f"  conv trunk : {conv:,}")
print(f"  heads-only : {heads:,} (total minus conv trunk)")

EPOCHS = 8
CLIP   = 1.0

print("\nStarting training of full 2g×2g baseline...")
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_epoch_plain(full2g, train_loader, opt_full2g, clip=CLIP)
    te_loss, te_acc = eval_epoch_plain(full2g, test_loader)
    print(f"[FULL 2g×2g g={genus}, relu={use_relu}] Ep {ep:02d} | "
          f"train {tr_loss:.4f} / {tr_acc:.2f}% | "
          f"test {te_loss:.4f} / {te_acc:.2f}%")



Using device: cuda
Full 2g×2g baseline will use genus = 30

Full2gMixerNet (g=30, use_relu=True) params: 26,986
  conv trunk : 18,816
  heads-only : 8,170 (total minus conv trunk)

Starting training of full 2g×2g baseline...
[FULL 2g×2g g=30, relu=True] Ep 01 | train 1.7115 / 35.73% | test 1.2891 / 50.89%
[FULL 2g×2g g=30, relu=True] Ep 02 | train 1.2209 / 54.13% | test 1.0996 / 59.92%
[FULL 2g×2g g=30, relu=True] Ep 03 | train 1.0737 / 61.08% | test 0.9570 / 66.27%
[FULL 2g×2g g=30, relu=True] Ep 04 | train 0.9104 / 69.00% | test 0.7928 / 73.39%
[FULL 2g×2g g=30, relu=True] Ep 05 | train 0.7682 / 74.64% | test 0.6654 / 78.48%
[FULL 2g×2g g=30, relu=True] Ep 06 | train 0.6521 / 78.90% | test 0.5723 / 81.22%
[FULL 2g×2g g=30, relu=True] Ep 07 | train 0.5691 / 81.76% | test 0.5146 / 83.17%
[FULL 2g×2g g=30, relu=True] Ep 08 | train 0.5064 / 84.00% | test 0.4534 / 85.70%
