
# Faithful Abel–Jacobi lookup tables (two-sheet, cut-aware)

This notebook generates lookup tables for a genus **g** hyperelliptic curve

\[
w^2 = \prod_{i=1}^{2g+2}(z-a_i),
\qquad \omega_k = \frac{z^k\,dz}{w},\; k=0,\dots,g-1,
\]

and precomputes Abel–Jacobi integrals on a rectangular grid:

- **Sheet 0** integrals: `I0[k, iy, ix]`
- **Sheet 1** integrals: `I1[k, iy, ix]`
- A constant **bridge vector** `B` such that **`I0 + I1 = B`** everywhere.

Key points:

- We use the **straight segment** from `base_point → z` in the base plane.
- We detect crossings with the user-chosen **branch cuts** and **flip the sheet** each time (i.e. the integrand changes sign).
- The straight segment ends on either sheet depending on the parity of cut crossings. Using the bridge `B`, we convert this into **tables for both sheets**.

> **About `ix` and `iy`:**  
> `ix` indexes the **x / real** grid (`grid_r[ix]`), and `iy` indexes the **y / imaginary** grid (`grid_i[iy]`).  
> The complex coordinate is `z = grid_r[ix] + 1j * grid_i[iy]`.

The tables are saved to Google Drive in the same `.pt` format as your current pipeline, but with extra keys:
- `I0`, `I1`, `B`, and (optionally) `sheet_parity`.


In [None]:

# @title Setup: installs, imports, Drive mount, config
!pip install -q torch numpy mpmath tqdm

import os, time, math, json, numpy as np, torch, mpmath as mp
from tqdm import tqdm
from google.colab import drive

# ===========================
# 0) User config
# ===========================
genus = 30          # default genus (g)
H = W = 96          # grid resolution (H rows in y, W columns in x)

r_min, r_max = -6.0, 6.0
i_min, i_max = -6.0, 6.0
grid_r = np.linspace(r_min, r_max, W).astype(np.float64)
grid_i = np.linspace(i_min, i_max, H).astype(np.float64)

# Numerical precision (mpmath) and quadrature
mp.mp.dps = 50      # raise if you need more precision
N_GAUSS = 32        # Gauss–Legendre nodes per segment (32 is a decent default)

# Intersection tolerance for segment–segment tests
EPS_INTERSECT = 1e-12

# Base point (choose outside the grid box)
base_point = complex(r_min - 2.0, i_min - 2.0)

# Cut generation parameters
CUT_RADIUS = 4.0
CUT_JITTER = 0.25
CUT_SEED   = 123

# ===========================
# 1) Drive & save directory
# ===========================
DRIVE_FOLDER = f"AJ_Tables_g{genus}_faithful"   # change if you want to overwrite an existing folder
drive.mount('/content/drive', force_remount=True)

SAVE_DIR = f"/content/drive/MyDrive/{DRIVE_FOLDER}"
os.makedirs(SAVE_DIR, exist_ok=True)

# File paths (ω and I saved separately, resume-safe)
OMEGAS_PATH    = os.path.join(SAVE_DIR, f"aj_omegas_genus{genus}.pt")
INTEGRALS_PATH = os.path.join(SAVE_DIR, f"aj_integrals_genus{genus}.pt")

print("Will save to:")
print("  ω  :", OMEGAS_PATH)
print("  I  :", INTEGRALS_PATH)


In [None]:

# @title Helper: atomic save (resume-safe)
def atomic_torch_save(obj, path):
    tmp = path + ".tmp"
    torch.save(obj, tmp)
    os.replace(tmp, path)


In [None]:

# @title Cuts: robust, non-overlapping construction (same style as your existing notebook)
def make_hyperelliptic_cuts(g, radius=4.0, jitter=0.25, seed=123,
                            r_min=-6.0, r_max=6.0, i_min=-6.0, i_max=6.0):
    """
    Returns g+1 cuts [(a0,a1), ..., (a_g, a'_g)] for a hyperelliptic curve
    with 2g+2 branch points. We place 2g+2 points on a perturbed circle and
    pair nearest neighbors to form short segments.

    This mirrors the construction in your original notebook, but parameterized.
    """
    rng = np.random.RandomState(seed)
    m = 2*g + 2
    thetas = np.linspace(0, 2*np.pi, m, endpoint=False)
    rng.shuffle(thetas)
    radii = radius * (1.0 + jitter * (rng.rand(m) - 0.5))
    pts = radii * np.exp(1j * thetas)

    # Ensure within grid box (shrink if needed)
    scale = max((pts.real.max()-pts.real.min())/(r_max-r_min+1e-6),
                (pts.imag.max()-pts.imag.min())/(i_max-i_min+1e-6))
    if scale > 0.85:
        pts = pts / (scale/0.85)

    remaining = list(range(m))
    cuts = []
    while remaining:
        i = remaining.pop(0)
        pi = pts[i]
        dists = [(j, abs(pi - pts[j])) for j in remaining]
        j = min(dists, key=lambda t: t[1])[0]
        remaining.remove(j)

        # Shorten segment slightly so endpoints are not exact branch points
        a, b = pi, pts[j]
        mid = 0.5*(a+b)
        a = a + 0.05*(a - mid)
        b = b + 0.05*(b - mid)
        cuts.append((complex(a), complex(b)))

    if len(cuts) > g+1:
        cuts.sort(key=lambda ab: -abs(ab[0]-ab[1]))
        cuts = cuts[:g+1]
    assert len(cuts) == g+1
    return cuts

branch_cuts = make_hyperelliptic_cuts(
    genus, radius=CUT_RADIUS, jitter=CUT_JITTER, seed=CUT_SEED,
    r_min=r_min, r_max=r_max, i_min=i_min, i_max=i_max
)
branch_pts = [a for ab in branch_cuts for a in ab]

print(f"genus={genus} → cuts={len(branch_cuts)} ; total branch points={len(branch_pts)}")



## Geometry utilities: segment–segment intersection and cut crossings

We parameterize the straight segment from base point \(p\) to target \(q\) as
\[
p(t) = p + t(q-p),\quad t\in[0,1].
\]
Each branch cut is a segment \([a,b]\). We detect all intersections \(t\in(0,1)\) and sort them to split the path into subsegments.

Each time we cross a cut, we **flip the sheet** (equivalently \(w\to -w\) and \(\omega_k \to -\omega_k\)).


In [None]:

# @title Segment intersection + cut crossing list
def _cross2(ax, ay, bx, by):
    return ax*by - ay*bx

def segment_intersection_t(p, q, a, b, eps=1e-12):
    """
    Return the parameter t in (0,1) such that p + t(q-p) intersects the segment [a,b],
    or None if no proper intersection.

    We exclude intersections extremely close to segment endpoints using `eps`.
    """
    px, py = float(p.real), float(p.imag)
    qx, qy = float(q.real), float(q.imag)
    ax, ay = float(a.real), float(a.imag)
    bx, by = float(b.real), float(b.imag)

    rx, ry = qx - px, qy - py
    sx, sy = bx - ax, by - ay

    denom = _cross2(rx, ry, sx, sy)
    if abs(denom) < eps:
        return None  # parallel or nearly parallel

    apx, apy = ax - px, ay - py
    t = _cross2(apx, apy, sx, sy) / denom
    u = _cross2(apx, apy, rx, ry) / denom

    if (eps < t < 1.0 - eps) and (eps < u < 1.0 - eps):
        return float(t)
    return None

def cut_crossings(base, z, cuts, eps=1e-12):
    """
    Returns sorted list of (t, cut_index) for intersections of segment base->z with cuts.
    """
    hits = []
    for j, (a, b) in enumerate(cuts):
        t = segment_intersection_t(base, z, a, b, eps=eps)
        if t is not None:
            hits.append((t, j))
    hits.sort(key=lambda x: x[0])
    return hits


In [None]:

# @title Quadrature nodes: Gauss–Legendre on [0,1]
# Gauss–Legendre nodes/weights on [-1,1]
x_gl, w_gl = np.polynomial.legendre.leggauss(N_GAUSS)
# Map to [0,1]
s_gl = (x_gl + 1.0) / 2.0
w01 = w_gl / 2.0

# store as Python floats (mpmath will upcast as needed)
S_GL = [float(s) for s in s_gl]
W_GL = [float(w) for w in w01]

print(f"Gauss–Legendre ready: N_GAUSS={N_GAUSS}, sum(weights)={sum(W_GL):.6f}")



## Cut-aware integrator

We evaluate \(w(z)=\sqrt{P(z)}\) using the principal `mp.sqrt`, but **enforce continuity** along each segment by flipping the sign at quadrature nodes whenever \(|-w - w_{\text{prev}}| < |w - w_{\text{prev}}|\).

When a cut is crossed, we flip the sheet: \(w_{\text{prev}} \leftarrow -w_{\text{prev}}\).

### Integrating **all** \(k=0,\dots,g-1\) at once

To avoid recomputing \(P(z)\) and \(\sqrt{P(z)}\) for each \(k\), we compute the full vector
\[
(\int \omega_0,\dots,\int \omega_{g-1})
\]
in one pass along the quadrature nodes.


In [None]:

# @title Polynomial P(z), sqrt continuation, segment integrator (all k)
# Convert branch points to mpmath types once
BRANCH_PTS_MP = [mp.mpc(a.real, a.imag) for a in branch_pts]

def P_of_z(z_mp):
    """P(z)=∏(z-a_i) as an mpmath complex."""
    prod = mp.mpc(1)
    for a in BRANCH_PTS_MP:
        prod *= (z_mp - a)
    return prod

# Base-sheet choice: pick w(base_point) via principal sqrt
BASE_MP = mp.mpc(base_point.real, base_point.imag)
W_BASE = mp.sqrt(P_of_z(BASE_MP))

def integrate_segment_allk(z0, z1, w_prev):
    """
    Integrate ω_k = z^k dz / w along the straight segment z(s)=z0+s(z1-z0), s∈[0,1],
    for all k=0..g-1 in one pass.
    Returns (I_vec, w_end) where I_vec is list length g (mpmath complex).
    """
    z0 = mp.mpc(z0.real, z0.imag)
    z1 = mp.mpc(z1.real, z1.imag)
    dz = z1 - z0

    if abs(dz) == 0:
        return [mp.mpc(0) for _ in range(genus)], w_prev

    seg_acc = [mp.mpc(0) for _ in range(genus)]
    w_last = w_prev

    for s, wt in zip(S_GL, W_GL):
        z = z0 + dz * s
        w = mp.sqrt(P_of_z(z))  # principal sqrt

        # enforce continuity along the segment
        if w_last is not None and abs(w - w_last) > abs(-w - w_last):
            w = -w

        inv_w = 1 / w
        z_pow = mp.mpc(1)
        for k in range(genus):
            seg_acc[k] += wt * (z_pow * inv_w)
            z_pow *= z

        w_last = w

    # multiply by dz because dz/ds is constant on this segment
    for k in range(genus):
        seg_acc[k] *= dz

    return seg_acc, w_last

def integrate_straight_with_cuts_allk(z):
    """
    Integrate along the straight segment base_point -> z, splitting at cut intersections.
    Each cut crossing flips the sheet (w -> -w), implemented by flipping w_prev after each split.

    Returns:
      I_vec: list length g of mp.mpc
      parity: (# crossings mod 2) — which sheet the straight lift ends on (0=sheet0, 1=sheet1)
      n_cross: number of crossings
    """
    hits = cut_crossings(base_point, z, branch_cuts, eps=EPS_INTERSECT)
    n_cross = len(hits)
    parity = n_cross % 2

    # Build split points
    pts = [base_point]
    for t, _ in hits:
        pts.append(base_point + t*(z - base_point))
    pts.append(z)

    I_total = [mp.mpc(0) for _ in range(genus)]
    w_prev = W_BASE

    for i in range(len(pts) - 1):
        I_seg, w_prev = integrate_segment_allk(pts[i], pts[i+1], w_prev)
        for k in range(genus):
            I_total[k] += I_seg[k]

        # Flip sheet after each cut crossing (i.e., between segments)
        if i < len(pts) - 2:
            w_prev = -w_prev

    return I_total, parity, n_cross


In [None]:

# @title Compute bridge vector B (constant): choose p on a cut and set B = 2 * ∫_{base→p} ω
# Pick an interior point on the first cut
a0, b0 = branch_cuts[0]
p_bridge = 0.5*(a0 + b0)

print("Bridge point p_bridge =", p_bridge)

I_p, parity_p, n_cross_p = integrate_straight_with_cuts_allk(p_bridge)
B_mp = [2 * val for val in I_p]  # B is length-g vector (mpmath complex)

print(f"Computed B using cut #0 midpoint.  crossings (excluding endpoints) = {n_cross_p}, parity={parity_p}")
print("Sample entries of B:")
for k in range(min(3, genus)):
    print(f"  B[{k}] = {complex(B_mp[k])}")


In [None]:

# @title Common metadata for files + config compatibility checks
COMMON_META = {
    "genus": genus,
    "branch_cuts": branch_cuts,
    "branch_pts": np.array(branch_pts, dtype=np.complex128),
    "grid_r": grid_r,
    "grid_i": grid_i,
    "meta": {
        "mp_dps": int(mp.mp.dps),
        "base_point": base_point,
        "grid_shape": (H, W),
        "ranges": (r_min, r_max, i_min, i_max),
        "N_GAUSS": int(N_GAUSS),
        "EPS_INTERSECT": float(EPS_INTERSECT),
        "cut_seed": int(CUT_SEED),
        "cut_radius": float(CUT_RADIUS),
        "cut_jitter": float(CUT_JITTER),
    }
}

def ensure_config_compatible(payload):
    assert int(payload["genus"]) == genus
    assert tuple(payload["meta"]["grid_shape"]) == (H, W)
    assert np.allclose(np.array(payload["grid_r"]), grid_r)
    assert np.allclose(np.array(payload["grid_i"]), grid_i)
    # If you want stricter checks, uncomment:
    # assert payload["meta"]["base_point"] == base_point
    # assert int(payload["meta"]["N_GAUSS"]) == int(N_GAUSS)


In [None]:

# @title (Optional) Check/lock branch cuts against saved files (resume safety)
import os, numpy as np, torch

AUTO_ADOPT_SAVED_CUTS = False  # set True if you want to overwrite current cuts from saved file

def pick_saved_path():
    cand = []
    if os.path.exists(INTEGRALS_PATH):
        cand.append(INTEGRALS_PATH)
    if os.path.exists(OMEGAS_PATH):
        cand.append(OMEGAS_PATH)
    if not cand:
        return None
    cand.sort(key=lambda p: os.path.getmtime(p), reverse=True)
    return cand[0]

def np_c128_from_list(lst):
    return np.array([complex(z) for z in lst], dtype=np.complex128)

saved_path = pick_saved_path()
if saved_path is None:
    print("No saved ω/Ι file found — fresh run.")
else:
    pkg = torch.load(saved_path, map_location="cpu", weights_only=False)
    saved_genus = int(pkg.get("genus"))
    saved_grid_r = np.array(pkg.get("grid_r"))
    saved_grid_i = np.array(pkg.get("grid_i"))
    saved_pts = np.array(pkg.get("branch_pts", []), dtype=np.complex128)
    saved_cuts = pkg.get("branch_cuts", None)

    curr_pts = np_c128_from_list(branch_pts)

    print(f"Comparing current session to: {os.path.basename(saved_path)}")
    ok_genus = (saved_genus == genus)
    ok_grid  = np.allclose(saved_grid_r, grid_r) and np.allclose(saved_grid_i, grid_i)
    ok_len   = (saved_pts.shape == curr_pts.shape)
    ok_pts   = ok_len and np.allclose(saved_pts, curr_pts)

    print(f"  genus match : {ok_genus} (saved={saved_genus}, current={genus})")
    print(f"  grid match  : {ok_grid}")
    if ok_len:
        max_dev = float(np.max(np.abs(saved_pts - curr_pts))) if saved_pts.size else 0.0
        print(f"  branch pts  : {'MATCH' if ok_pts else 'MISMATCH'} (max |Δ| = {max_dev:.3e})")
    else:
        print(f"  branch pts  : count differs (saved {saved_pts.size}, current {curr_pts.size})")

    if ok_genus and ok_grid and ok_pts:
        print("\n✅ SAFE to resume: cuts/points match what’s in Drive.")
    else:
        print("\n❌ MISMATCH detected — do NOT resume building tables with different cuts.")
        if AUTO_ADOPT_SAVED_CUTS and (saved_cuts is not None):
            branch_cuts = saved_cuts
            branch_pts  = [a for ab in branch_cuts for a in ab]
            BRANCH_PTS_MP = [mp.mpc(a.real, a.imag) for a in branch_pts]
            print("→ Adopted saved branch cuts/points into memory.")
        else:
            print("Tip: set AUTO_ADOPT_SAVED_CUTS=True to adopt the saved cuts.")



## Build ω table (optional; used for anchor selection in training)

This is the same idea as your original notebook: evaluate

\[
\omega_k(z) = \frac{z^k}{\sqrt{P(z)}}
\]

pointwise on the grid. (This is not an integral; it's just the differential value.)
We keep the key name `omega_plus` for backward compatibility.

> Note: the **magnitude** \(|\omega_k|\) is independent of the sign of \(\sqrt{P(z)}\), so this remains useful for your anchor heuristic even if the sign convention differs.


In [None]:

# @title Build ω (differentials) → Drive (resume-safe, computes all k per grid point)
import torch

def omega_allk_at_z(z):
    """Return [omega_0(z), ..., omega_{g-1}(z)] with principal sqrt."""
    z_mp = mp.mpc(z.real, z.imag)
    w = mp.sqrt(P_of_z(z_mp))
    inv_w = 1 / w
    out = []
    z_pow = mp.mpc(1)
    for k in range(genus):
        out.append(z_pow * inv_w)
        z_pow *= z_mp
    return out

# Initialize or resume
if os.path.exists(OMEGAS_PATH):
    pkg = torch.load(OMEGAS_PATH, map_location='cpu', weights_only=False)
    ensure_config_compatible(pkg)
    Om_plus = pkg["omega_plus"]  # (g,H,W) complex
    progress = pkg.get("progress", {"iy_done": 0})
    print("Resuming ω from Drive.")
else:
    Om_plus = torch.zeros(genus, H, W, dtype=torch.cfloat)
    progress = {"iy_done": 0}
    print("Starting fresh ω build.")

SAVE_EVERY_N_ROWS = 1  # safest; raise to reduce I/O

t0 = time.time()
for iy in range(int(progress["iy_done"]), H):
    y = grid_i[iy]
    row = np.zeros((genus, W), dtype=np.complex64)
    for ix in range(W):
        x = grid_r[ix]
        z = complex(x, y)
        vals = omega_allk_at_z(z)
        row[:, ix] = np.array([complex(v) for v in vals], dtype=np.complex64)

    Om_plus[:, iy, :] = torch.from_numpy(row)
    progress["iy_done"] = iy + 1

    if (iy % SAVE_EVERY_N_ROWS) == (SAVE_EVERY_N_ROWS - 1):
        payload = {**COMMON_META, "omega_plus": Om_plus, "progress": progress}
        atomic_torch_save(payload, OMEGAS_PATH)

# Final save
payload = {**COMMON_META, "omega_plus": Om_plus, "progress": progress}
atomic_torch_save(payload, OMEGAS_PATH)

t1 = time.time()
print(f"ω table saved to {OMEGAS_PATH}  | elapsed {t1 - t0:.1f}s")



## Build AJ integrals for both sheets

For each grid point \(z\):

1. Compute the straight-path lift integral (starting from the base point on sheet 0):
   - `I_reached(z)` and the **sheet parity** `parity(z) ∈ {0,1}`.
2. Using the bridge vector `B`, define the two-sheet tables:
   - If `parity(z)=0`, then `I0(z)=I_reached(z)` and `I1(z)=B-I0(z)`.
   - If `parity(z)=1`, then `I1(z)=I_reached(z)` and `I0(z)=B-I1(z)`.

This guarantees `I0(z) + I1(z) = B` everywhere (up to numerical error).


In [None]:

# @title Build I0/I1 (integrals) → Drive (resume-safe)
import torch

# Convert B to torch once for storage; keep python list for arithmetic
B_list = [complex(v) for v in B_mp]
B_torch = torch.tensor(B_list, dtype=torch.cfloat)

# Initialize or resume
if os.path.exists(INTEGRALS_PATH):
    pkg = torch.load(INTEGRALS_PATH, map_location='cpu', weights_only=False)
    ensure_config_compatible(pkg)

    I0 = pkg.get("I0", None)
    I1 = pkg.get("I1", None)

    # Backward compat: older runs might have I_plus only
    if I0 is None and "I_plus" in pkg:
        I0 = pkg["I_plus"]
    if I1 is None and "I_minus" in pkg:
        I1 = pkg["I_minus"]

    if I0 is None:
        I0 = torch.zeros(genus, H, W, dtype=torch.cfloat)
    if I1 is None:
        I1 = torch.zeros(genus, H, W, dtype=torch.cfloat)

    # Prefer saved B if present
    if "B" in pkg:
        B_torch = pkg["B"].to(torch.cfloat)
        B_list = [complex(v) for v in B_torch]
        print("Loaded B from saved file.")
    else:
        print("No B in saved file; using newly computed B (will be saved).")

    sheet_parity = pkg.get("sheet_parity", torch.zeros(H, W, dtype=torch.int8))
    progress = pkg.get("progress", {"iy_done": 0})
    print("Resuming integrals from Drive.")
else:
    I0 = torch.zeros(genus, H, W, dtype=torch.cfloat)
    I1 = torch.zeros(genus, H, W, dtype=torch.cfloat)
    sheet_parity = torch.zeros(H, W, dtype=torch.int8)
    progress = {"iy_done": 0}
    print("Starting fresh integrals build.")

SAVE_EVERY_N_ROWS = 1  # safest

B_np = np.array(B_list, dtype=np.complex64)

t0 = time.time()
start_iy = int(progress["iy_done"])
for iy in range(start_iy, H):
    y = grid_i[iy]
    row0 = np.zeros((genus, W), dtype=np.complex64)
    row1 = np.zeros((genus, W), dtype=np.complex64)

    for ix in range(W):
        x = grid_r[ix]
        z = complex(x, y)

        I_reached, parity, n_cross = integrate_straight_with_cuts_allk(z)
        sheet_parity[iy, ix] = int(parity)

        I_reached_c = np.array([complex(v) for v in I_reached], dtype=np.complex64)

        if parity == 0:
            I0_c = I_reached_c
            I1_c = (B_np - I0_c)
        else:
            I1_c = I_reached_c
            I0_c = (B_np - I1_c)

        row0[:, ix] = I0_c
        row1[:, ix] = I1_c

    I0[:, iy, :] = torch.from_numpy(row0)
    I1[:, iy, :] = torch.from_numpy(row1)

    progress["iy_done"] = iy + 1

    if (iy % SAVE_EVERY_N_ROWS) == (SAVE_EVERY_N_ROWS - 1):
        payload = {
            **COMMON_META,
            "I0": I0,
            "I1": I1,
            "B": B_torch,
            "sheet_parity": sheet_parity,
            # Backward-compat keys:
            "I_plus": I0,
            "I_minus": I1,
            "progress": progress,
        }
        atomic_torch_save(payload, INTEGRALS_PATH)

# Final save
payload = {
    **COMMON_META,
    "I0": I0,
    "I1": I1,
    "B": B_torch,
    "sheet_parity": sheet_parity,
    "I_plus": I0,
    "I_minus": I1,
    "progress": progress,
}
atomic_torch_save(payload, INTEGRALS_PATH)

t1 = time.time()
print(f"I0/I1 tables saved to {INTEGRALS_PATH}  | elapsed {t1 - t0:.1f}s")

# Quick sanity check on a few random points
with torch.no_grad():
    B_view = B_torch.view(genus, 1, 1)
    err = (I0 + I1 - B_view).abs().max().item()
print("Max |I0 + I1 - B| over table:", err)



## Notes / knobs you may want to tune

- **`mp.mp.dps`**: raise precision if you see instability near branch points.
- **`N_GAUSS`**: raise quadrature nodes if you want more accurate integrals.
- **Grid size** (`H`,`W`): larger grids cost more but improve bilinear sampling fidelity.
- **Cut geometry** (`CUT_RADIUS`, `CUT_JITTER`, `CUT_SEED`): keep fixed across runs if you want reproducibility.
