# Faithful Abel–Jacobi (AJ) Training Notebook (Two Sheets)

This notebook updates your genus-30 AJ MNIST training pipeline to use the **faithful two-sheet** lookup tables:

- `I0(z)` = AJ integrals to the lift on sheet 0  
- `I1(z)` = AJ integrals to the lift on sheet 1  
- `B`     = constant bridge vector (used in table generation)

**Main change:** we replace the old “single-table + learned sign” sheet kludge with a **two-table sheet selector**:
- `sign = tanh(sheet_logits)` (keeps your initialization trick)
- `w0 = (1+sign)/2`  
- `I = w0 * I0 + (1-w0) * I1`

Everything else stays in the same spirit: anchored point initialization, per-channel normalization, learnable global gain `gamma`, boundary/branch penalties, and axis-aligned Fourier features (mocking periods).

**If you don’t have an ω-table** (`aj_omegas_genus30.pt`), the notebook will compute a fast **ω-magnitude proxy** from the branch points for anchor selection.

In [None]:
# @title Setup: installs, imports, Drive mount, paths
!pip install -q torch torchvision numpy

import os, math, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from google.colab import drive

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

# ===========================
# Drive + Paths (match lookup-table notebook)
# ===========================
DRIVE_FOLDER = "AJ_Tables_g30"   # <-- change if you used a different folder
drive.mount('/content/drive', force_remount=True)
SAVE_DIR = f"/content/drive/MyDrive/{DRIVE_FOLDER}"

# Faithful integrals (I0/I1/B) file:
INTEGRALS_PATH = os.path.join(SAVE_DIR, "aj_integrals_genus30.pt")

# Optional omegas file (only used to compute ω-aware anchors):
OMEGAS_PATH    = os.path.join(SAVE_DIR, "aj_omegas_genus30.pt")

print("SAVE_DIR        :", SAVE_DIR)
print("INTEGRALS_PATH  :", INTEGRALS_PATH)
print("OMEGAS_PATH     :", OMEGAS_PATH)
assert os.path.exists(INTEGRALS_PATH), "Integrals file not found in Drive."

In [None]:
# @title Load faithful tables: I0, I1, B (+ geometry)
ints = torch.load(INTEGRALS_PATH, map_location='cpu', weights_only=False)

# Required metadata
grid_r_np  = np.array(ints["grid_r"])
grid_i_np  = np.array(ints["grid_i"])
branch_pts = np.array(ints["branch_pts"])
branch_cuts = ints.get("branch_cuts", None)

# Faithful integrals
I0 = ints.get("I0", ints.get("I_plus", None))
I1 = ints.get("I1", ints.get("I_minus", None))
B  = ints.get("B", None)

assert I0 is not None, "Expected I0 (faithful) or at least I_plus in integrals file."

# Ensure torch tensors and complex dtype
if not torch.is_tensor(I0):
    I0 = torch.tensor(I0)
I0 = I0.to(torch.cfloat)

if I1 is None:
    assert B is not None, "I1 missing and B missing; cannot construct second sheet."
    if not torch.is_tensor(B):
        B = torch.tensor(B)
    B = B.to(torch.cfloat)
    I1 = B[:, None, None] - I0
else:
    if not torch.is_tensor(I1):
        I1 = torch.tensor(I1)
    I1 = I1.to(torch.cfloat)

# Ensure B exists (estimate if missing)
if B is None:
    B = (I0 + I1).mean(dim=(1,2))
    print("B not found in file → estimated B as mean(I0+I1) over grid.")
else:
    if not torch.is_tensor(B):
        B = torch.tensor(B)
    B = B.to(torch.cfloat)

genus = int(ints.get("genus", I0.shape[0]))
H, W = I0.shape[-2:]

# Torch versions of axes / branch points
grid_r = torch.tensor(grid_r_np, dtype=torch.float32)
grid_i = torch.tensor(grid_i_np, dtype=torch.float32)
branch_pts_t = torch.tensor(branch_pts)  # complex

print(f"Loaded: genus={genus}, grid=({H}x{W})")
print("Integrals keys present:", [k for k in ["I0","I1","B","I_plus","I_minus","sheet_parity","branch_cuts"] if k in ints])

In [None]:
# @title Optional: load ω-table (if you generated/saved it)
Om_plus = None
if os.path.exists(OMEGAS_PATH):
    omeg = torch.load(OMEGAS_PATH, map_location='cpu', weights_only=False)
    Om_plus = omeg.get("omega_plus", None)
    if Om_plus is not None:
        Om_plus = Om_plus.to(torch.cfloat)
        print("Loaded omega_plus:", tuple(Om_plus.shape))
    else:
        print("OMEGAS_PATH exists but no key 'omega_plus' found; will use proxy ω-map.")
else:
    print("No omegas file found; will use proxy ω-map from branch points.")

In [None]:
# @title Compute ω-aware anchors and AJ normalization (mu/sigma)
import numpy as np
import torch

# -----------------------------
# 1) ω-strength map for anchors
# -----------------------------
with torch.no_grad():
    if Om_plus is not None:
        Wmap = Om_plus.abs().sum(dim=0)          # (H,W)
        Wmap_np = Wmap.cpu().numpy()
        print("Anchor map: using ω-table (sum_k |ω_k|).")
    else:
        # Proxy (magnitude-only): score(z) = log sum_k |z|^k / |sqrt(P(z))|
        # Enough for choosing anchors (we only need a monotone proxy).
        print("Anchor map: using proxy score from branch points (no ω-table).")

        X = torch.tensor(grid_r_np[None, :].repeat(H, axis=0), dtype=torch.float64)
        Y = torch.tensor(grid_i_np[:, None].repeat(W, axis=1), dtype=torch.float64)

        bp = torch.tensor(branch_pts, dtype=torch.complex128)
        bp_real = bp.real.view(1,1,-1).to(torch.float64)
        bp_imag = bp.imag.view(1,1,-1).to(torch.float64)

        dx = X.unsqueeze(-1) - bp_real
        dy = Y.unsqueeze(-1) - bp_imag
        d  = torch.sqrt(dx*dx + dy*dy).clamp_min(1e-12)
        log_absP = torch.log(d).sum(dim=-1)
        log_abs_sqrtP = 0.5 * log_absP

        absZ = torch.sqrt(X*X + Y*Y).clamp_min(1e-12)
        log_absZ = torch.log(absZ)

        k = torch.arange(genus, dtype=torch.float64).view(1,1,-1)
        terms = k * log_absZ.unsqueeze(-1) - log_abs_sqrtP.unsqueeze(-1)
        Wmap_log = torch.logsumexp(terms, dim=-1)    # (H,W)
        Wmap_np = Wmap_log.cpu().numpy()

# -----------------------------
# 2) Mask edges + avoid branch points
# -----------------------------
gy = np.linspace(-1, 1, H)[:, None]
gx = np.linspace(-1, 1, W)[None, :]
edge_mask = (np.abs(gx) < 0.92) & (np.abs(gy) < 0.92)

grid_x = grid_r_np[None, :].repeat(H, axis=0)
grid_y = grid_i_np[:, None].repeat(W, axis=1)

bp_real = np.real(branch_pts).reshape(-1, 1, 1)
bp_imag = np.imag(branch_pts).reshape(-1, 1, 1)
d2 = (grid_x - bp_real)**2 + (grid_y - bp_imag)**2
bp_mask = (d2.min(axis=0) > 0.25)     # keep points with dist > ~0.5

mask = edge_mask & bp_mask

# -----------------------------
# 3) Candidate set = top quantile by ω-strength
# -----------------------------
score = np.where(mask, Wmap_np, -np.inf)
valid = score[score > -np.inf]
assert valid.size > 0, "Mask removed all points; relax edge/bp masks."

q = 0.85
cand = np.argwhere(score >= np.quantile(valid, q))
while cand.shape[0] < genus and q > 0.50:
    q -= 0.05
    cand = np.argwhere(score >= np.quantile(valid, q))

print(f"Anchor candidates: {cand.shape[0]} points (quantile={q:.2f})")

# -----------------------------
# 4) Farthest-point sampling for diverse anchors
# -----------------------------
def farthest_k(points_hw, k):
    pts = points_hw.copy()
    start = pts[np.argmax(score[tuple(pts.T)])]
    chosen = [start]
    if k == 1:
        return np.array(chosen)

    coords = np.stack([grid_x[tuple(pts.T)], grid_y[tuple(pts.T)]], axis=1)  # (N,2)
    c0 = np.array([grid_x[start[0], start[1]], grid_y[start[0], start[1]]])[None, :]
    mind = np.sum((coords - c0)**2, axis=1)

    for _ in range(1, k):
        j = np.argmax(mind)
        chosen.append(pts[j])
        cj = coords[j][None, :]
        mind = np.minimum(mind, np.sum((coords - cj)**2, axis=1))
    return np.array(chosen)

anchors_hw = farthest_k(cand, genus)  # (g,2) with (iy,ix)
x0 = grid_r_np[anchors_hw[:, 1]]
y0 = grid_i_np[anchors_hw[:, 0]]
anchors_xy = np.stack([x0, y0], axis=1)          # (g,2)
anchors_xy_t = torch.tensor(anchors_xy, dtype=torch.float32)

print("Example anchors (first 8):")
for i in range(min(8, genus)):
    print(f"  {i:02d}: x0={anchors_xy[i,0]:+.3f}, y0={anchors_xy[i,1]:+.3f}")

# -----------------------------
# 5) AJ normalization stats (use BOTH sheets)
# -----------------------------
I0_ch = torch.cat([I0.real, I0.imag], dim=0)   # (2g,H,W)
I1_ch = torch.cat([I1.real, I1.imag], dim=0)   # (2g,H,W)
I_cat = torch.stack([I0_ch, I1_ch], dim=0)     # (2,2g,H,W)

mu = I_cat.mean(dim=(0,2,3))                   # (2g,)
sigma = I_cat.std(dim=(0,2,3)).clamp_min(1e-6) # (2g,)

mu_t = mu.float()
sigma_t = sigma.float()

print("mu/sigma computed from both sheets.")

In [None]:
# @title AJ core: faithful two-sheet lookup + normalization + penalties
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

def _pack_complex_table(table_gHW: torch.Tensor) -> torch.Tensor:
    # Input:  (g,H,W) complex
    # Output: (1,2g,H,W) real with channels [Re..., Im...]
    re, im = table_gHW.real, table_gHW.imag
    return torch.cat([re, im], dim=0).unsqueeze(0).contiguous()

class AJGridActivationNorm(nn.Module):
    def __init__(self,
                 I0: torch.Tensor,
                 I1: torch.Tensor,
                 grid_r: torch.Tensor,
                 grid_i: torch.Tensor,
                 branch_pts: torch.Tensor,
                 mu: torch.Tensor,
                 sigma: torch.Tensor,
                 branch_cuts=None):
        super().__init__()
        self.g = I0.shape[0]

        # Two sheets
        self.register_buffer("I0", _pack_complex_table(I0))   # (1,2g,H,W)
        self.register_buffer("I1", _pack_complex_table(I1))   # (1,2g,H,W)

        # stats
        self.register_buffer("mu",    mu.view(1,1,-1))        # (1,1,2g)
        self.register_buffer("sigma", sigma.view(1,1,-1))     # (1,1,2g)
        self.gamma = nn.Parameter(torch.tensor(1.0))          # learnable gain

        # bounds
        self.register_buffer("r_min", torch.tensor(float(grid_r.min())))
        self.register_buffer("r_max", torch.tensor(float(grid_r.max())))
        self.register_buffer("i_min", torch.tensor(float(grid_i.min())))
        self.register_buffer("i_max", torch.tensor(float(grid_i.max())))

        # branch points (repulsion)
        self.register_buffer("bp_real", branch_pts.real.float())
        self.register_buffer("bp_imag", branch_pts.imag.float())

        # optional cut penalty
        self.has_cuts = (branch_cuts is not None)
        if self.has_cuts:
            a = np.array([complex(ab[0]) for ab in branch_cuts], dtype=np.complex128)
            b = np.array([complex(ab[1]) for ab in branch_cuts], dtype=np.complex128)
            self.register_buffer("cut_ax", torch.tensor(a.real, dtype=torch.float32))
            self.register_buffer("cut_ay", torch.tensor(a.imag, dtype=torch.float32))
            self.register_buffer("cut_bx", torch.tensor(b.real, dtype=torch.float32))
            self.register_buffer("cut_by", torch.tensor(b.imag, dtype=torch.float32))

    def _map_raw_to_bounds(self, raw_xy):
        xr = self.r_min + (self.r_max - self.r_min) * torch.sigmoid(raw_xy[..., 0])
        yi = self.i_min + (self.i_max - self.i_min) * torch.sigmoid(raw_xy[..., 1])
        return xr, yi

    def _norm_to_grid(self, xr, yi):
        gx = 2.0 * (xr - self.r_min) / (self.r_max - self.r_min) - 1.0
        gy = 2.0 * (yi - self.i_min) / (self.i_max - self.i_min) - 1.0
        return gx, gy

    def forward(self, raw_xy: torch.Tensor, sheet_logits: torch.Tensor, return_aux=True):
        B, g, _ = raw_xy.shape
        assert g == self.g

        xr, yi = self._map_raw_to_bounds(raw_xy)
        gx, gy = self._norm_to_grid(xr, yi)
        grid = torch.stack([gx, gy], dim=-1).view(B*g, 1, 1, 2)

        # Sample both sheets
        I0_s = F.grid_sample(self.I0.expand(B*g, -1, -1, -1),
                             grid, mode="bilinear", align_corners=True).view(B, g, -1)
        I1_s = F.grid_sample(self.I1.expand(B*g, -1, -1, -1),
                             grid, mode="bilinear", align_corners=True).view(B, g, -1)

        # Sheet selector: sign=tanh(logit), w0=(1+sign)/2
        sign = torch.tanh(sheet_logits).unsqueeze(-1)   # (B,g,1)
        w0   = 0.5 * (1.0 + sign)                       # (B,g,1)
        I    = w0 * I0_s + (1.0 - w0) * I1_s            # (B,g,2g)

        # Standardize and sum over points
        I_std  = (I - self.mu) / self.sigma
        coords = self.gamma * I_std.sum(dim=1)          # (B,2g)

        aux = None
        if return_aux:
            # boundary penalty in normalized coords
            margin = 0.95
            bpen = ((gx.abs() - margin).clamp_min(0)**2 +
                    (gy.abs() - margin).clamp_min(0)**2).mean()

            # branch-point repulsion in physical coords
            dx = xr.unsqueeze(-1) - self.bp_real
            dy = yi.unsqueeze(-1) - self.bp_imag
            d2 = dx*dx + dy*dy
            tau = 0.07
            rpen = torch.exp(-d2 / (2*tau*tau)).mean()

            # optional: cut-distance penalty
            cpen = torch.zeros((), device=xr.device)
            if self.has_cuts:
                px = xr.unsqueeze(-1)  # (B,g,1)
                py = yi.unsqueeze(-1)
                ax = self.cut_ax.view(1,1,-1)
                ay = self.cut_ay.view(1,1,-1)
                bx = self.cut_bx.view(1,1,-1)
                by = self.cut_by.view(1,1,-1)

                vx = bx - ax
                vy = by - ay
                wx = px - ax
                wy = py - ay
                vv = (vx*vx + vy*vy).clamp_min(1e-12)
                t = (wx*vx + wy*vy) / vv
                t = t.clamp(0.0, 1.0)
                cx = ax + t*vx
                cy = ay + t*vy
                d2seg = (px - cx)**2 + (py - cy)**2
                d2min = d2seg.min(dim=-1).values
                tau_cut = 0.07
                cpen = torch.exp(-d2min / (2*tau_cut*tau_cut)).mean()

            aux = {
                "x": xr, "y": yi, "gx": gx, "gy": gy,
                "bound_penalty": bpen,
                "branch_penalty": rpen,
                "cut_penalty": cpen,
                "sheet_sign": sign.squeeze(-1),
                "sheet_w0": w0.squeeze(-1),
            }

        return coords, aux

In [None]:
# @title AJMNIST_Anchored + axis-periodic head (structure preserved)
import torch
import torch.nn as nn
import torch.nn.functional as F

class AJMNIST_Anchored(nn.Module):
    def __init__(self, genus, I0, I1, grid_r, grid_i, branch_pts,
                 anchors_xy, mu, sigma, embed_dim=8, branch_cuts=None):
        super().__init__()
        self.genus = genus

        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((1,1))
        )

        self.embed = nn.Parameter(torch.empty(genus, embed_dim))
        nn.init.uniform_(self.embed, -2.0, 2.0)

        # shared head (no bias) + per-point bias
        self.point_head = nn.Linear(64 + embed_dim, 3, bias=False)
        nn.init.xavier_uniform_(self.point_head.weight)
        self.point_bias = nn.Parameter(torch.zeros(genus, 3))

        # Faithful AJ activation
        self.aj = AJGridActivationNorm(I0, I1, grid_r, grid_i, branch_pts, mu, sigma, branch_cuts=branch_cuts)
        self.classifier = nn.Linear(2*genus, 10)

        # ω-aware initialization for biases
        rmin, rmax = float(grid_r.min()), float(grid_r.max())
        imin, imax = float(grid_i.min()), float(grid_i.max())

        def logit(p): return float(torch.log(p/(1-p)))

        with torch.no_grad():
            for i in range(genus):
                x0, y0 = float(anchors_xy[i,0]), float(anchors_xy[i,1])
                px = (x0 - rmin) / (rmax - rmin)
                py = (y0 - imin) / (imax - imin)
                self.point_bias[i, 0] = logit(torch.tensor(px))
                self.point_bias[i, 1] = logit(torch.tensor(py))

                # Near-binary sheet init: sign=tanh(logit) ≈ ±0.8 ⇒ w0≈0.9 or 0.1
                target_sign = 0.8 if (i % 2 == 0) else -0.8
                self.point_bias[i, 2] = float(torch.atanh(torch.tensor(target_sign)))

    def forward(self, x, return_aux=False):
        B = x.size(0)
        h = self.conv(x).view(B, -1)
        h_exp = h.unsqueeze(1).expand(-1, self.genus, -1)
        emb   = self.embed.unsqueeze(0).expand(B, -1, -1)

        out   = self.point_head(torch.cat([h_exp, emb], dim=2)) + self.point_bias.unsqueeze(0)
        raw_xy, sheet_logits = out[..., :2], out[..., 2]
        coords, aux = self.aj(raw_xy, sheet_logits, return_aux=True)
        logits = self.classifier(coords)
        if return_aux:
            return logits, aux
        return logits

# Axis-aligned Fourier torus features (mock periods)
class TorusFeatures(nn.Module):
    def __init__(self, dim: int, K: int = 2, freqs=None, learnable: bool = False):
        super().__init__()
        if freqs is None:
            freqs = torch.tensor([0.5, 1.0], dtype=torch.float32)[:K]
        else:
            freqs = torch.as_tensor(freqs, dtype=torch.float32)[:K]
        if learnable:
            self.freqs = nn.Parameter(freqs)
        else:
            self.register_buffer("freqs", freqs)
        self.dim, self.K = dim, len(freqs)

    def forward(self, u):
        B, D = u.shape
        f = self.freqs.view(1,1,-1).to(u.device)
        ang = u.unsqueeze(-1) * f
        return torch.cat([torch.cos(ang), torch.sin(ang)], dim=-1).view(B, D*2*self.K)

class AJMNIST_AxisPeriodic(nn.Module):
    def __init__(self, genus, I0, I1, grid_r, grid_i, branch_pts,
                 anchors_xy, mu, sigma, embed_dim=8, K=2, learnable_freqs=False, branch_cuts=None):
        super().__init__()
        self.base = AJMNIST_Anchored(genus, I0, I1, grid_r, grid_i, branch_pts,
                                     anchors_xy, mu, sigma, embed_dim=embed_dim, branch_cuts=branch_cuts)
        D = 2*genus
        self.torus = TorusFeatures(D, K=K, learnable=learnable_freqs)
        self.classifier = nn.Linear(D*2*K, 10)

    @property
    def genus(self): return self.base.genus

    def forward(self, x, return_aux=False):
        B = x.size(0)
        h = self.base.conv(x).view(B, -1)
        h_exp = h.unsqueeze(1).expand(-1, self.genus, -1)
        emb   = self.base.embed.unsqueeze(0).expand(B, -1, -1)
        out   = self.base.point_head(torch.cat([h_exp, emb], dim=2)) + self.base.point_bias.unsqueeze(0)
        raw_xy, sheet_logits = out[..., :2], out[..., 2]

        coords, aux = self.base.aj(raw_xy, sheet_logits, return_aux=True)
        feats = self.torus(coords)
        logits = self.classifier(feats)
        if return_aux:
            return logits, aux
        return logits

In [None]:
# @title MNIST data loaders
tfm = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
train_ds = torchvision.datasets.MNIST(root="/content/data", train=True, download=True, transform=tfm)
test_ds  = torchvision.datasets.MNIST(root="/content/data", train=False, download=True, transform=tfm)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = torch.utils.data.DataLoader(test_ds,  batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

print("Train batches:", len(train_loader), " Test batches:", len(test_loader))

In [None]:
# @title AMP-aware training helpers
from torch.utils.data import DataLoader

USE_AMP = (device.type == "cuda")
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
ce = nn.CrossEntropyLoss()

def count_params(m: nn.Module) -> int:
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

@torch.no_grad()
def eval_epoch(model: nn.Module, loader):
    model.eval()
    tot, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = ce(logits, y)
        tot += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n += x.size(0)
    return tot/n, 100.0*correct/n

def train_epoch_amp(model: nn.Module, loader, opt,
                    clip: float = 1.0,
                    lam_branch: float = 1e-3,
                    lam_bound:  float = 1e-3,
                    lam_cut:    float = 0.0):
    model.train()
    tot, correct, n = 0.0, 0, 0

    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=USE_AMP):
            logits, aux = model(x, return_aux=True)
            loss = ce(logits, y)
            loss = loss + lam_branch * aux.get("branch_penalty", 0.0)                          + lam_bound  * aux.get("bound_penalty",  0.0)                          + lam_cut    * aux.get("cut_penalty",    0.0)

        scaler.scale(loss).backward()
        if clip is not None:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), clip)
        scaler.step(opt)
        scaler.update()

        tot += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        n += x.size(0)

    return tot/n, 100.0*correct/n

In [None]:
# @title Train faithful axis-periodic AJ model (GPU + AMP)
# Move tables/stats to device
I0_dev = I0.to(device)
I1_dev = I1.to(device)
grid_r_dev = grid_r.to(device)
grid_i_dev = grid_i.to(device)
branch_pts_dev = branch_pts_t.to(device)
anchors_xy_dev = anchors_xy_t.to(device)
mu_dev = mu_t.to(device)
sigma_dev = sigma_t.to(device)

# Smaller-batch loaders (OOM-safe)
BASE_BS = 64
train_loader_small = DataLoader(train_ds, batch_size=BASE_BS, shuffle=True,
                                num_workers=2, pin_memory=True, drop_last=True)
test_loader_small  = DataLoader(test_ds,  batch_size=256, shuffle=False,
                                num_workers=2, pin_memory=True)

# Instantiate model
K = 2
EMBED_DIM = 4
LEARN_FREQS = False

aj_axis = AJMNIST_AxisPeriodic(
    genus,
    I0_dev, I1_dev,
    grid_r_dev, grid_i_dev,
    branch_pts_dev,
    anchors_xy_dev,
    mu_dev, sigma_dev,
    embed_dim=EMBED_DIM,
    K=K,
    learnable_freqs=LEARN_FREQS,
    branch_cuts=branch_cuts
).to(device)

print(f"AJ axis-periodic faithful (g={genus}, K={K}) params: {count_params(aj_axis):,}")

# Two-tier LR (same spirit as your existing code)
fast_params = list(aj_axis.base.point_head.parameters()) +               [aj_axis.base.point_bias] +               list(aj_axis.torus.parameters()) +               list(aj_axis.classifier.parameters())

base_params = list(aj_axis.base.conv.parameters()) +               list(aj_axis.base.aj.parameters())

opt = torch.optim.AdamW(
    [
        {"params": base_params, "lr": 3e-4},
        {"params": fast_params, "lr": 1e-3},
    ],
    weight_decay=1e-4
)

# Train
EPOCHS     = 8
CLIP_NORM  = 1.0
LAM_BRANCH = 1e-3
LAM_BOUND  = 1e-3
LAM_CUT    = 0.0   # keep 0 by default; set ~1e-4..1e-3 if you want seam avoidance

print("Starting training | AMP =", USE_AMP)
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_epoch_amp(
        aj_axis, train_loader_small, opt,
        clip=CLIP_NORM,
        lam_branch=LAM_BRANCH,
        lam_bound=LAM_BOUND,
        lam_cut=LAM_CUT
    )
    te_loss, te_acc = eval_epoch(aj_axis, test_loader_small)
    print(f"[Faithful AJ axis K={K}] Epoch {ep:02d} | "
          f"train {tr_loss:.4f} / {tr_acc:.2f}% | "
          f"test {te_loss:.4f} / {te_acc:.2f}%")

In [None]:
# @title Quick diagnostics: sheet usage + penalties
@torch.no_grad()
def batch_diag(model, loader, n_batches=3):
    model.eval()
    sheets = []
    bpen = []
    rpen = []
    cpen = []
    for i, (x,y) in enumerate(loader):
        if i >= n_batches:
            break
        x = x.to(device)
        logits, aux = model(x, return_aux=True)
        sheets.append(aux["sheet_w0"].mean().item())
        bpen.append(aux["bound_penalty"].item())
        rpen.append(aux["branch_penalty"].item())
        cpen.append(aux.get("cut_penalty", torch.tensor(0.0)).item())
    print(f"Mean sheet w0 over {n_batches} batches: {np.mean(sheets):.3f} (0→sheet1, 1→sheet0)")
    print(f"Bound penalty:  {np.mean(bpen):.6f}")
    print(f"Branch penalty: {np.mean(rpen):.6f}")
    print(f"Cut penalty:    {np.mean(cpen):.6f}")

batch_diag(aj_axis, train_loader_small, n_batches=3)