<a href="https://colab.research.google.com/github/gift-framework/GIFT/blob/research/notebooks/colab_atlas_g2_metric.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GIFT Atlas G₂ Metric — 3-Chart TCS Construction + Spectral Bridge

**Version**: A1 (2026-02-13)
**Hardware**: A100 GPU (Colab)
**Duration**: ~60 min total

## What This Notebook Does
1. **Phase 1**: Train 3 chart PINNs independently (Neck + Bulk_L + Bulk_R)
2. **Phase 2**: Schwarz alternating iterations (match at interfaces)
3. **Phase 3**: Joint fine-tuning with global torsion penalty
4. **Phase 4**: Spectral bridge — Laplacian eigenvalues on K₇

## The Atlas Architecture
```
K₇ = Bulk_L ∪ Neck ∪ Bulk_R

Bulk_L: ACyl CY₃(Quintic) × S¹    b₂=11, b₃=40
Neck:   K3 × T² × [δ, L-δ]        matching region
Bulk_R: ACyl CY₃(CI(2,2,2)) × S¹  b₂=10, b₃=37
```

## Key GIFT Predictions Tested
- det(g) = 65/32 everywhere (all charts)
- Torsion ≪ Joyce threshold (0.0288) — now on K₇, not T⁷
- λ₁ × H* = dim(G₂) = 14 (master equation)
- λₙ ∝ γₙ (Riemann zeta zero correspondence)

## What's New vs P1/P2
- **3 separate coordinate charts** with explicit TCS transition maps
- **Kovalev twist** at right interface (circle swap + HK rotation)
- **Asymptotic decay constraints** in bulk regions
- **Cholesky-level blending** preserving positive definiteness
- Topology is now K₇ (b₃=77), not T⁷ (b₃=35)

In [1]:
# Setup
import os, sys, time, math, json
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float64  # double precision for metric accuracy

print(f"PyTorch {torch.__version__}")
print(f"Device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    props = torch.cuda.get_device_properties(0)
    print(f"Memory: {props.total_memory / 1e9:.1f} GB")

torch.manual_seed(42)
np.random.seed(42)

PyTorch 2.9.0+cu128
Device: cuda
GPU: NVIDIA A100-SXM4-80GB
Memory: 85.1 GB


## 1. GIFT Topological Constants (Lean-verified)

In [2]:
# === Topological constants (zero free parameters) ===
B2 = 21                        # b₂(K₇)
B3 = 77                        # b₃(K₇)
DIM_G2 = 14                   # dim(G₂) holonomy
DIM_E8 = 248                  # dim(E₈)
RANK_E8 = 8                   # rank(E₈)
H_STAR = B2 + B3 + 1          # = 99
DET_G_TARGET = 65 / 32        # = 2.03125
TORSION_THRESHOLD = 0.0288    # Joyce ε₀
SCALE_FACTOR = (65 / 32) ** (1 / 14)  # ≈ 1.0543

# === TCS building block Betti numbers ===
B2_M1 = 11                    # b₂(Quintic)
B3_M1 = 40                    # b₃(Quintic)
B2_M2 = 10                    # b₂(CI(2,2,2))
B3_M2 = 37                    # b₃(CI(2,2,2))

# === Spectral predictions ===
LAMBDA_1_PREDICTED = DIM_G2 / H_STAR  # = 14/99 ≈ 0.14141

# === Domain layout (coord[0] = t parameter) ===
# Overlap widths chosen for smooth Schwarz convergence
DOMAIN_BULK_L  = (0.00, 0.40)
DOMAIN_NECK    = (0.25, 0.75)
DOMAIN_BULK_R  = (0.60, 1.00)
OVERLAP_LEFT   = (0.25, 0.40)  # Bulk_L ∩ Neck
OVERLAP_RIGHT  = (0.60, 0.75)  # Neck ∩ Bulk_R

# === ACyl decay rate ===
MU_DECAY = 0.8  # exponential decay rate for ACyl CY metrics

print("GIFT Atlas Constants:")
print(f"  b₂={B2}, b₃={B3}, H*={H_STAR}, dim(G₂)={DIM_G2}")
print(f"  det(g) target = {DET_G_TARGET}")
print(f"  M₁ (Quintic): b₂={B2_M1}, b₃={B3_M1}")
print(f"  M₂ (CI(2,2,2)): b₂={B2_M2}, b₃={B3_M2}")
print(f"  Mayer-Vietoris check: {B2_M1}+{B2_M2}={B2_M1+B2_M2} = b₂={B2} ✓")
print(f"  Mayer-Vietoris check: {B3_M1}+{B3_M2}={B3_M1+B3_M2} = b₃={B3} ✓")
print(f"  Predicted λ₁ = {DIM_G2}/{H_STAR} = {LAMBDA_1_PREDICTED:.6f}")
print(f"  Domain layout: L=[{DOMAIN_BULK_L}] N=[{DOMAIN_NECK}] R=[{DOMAIN_BULK_R}]")

GIFT Atlas Constants:
  b₂=21, b₃=77, H*=99, dim(G₂)=14
  det(g) target = 2.03125
  M₁ (Quintic): b₂=11, b₃=40
  M₂ (CI(2,2,2)): b₂=10, b₃=37
  Mayer-Vietoris check: 11+10=21 = b₂=21 ✓
  Mayer-Vietoris check: 40+37=77 = b₃=77 ✓
  Predicted λ₁ = 14/99 = 0.141414
  Domain layout: L=[(0.0, 0.4)] N=[(0.25, 0.75)] R=[(0.6, 1.0)]


## 2. G₂ 3-Form, Generators, and TCS Coordinate System

In [3]:
# Standard associative 3-form on R⁷ (from G2Holonomy.lean)
# φ₀ = e¹²³ + e¹⁴⁵ + e¹⁶⁷ + e²⁴⁶ − e²⁵⁷ − e³⁴⁷ − e³⁵⁶
#
# TCS coordinate convention:
#   coord[0] = t   (neck/cylindrical parameter, maps to radial direction)
#   coord[1] = θ   (fiber S¹)
#   coord[2:6] = (u₁, u₂, u₃, u₄)  (K3 fiber coordinates)
#   coord[6] = ψ   (outer S¹)
#
# The G₂ 3-form in TCS product form:
#   φ = dψ ∧ ω_K3 + dψ ∧ dr ∧ dθ + Re(Ω_K3) ∧ dr − Im(Ω_K3) ∧ dθ
# In frame: φ = e⁰¹² + e⁰³⁴ + e⁰⁵⁶ + e¹³⁵ − e¹⁴⁶ − e²³⁶ − e²⁴⁵
# (with our coordinate ordering: 0=t, 1=θ, 2=u₁, 3=u₂, 4=u₃, 5=u₄, 6=ψ)

STANDARD_G2_FORM = [
    ((0, 1, 2), +1.0),
    ((0, 3, 4), +1.0),
    ((0, 5, 6), +1.0),
    ((1, 3, 5), +1.0),
    ((1, 4, 6), -1.0),
    ((2, 3, 6), -1.0),
    ((2, 4, 5), -1.0),
]

FANO_LINES = [(0, 1, 3), (1, 2, 4), (2, 3, 5), (3, 4, 6), (4, 5, 0), (5, 6, 1), (6, 0, 2)]


def build_phi0_tensor():
    """Full 7×7×7 antisymmetric tensor for the standard G₂ form."""
    phi0 = np.zeros((7, 7, 7), dtype=np.float64)
    for (indices, sign) in STANDARD_G2_FORM:
        i, j, k = indices
        phi0[i, j, k] = sign;  phi0[j, k, i] = sign;  phi0[k, i, j] = sign
        phi0[j, i, k] = -sign; phi0[i, k, j] = -sign; phi0[k, j, i] = -sign
    return phi0


PHI0_TENSOR = build_phi0_tensor()


def phi0_components(normalize=True):
    """Standard G₂ 3-form as 35 independent components."""
    phi0 = np.zeros(35, dtype=np.float64)
    idx = 0
    for i in range(7):
        for j in range(i + 1, 7):
            for k in range(j + 1, 7):
                phi0[idx] = PHI0_TENSOR[i, j, k]
                idx += 1
    if normalize:
        phi0 *= SCALE_FACTOR
    return phi0


def g2_generators():
    """14 generators of G₂ ⊂ so(7)."""
    gens = np.zeros((14, 7, 7), dtype=np.float64)
    for idx, (i, j, k) in enumerate(FANO_LINES):
        gens[idx, i, j] = 1;  gens[idx, j, i] = -1
    for idx in range(7):
        i, j, k = idx, (idx + 1) % 7, (idx + 3) % 7
        gens[7 + idx, i, k] = 1;    gens[7 + idx, k, i] = -1
        gens[7 + idx, j, k] = 0.5;  gens[7 + idx, k, j] = -0.5
    for idx in range(14):
        norm = np.linalg.norm(gens[idx])
        if norm > 1e-10:
            gens[idx] /= norm
    return gens


def compute_lie_derivatives():
    """Lie derivatives (L_Xa φ₀) for each G₂ generator. Shape: (14, 35)."""
    generators = g2_generators()
    lie_derivs = np.zeros((14, 35), dtype=np.float64)
    for a in range(14):
        X = generators[a]
        idx = 0
        for i in range(7):
            for j in range(i + 1, 7):
                for k in range(j + 1, 7):
                    val = 0.0
                    for l in range(7):
                        val += X[i, l] * PHI0_TENSOR[l, j, k]
                        val += X[j, l] * PHI0_TENSOR[i, l, k]
                        val += X[k, l] * PHI0_TENSOR[i, j, l]
                    lie_derivs[a, idx] = val
                    idx += 1
    return lie_derivs


LIE_DERIVATIVES = compute_lie_derivatives()
print(f"G₂ generators: 14 matrices in so(7)")
print(f"Lie derivative matrix: {LIE_DERIVATIVES.shape}")

G₂ generators: 14 matrices in so(7)
Lie derivative matrix: (14, 35)


## 3. TCS Transition Maps (Kovalev Twist)

The Kovalev twist at the right interface swaps the two S¹ factors
and applies a hyper-Kähler rotation (I ↦ J) on the K3 fiber.

**Left overlap** (Bulk_L ∩ Neck): identity on all coordinates.
**Right overlap** (Neck → Bulk_R transition):
  - t: shared global coordinate (no reversal needed)
  - θ_bulk = ψ_neck      (circle swap)
  - ψ_bulk = θ_neck      (circle swap)
  - u_bulk = R⁻¹ · u_neck (inverse HK rotation on K3 fiber)

For interface matching, we transform neck coordinates to bulk_R
coordinates using R⁻¹, then pull back the bulk_R metric through
the transition Jacobian J to compare in neck coordinates.

In [4]:
def build_kovalev_twist_jacobian():
    """
    7×7 Jacobian of the Neck → Bulk_R transition map.

    Coordinate ordering: (t, θ, u₁, u₂, u₃, u₄, ψ)
    Indices:             (0, 1,  2,  3,  4,  5,  6)

    The transition map (neck chart → bulk_R chart):
      t_bulk = t_neck          → ∂t_bulk/∂t_neck = +1  (shared global t)
      θ_bulk = ψ_neck          → ∂θ_bulk/∂ψ_neck = +1  (circle swap)
      u_bulk = R⁻¹_HK · u_neck → 4×4 inverse HK rotation
      ψ_bulk = θ_neck          → ∂ψ_bulk/∂θ_neck = +1  (circle swap)

    Note: t is a shared global coordinate — no reversal needed.
    The Kovalev twist only affects transverse coordinates (θ, ψ, K3).
    """
    J = np.zeros((7, 7), dtype=np.float64)

    # t → t (shared global coordinate, no reversal)
    J[0, 0] = 1.0

    # θ_bulk = ψ_neck (circle swap: coord 1 ← coord 6)
    J[1, 6] = 1.0

    # K3 inverse HK rotation: R⁻¹ where R: (u₁,u₂,u₃,u₄) → (u₃,−u₁,u₄,u₂)
    # R⁻¹: (u₁,u₂,u₃,u₄)_neck → (−u₂,u₄,u₁,u₃)_bulk
    J[2, 3] = -1.0   # u₁_bulk = −u₂_neck
    J[3, 5] = 1.0    # u₂_bulk = u₄_neck
    J[4, 2] = 1.0    # u₃_bulk = u₁_neck
    J[5, 4] = 1.0    # u₄_bulk = u₃_neck

    # ψ_bulk = θ_neck (circle swap: coord 6 ← coord 1)
    J[6, 1] = 1.0

    return J


KOVALEV_JACOBIAN = build_kovalev_twist_jacobian()
KOVALEV_JACOBIAN_T = torch.from_numpy(KOVALEV_JACOBIAN).to(dtype)

# Verify it's orthogonal (det = ±1, J^T J = I)
_JTJ = KOVALEV_JACOBIAN @ KOVALEV_JACOBIAN.T
_det_J = np.linalg.det(KOVALEV_JACOBIAN)
print(f"\nKovalev twist Jacobian:")
print(f"  det(J) = {_det_J:.0f} (should be ±1)")
print(f"  ||J^T J - I|| = {np.linalg.norm(_JTJ - np.eye(7)):.2e} (should be ~0)")


def kovalev_twist_coords(x_neck):
    """
    Apply Kovalev transition map: Neck coordinates → Bulk_R coordinates.

    x_neck: (N, 7) tensor in Neck chart
    Returns: (N, 7) tensor in Bulk_R chart coordinates

    The transition map applies:
      t: unchanged (shared global coordinate)
      circles: θ_bulk = ψ_neck, ψ_bulk = θ_neck (swap)
      K3: u_bulk = R⁻¹ · u_neck (inverse hyper-Kahler rotation)
    """
    x_bulk = x_neck.clone()
    # t unchanged (both charts use global t)
    # x_bulk[:, 0] already correct from clone

    # Circle swap: θ_bulk ← ψ_neck, ψ_bulk ← θ_neck
    x_bulk[:, 1] = x_neck[:, 6]
    x_bulk[:, 6] = x_neck[:, 1]

    # K3 inverse HK rotation: (u₁,u₂,u₃,u₄)_neck → (−u₂,u₄,u₁,u₃)_bulk
    x_bulk[:, 2] = torch.fmod(-x_neck[:, 3] + 1.0, 1.0)  # u₁_bulk = −u₂_neck (mod 1)
    x_bulk[:, 3] = x_neck[:, 5]                            # u₂_bulk = u₄_neck
    x_bulk[:, 4] = x_neck[:, 2]                            # u₃_bulk = u₁_neck
    x_bulk[:, 5] = x_neck[:, 4]                            # u₄_bulk = u₃_neck

    return x_bulk


def transform_metric_kovalev(g_bulk_R, device_ref=None):
    """
    Pull back bulk_R metric to neck coordinates via the transition Jacobian.
    g_neck_ab = J^T_ai g_bulk_ij J_jb

    Since J is orthogonal (|det|=1), this preserves eigenvalues and det.

    g_bulk_R: (N, 7, 7) metric in bulk_R chart
    Returns: (N, 7, 7) in neck coordinates
    """
    dev = device_ref or g_bulk_R.device
    J = KOVALEV_JACOBIAN_T.to(dev)
    # g' = J^T @ g @ J (batched)
    return torch.einsum('ai,nij,jb->nab', J.T, g_bulk_R, J)


# Quick test
with torch.no_grad():
    _x_test = torch.rand(10, 7, device=device, dtype=dtype)
    _x_twisted = kovalev_twist_coords(_x_test)
    print(f"  Twist test: input range [{_x_test.min():.2f}, {_x_test.max():.2f}]")
    print(f"             output range [{_x_twisted.min():.2f}, {_x_twisted.max():.2f}]")
    del _x_test, _x_twisted


Kovalev twist Jacobian:
  det(J) = -1 (should be ±1)
  ||J^T J - I|| = 0.00e+00 (should be ~0)
  Twist test: input range [0.01, 0.98]
             output range [0.01, 0.98]


## 4. Neural Network Architectures

Three PINNs, one per chart:
- **NeckPINN**: G₂ adjoint parameterization (14 DOF → 35 3-form components)
- **BulkCYPINN_L**: Cholesky-parameterized for ACyl CY₃(Quintic) × S¹
- **BulkCYPINN_R**: Cholesky-parameterized for ACyl CY₃(CI(2,2,2)) × S¹

In [5]:
class FourierFeatures(nn.Module):
    """Random Fourier features with periodic-aware encoding.

    For S¹ coordinates (θ, ψ = coords 1, 6), uses deterministic harmonics.
    For remaining coords, uses random Fourier features.
    """
    def __init__(self, input_dim=7, num_random_freq=24, num_periodic_freq=8, scale=1.0):
        super().__init__()
        # Random features for non-periodic coords (t, u₁..u₄)
        n_nonperiodic = input_dim - 2  # 5 non-periodic coords
        B_rand = torch.randn(num_random_freq, n_nonperiodic, dtype=dtype) * scale
        self.register_buffer('B_rand', B_rand)

        # Deterministic harmonics for S¹ coords (θ, ψ)
        k = torch.arange(1, num_periodic_freq + 1, dtype=dtype).unsqueeze(1)
        self.register_buffer('k_periodic', k)

        self.output_dim = 2 * num_random_freq + 4 * num_periodic_freq
        self.num_periodic_freq = num_periodic_freq

    def forward(self, x):
        """x: (N, 7) → features: (N, output_dim)."""
        # Non-periodic coords: t, u₁, u₂, u₃, u₄ (indices 0, 2, 3, 4, 5)
        x_nonper = torch.stack([x[:, 0], x[:, 2], x[:, 3], x[:, 4], x[:, 5]], dim=1)
        proj = 2 * math.pi * x_nonper @ self.B_rand.T
        feat_rand = torch.cat([torch.cos(proj), torch.sin(proj)], dim=-1)

        # Periodic coords: θ = x[:,1], ψ = x[:,6]
        theta = x[:, 1:2] * 2 * math.pi  # scale to [0, 2π)
        psi = x[:, 6:7] * 2 * math.pi

        # Harmonics: sin(k·θ), cos(k·θ), sin(k·ψ), cos(k·ψ)
        k = self.k_periodic.T  # (1, num_freq)
        feat_periodic = torch.cat([
            torch.sin(theta * k), torch.cos(theta * k),
            torch.sin(psi * k), torch.cos(psi * k),
        ], dim=-1)

        return torch.cat([feat_rand, feat_periodic], dim=-1)


class NeckPINN(nn.Module):
    """
    G₂-native PINN for the TCS neck region.

    Architecture:
        7D coords → Fourier → MLP(256×4) → 14 G₂ adjoint params
        → Lie derivatives → 35 3-form components

    Ansatz: φ(x) = φ₀ + perturbation_scale × δφ(x)
    """
    def __init__(self, num_random_freq=24, num_periodic_freq=8,
                 hidden_dims=None, perturbation_scale=0.2):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = [256, 256, 256, 128]
        self.perturbation_scale = perturbation_scale

        self.register_buffer('phi0', torch.from_numpy(phi0_components(True)).to(dtype))
        self.register_buffer('lie_derivs', torch.from_numpy(LIE_DERIVATIVES).to(dtype))

        self.fourier = FourierFeatures(7, num_random_freq, num_periodic_freq)
        layers = []
        in_dim = self.fourier.output_dim
        for h in hidden_dims:
            layers += [nn.Linear(in_dim, h, dtype=dtype), nn.SiLU()]
            in_dim = h
        layers.append(nn.Linear(in_dim, 14, dtype=dtype))
        self.mlp = nn.Sequential(*layers)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=1.0)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        """x: (N,7) → φ: (N,35) 3-form components."""
        adj = self.mlp(self.fourier(x))
        delta_phi = adj @ self.lie_derivs
        return self.phi0.unsqueeze(0) + self.perturbation_scale * delta_phi

    def phi_tensor(self, x):
        """x: (N,7) → full antisymmetric 3-tensor (N,7,7,7)."""
        comp = self.forward(x)
        N = comp.shape[0]
        phi = torch.zeros(N, 7, 7, 7, device=x.device, dtype=x.dtype)
        idx = 0
        for i in range(7):
            for j in range(i + 1, 7):
                for k in range(j + 1, 7):
                    v = comp[:, idx]
                    phi[:, i, j, k] = v;  phi[:, j, k, i] = v;  phi[:, k, i, j] = v
                    phi[:, j, i, k] = -v; phi[:, i, k, j] = -v; phi[:, k, j, i] = -v
                    idx += 1
        return phi

    def metric(self, x):
        """g_ij = (1/6) Σ_{k,l} φ_ikl φ_jkl. Returns (N,7,7)."""
        phi = self.phi_tensor(x)
        return torch.einsum('nikl,njkl->nij', phi, phi) / 6.0

    def det_g(self, x):
        return torch.linalg.det(self.metric(x))

    def param_count(self):
        return sum(p.numel() for p in self.parameters())


class BulkCYPINN(nn.Module):
    """
    Cholesky-parameterized PINN for ACyl CY₃ × S¹ bulk regions.

    Metric: g = (L₀ + perturbation_scale × δL(x)) @ (...)ᵀ
    Guaranteed positive-definite by construction.

    Warm-start: L₀ encodes the asymptotic CY product metric.
    """
    def __init__(self, num_random_freq=24, num_periodic_freq=8,
                 hidden_dims=None, perturbation_scale=0.2,
                 base_diag=None):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = [256, 256, 256, 128]
        self.perturbation_scale = perturbation_scale

        # Base Cholesky factor for CY product metric: diag(1, 1, α, α, α, α, 1)
        # where α = (65/32)^(1/14) to get det(g) = 65/32
        if base_diag is None:
            base_diag = [1.0, 1.0, SCALE_FACTOR, SCALE_FACTOR,
                         SCALE_FACTOR, SCALE_FACTOR, 1.0]
        L0 = torch.zeros(7, 7, dtype=dtype)
        for i in range(7):
            L0[i, i] = base_diag[i]

        tril_idx = torch.tril_indices(7, 7)
        self.register_buffer('tril_row', tril_idx[0])
        self.register_buffer('tril_col', tril_idx[1])
        self.register_buffer('L0_flat', L0[tril_idx[0], tril_idx[1]])

        self.fourier = FourierFeatures(7, num_random_freq, num_periodic_freq)
        layers = []
        in_dim = self.fourier.output_dim
        for h in hidden_dims:
            layers += [nn.Linear(in_dim, h, dtype=dtype), nn.SiLU()]
            in_dim = h
        layers.append(nn.Linear(in_dim, 28, dtype=dtype))  # 28 lower-tri elements
        self.mlp = nn.Sequential(*layers)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.01)  # small init
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def cholesky_factor(self, x):
        """x: (N,7) → L: (N,7,7) lower triangular with positive diagonal."""
        N = x.shape[0]
        delta_L = self.mlp(self.fourier(x))  # (N, 28)
        L = torch.zeros(N, 7, 7, device=x.device, dtype=x.dtype)
        for i in range(28):
            val = self.L0_flat[i] + self.perturbation_scale * delta_L[:, i]
            r, c = self.tril_row[i].item(), self.tril_col[i].item()
            if r == c:
                val = F.softplus(val)  # Enforce positive diagonal
            L[:, r, c] = val
        return L

    def metric(self, x):
        """g = L @ Lᵀ — guaranteed positive definite."""
        L = self.cholesky_factor(x)
        return torch.bmm(L, L.transpose(1, 2))

    def det_g(self, x):
        """det(g) = (∏ L_ii)²."""
        L = self.cholesky_factor(x)
        log_det = 2 * torch.sum(torch.log(torch.diagonal(L, dim1=1, dim2=2)), dim=1)
        return torch.exp(log_det)

    def param_count(self):
        return sum(p.numel() for p in self.parameters())


# Instantiate and count parameters
_neck = NeckPINN().to(device)
_bulk = BulkCYPINN().to(device)
print(f"\nArchitectures:")
print(f"  NeckPINN:   {_neck.param_count():,} params")
print(f"  BulkCYPINN: {_bulk.param_count():,} params")
print(f"  Total (3 charts): ~{(_neck.param_count() + 2 * _bulk.param_count()):,} params")
del _neck, _bulk


Architectures:
  NeckPINN:   187,022 params
  BulkCYPINN: 188,828 params
  Total (3 charts): ~564,678 params


## 5. Analytical Warm-Start Targets

Each chart has its own analytical metric target for supervised warm-start.
This breaks the trivial-solution trap (where constant φ₀ satisfies all losses at zero).

- **Neck**: K3 Kummer bumps + neck warping + torus coupling
- **Bulk_L**: ACyl CY product metric with exponential approach to cylinder
- **Bulk_R**: Same as Bulk_L but with Kovalev twist signature

In [6]:
def analytical_neck_metric(x):
    """
    Analytical TCS metric for the neck region [0,1]⁷.

    Encodes K₇ = (K3 × T²) × I structure:
    - x[:,0] = t (neck parameter)
    - x[:,1] = θ (fiber S¹)
    - x[:,2:6] = (u₁..u₄) K3 coords with Kummer bumps
    - x[:,6] = ψ (outer S¹)
    """
    N = x.shape[0]
    g = torch.zeros(N, 7, 7, device=x.device, dtype=x.dtype)

    t = 2 * x[:, 0] - 1  # map [0,1] → [-1,1]

    # dt² with neck warping (Gaussian bump at center = gluing region)
    g[:, 0, 0] = 1.0 + 0.15 * torch.exp(-t ** 2 / 0.18)

    # K3 Kummer metric (coords 2-5): identity + bumps at 16 orbifold fixed points
    a_sq = 0.005
    k3_factor = torch.ones(N, device=x.device, dtype=x.dtype)
    for b0 in range(2):
        for b1 in range(2):
            for b2 in range(2):
                for b3 in range(2):
                    fp = torch.tensor([b0 * 0.5, b1 * 0.5, b2 * 0.5, b3 * 0.5],
                                      device=x.device, dtype=x.dtype)
                    dx = x[:, 2:6] - fp.unsqueeze(0)
                    dx = dx - torch.round(dx)  # periodic BC
                    r2 = (dx ** 2).sum(dim=1)
                    k3_factor = k3_factor + 0.3 * torch.exp(-r2 / (2 * a_sq))

    for i in range(4):
        g[:, 2 + i, 2 + i] = k3_factor

    # S¹ factors (θ, ψ) with mild t-dependence
    g[:, 1, 1] = 1.0 + 0.05 * torch.cos(2 * math.pi * x[:, 0])
    g[:, 6, 6] = 1.0 + 0.05 * torch.sin(2 * math.pi * x[:, 0])

    # Off-diagonal: dt-dθ coupling near neck ends (TCS signature)
    coupling = 0.03 * torch.exp(-((t.abs() - 0.8) ** 2) / 0.1)
    g[:, 0, 1] = coupling;  g[:, 1, 0] = coupling

    # Normalize det(g) = 65/32
    det_raw = torch.linalg.det(g)
    alpha = (DET_G_TARGET / det_raw.abs().clamp(min=1e-10)) ** (1.0 / 14)
    g = g * (alpha ** 2).unsqueeze(1).unsqueeze(2)

    return g


def analytical_bulk_metric(x, decay_rate=MU_DECAY, side='left'):
    """
    Analytical ACyl CY₃ × S¹ product metric for a bulk region.

    In the cylindrical end, the metric approaches:
      g → dt² + dθ² + g_K3(u) + dψ²
    with exponential corrections O(e^{-μ|t-t_interface|}).

    side='left':  t close to DOMAIN_BULK_L, interface at upper boundary
    side='right': t close to DOMAIN_BULK_R, interface at lower boundary
    """
    N = x.shape[0]
    g = torch.zeros(N, 7, 7, device=x.device, dtype=x.dtype)

    t = x[:, 0]  # in [0, 1] domain

    # Distance from interface (where bulk meets neck)
    if side == 'left':
        dist_from_interface = (OVERLAP_LEFT[1] - t).clamp(min=0)
    else:
        dist_from_interface = (t - OVERLAP_RIGHT[0]).clamp(min=0)

    # Exponential approach to product metric as we move away from interface
    decay = torch.exp(-decay_rate * dist_from_interface * 10)  # scale by 10 for unit-interval

    # dt² (cylindrical radial)
    g[:, 0, 0] = 1.0 + 0.08 * decay

    # dθ² (fiber S¹)
    g[:, 1, 1] = 1.0 + 0.04 * decay * torch.cos(2 * math.pi * x[:, 1])

    # K3 fiber (coords 2-5): simplified Eguchi-Hanson-like structure
    # In the bulk, K3 has full Ricci-flat metric (approximated here)
    a_sq = 0.008
    k3_base = torch.ones(N, device=x.device, dtype=x.dtype)
    # 16 Kummer centers (same as neck for matching)
    for b0 in range(2):
        for b1 in range(2):
            for b2 in range(2):
                for b3 in range(2):
                    fp = torch.tensor([b0 * 0.5, b1 * 0.5, b2 * 0.5, b3 * 0.5],
                                      device=x.device, dtype=x.dtype)
                    dx = x[:, 2:6] - fp.unsqueeze(0)
                    dx = dx - torch.round(dx)
                    r2 = (dx ** 2).sum(dim=1)
                    k3_base = k3_base + 0.25 * torch.exp(-r2 / (2 * a_sq))

    for i in range(4):
        g[:, 2 + i, 2 + i] = k3_base

    # dψ² (outer S¹)
    g[:, 6, 6] = 1.0 + 0.03 * decay * torch.sin(2 * math.pi * x[:, 6])

    # Mild off-diagonal coupling (decays away from interface)
    od_coupling = 0.02 * decay
    g[:, 0, 1] = od_coupling;  g[:, 1, 0] = od_coupling

    # Normalize per-sample to det = 65/32
    det_raw = torch.linalg.det(g)
    alpha = (DET_G_TARGET / det_raw.abs().clamp(min=1e-10)) ** (1.0 / 14)
    g = g * (alpha ** 2).unsqueeze(1).unsqueeze(2)

    return g


# Verify analytical metrics
with torch.no_grad():
    _x = torch.rand(500, 7, device=device, dtype=dtype)
    for name, fn in [("Neck", analytical_neck_metric),
                     ("Bulk_L", lambda x: analytical_bulk_metric(x, side='left')),
                     ("Bulk_R", lambda x: analytical_bulk_metric(x, side='right'))]:
        _g = fn(_x)
        _d = torch.linalg.det(_g)
        _e = torch.linalg.eigvalsh(_g)
        print(f"  {name}: det={_d.mean():.6f}  min_eig={_e.min():.4f}  "
              f"max_eig={_e.max():.4f}  PD={(_e.min() > 0).item()}")
    del _x, _g, _d, _e

  Neck: det=2.031250  min_eig=0.9606  max_eig=1.2568  PD=True
  Bulk_L: det=2.031250  min_eig=0.9702  max_eig=1.1973  PD=True
  Bulk_R: det=2.031250  min_eig=0.9758  max_eig=1.1975  PD=True


## 6. Loss Functions (Extended for Atlas)

In [7]:
def det_loss(model, x, target=DET_G_TARGET):
    """L_det = ⟨(det(g) - 65/32)²⟩."""
    return torch.mean((model.det_g(x) - target) ** 2)


def torsion_loss_fast(model, x, n_components=14):
    """Approximate torsion via autograd on 14 sampled 3-form components."""
    x = x.clone().requires_grad_(True)
    phi = model(x)
    torsion = torch.tensor(0.0, device=x.device, dtype=x.dtype)
    indices = torch.linspace(0, 34, n_components, dtype=torch.long)
    for i in indices:
        grad = torch.autograd.grad(
            phi[:, i].sum(), x, create_graph=True, retain_graph=True
        )[0]
        torsion = torsion + torch.mean(grad ** 2)
    return torsion / n_components


def torsion_loss_full(model, x):
    """Full torsion: all 35 3-form components."""
    x = x.clone().requires_grad_(True)
    phi = model(x)
    torsion = torch.tensor(0.0, device=x.device, dtype=x.dtype)
    for i in range(35):
        grad = torch.autograd.grad(
            phi[:, i].sum(), x, create_graph=True, retain_graph=True
        )[0]
        torsion = torsion + torch.mean(grad ** 2)
    return torsion / 35.0


def positive_definite_loss(model, x):
    """Penalize negative eigenvalues of g."""
    g = model.metric(x)
    eigs = torch.linalg.eigvalsh(g)
    return torch.mean(torch.relu(-eigs) ** 2)


def smoothness_loss(model, x, n_pairs=7):
    """Penalize large metric gradients."""
    x = x.clone().requires_grad_(True)
    g = model.metric(x)
    loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)
    for i in range(n_pairs):
        grad = torch.autograd.grad(
            g[:, i, i].sum(), x, create_graph=True, retain_graph=True
        )[0]
        loss = loss + torch.mean(grad ** 2)
    return loss / n_pairs


def ricci_proxy_loss(model, x, eps=0.01):
    """Proxy for Ricci-flatness via finite-difference curvature."""
    g0 = model.metric(x)
    loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)
    for d in range(7):
        x_p = x.clone(); x_p[:, d] += eps
        x_m = x.clone(); x_m[:, d] -= eps
        g_p = model.metric(x_p)
        g_m = model.metric(x_m)
        d2g = (g_p - 2 * g0 + g_m) / (eps ** 2)
        loss = loss + torch.mean(d2g ** 2)
    return loss / 7.0


def supervised_loss_neck(model, x):
    """MSE between neck PINN metric and analytical TCS target."""
    g_pinn = model.metric(x)
    with torch.no_grad():
        g_target = analytical_neck_metric(x)
    return torch.mean((g_pinn - g_target) ** 2)


def supervised_loss_bulk(model, x, side='left'):
    """MSE between bulk PINN metric and analytical CY target."""
    g_pinn = model.metric(x)
    with torch.no_grad():
        g_target = analytical_bulk_metric(x, side=side)
    return torch.mean((g_pinn - g_target) ** 2)


def interface_loss_cholesky(model_a, model_b, x_overlap, transform_fn=None):
    """
    Cholesky-level C⁰ matching at domain interface.

    If transform_fn is provided, apply it to model_b's metric before comparing.
    This handles the Kovalev twist at the right interface.
    """
    g_a = model_a.metric(x_overlap)
    if transform_fn is not None:
        x_transformed = transform_fn(x_overlap)
        g_b_raw = model_b.metric(x_transformed)
        g_b = transform_metric_kovalev(g_b_raw, device_ref=x_overlap.device)
    else:
        g_b = model_b.metric(x_overlap)
    return torch.mean((g_a - g_b) ** 2)


def asymptotic_decay_loss(model, x, side='left'):
    """
    Enforce exponential approach to product metric far from interface.

    The bulk metric should approach dt² + dθ² + g_K3 + dψ² exponentially
    as we move away from the interface region.
    """
    g = model.metric(x)
    t = x[:, 0]

    # Product metric: block-diagonal
    with torch.no_grad():
        g_product = analytical_bulk_metric(x, side=side)

    # Weight: stronger penalty far from interface
    if side == 'left':
        weight = torch.exp(-5 * (t - DOMAIN_BULK_L[0]))
    else:
        weight = torch.exp(-5 * (DOMAIN_BULK_R[1] - t))

    diff = (g - g_product) ** 2
    return torch.mean(weight.unsqueeze(1).unsqueeze(2) * diff)


def hermite_blend(t, t_lo, t_hi):
    """C¹ Hermite blend: 0 at t_lo, 1 at t_hi."""
    s = torch.clamp((t - t_lo) / (t_hi - t_lo), 0, 1)
    return 3 * s ** 2 - 2 * s ** 3


print("Loss functions defined.")

Loss functions defined.


## 7. Sampling Utilities

In [8]:
def sample_domain(n, domain, dim=7):
    """Sample points uniformly, restricting coord[0] to the given domain."""
    x = torch.rand(n, dim, device=device, dtype=dtype)
    x[:, 0] = x[:, 0] * (domain[1] - domain[0]) + domain[0]
    return x


def sample_overlap(n, overlap, dim=7):
    """Sample points in an overlap region."""
    return sample_domain(n, overlap, dim)

## 8. Phase 1 — Independent Chart Training

Train each chart's PINN independently:
- Neck: supervised warm-start → torsion-free physics
- Bulk_L, Bulk_R: supervised warm-start → Ricci-flat physics

No interface matching yet (that comes in Phase 2).

In [9]:
def train_chart(model, chart_name, analytical_target_fn, domain,
                is_neck=False, n_epochs=3000, batch_size=1024, verbose=True):
    """
    Train a single chart PINN with 3-phase curriculum:
    Phase A: Supervised warm-start from analytical target
    Phase B: Blended (supervised + physics)
    Phase C: Physics-dominated (det + torsion/Ricci + smoothness)
    """
    model.to(device)
    model.train()

    if is_neck:
        phases = [
            {'name': 'warm-start', 'epochs': int(0.2 * n_epochs), 'lr': 1e-3,
             'w_sup': 100.0, 'w_det': 0, 'w_phys': 0, 'w_smooth': 0},
            {'name': 'blend', 'epochs': int(0.3 * n_epochs), 'lr': 5e-4,
             'w_sup': 20.0, 'w_det': 50, 'w_phys': 0.5, 'w_smooth': 0.05},
            {'name': 'physics', 'epochs': int(0.5 * n_epochs), 'lr': 2e-4,
             'w_sup': 1.0, 'w_det': 100, 'w_phys': 1.0, 'w_smooth': 0.1},
        ]
    else:
        phases = [
            {'name': 'warm-start', 'epochs': int(0.25 * n_epochs), 'lr': 1e-3,
             'w_sup': 100.0, 'w_det': 0, 'w_phys': 0, 'w_smooth': 0},
            {'name': 'blend', 'epochs': int(0.35 * n_epochs), 'lr': 5e-4,
             'w_sup': 15.0, 'w_det': 50, 'w_phys': 0.5, 'w_smooth': 0.05},
            {'name': 'physics', 'epochs': int(0.40 * n_epochs), 'lr': 1e-4,
             'w_sup': 0.5, 'w_det': 100, 'w_phys': 1.0, 'w_smooth': 0.1},
        ]

    history = {'loss': [], 'det_err': [], 'phys': [], 'sup': [], 'phase': []}
    best_loss = float('inf')
    best_state = None
    total_ep = 0
    t_start = time.time()

    for phase in phases:
        pname = phase['name']
        n_ep = phase['epochs']
        if verbose:
            print(f"\n  [{chart_name}] Phase: {pname} | {n_ep} epochs | lr={phase['lr']}")

        optimizer = torch.optim.AdamW(model.parameters(), lr=phase['lr'], weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_ep)

        for ep in range(n_ep):
            x = sample_domain(batch_size, domain)
            optimizer.zero_grad()

            # Supervised loss
            L_sup = analytical_target_fn(model, x) if phase['w_sup'] > 0 else torch.tensor(0.0, device=device)

            # Physics losses
            L_det = det_loss(model, x) if phase['w_det'] > 0 else torch.tensor(0.0, device=device)
            if phase['w_phys'] > 0 and is_neck:
                L_phys = torsion_loss_fast(model, x)
            elif phase['w_phys'] > 0:
                L_phys = ricci_proxy_loss(model, x)
            else:
                L_phys = torch.tensor(0.0, device=device)
            L_smooth = smoothness_loss(model, x) if phase['w_smooth'] > 0 else torch.tensor(0.0, device=device)

            loss = (phase['w_sup'] * L_sup +
                    phase['w_det'] * L_det +
                    phase['w_phys'] * L_phys +
                    phase['w_smooth'] * L_smooth)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            history['loss'].append(loss.item())
            history['det_err'].append(L_det.item())
            history['phys'].append(L_phys.item())
            history['sup'].append(L_sup.item())
            history['phase'].append(pname)

            if loss.item() < best_loss:
                best_loss = loss.item()
                best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

            if verbose and (ep % 500 == 0 or ep == n_ep - 1):
                elapsed = time.time() - t_start
                print(f"    [{total_ep:5d}] loss={loss.item():.2e}  "
                      f"det={L_det.item():.2e}  phys={L_phys.item():.2e}  "
                      f"sup={L_sup.item():.2e}  ({elapsed:.0f}s)")
            total_ep += 1

    elapsed = time.time() - t_start
    if verbose:
        print(f"  [{chart_name}] Done: {total_ep} epochs in {elapsed:.1f}s, best={best_loss:.2e}")

    if best_state is not None:
        model.load_state_dict(best_state)
        model.to(device)

    return history

In [10]:
# === Create the 3 chart PINNs ===
print("\n" + "=" * 70)
print("  PHASE 1: Independent Chart Training")
print("=" * 70)

neck_pinn = NeckPINN(perturbation_scale=0.2).to(device)
bulk_L_pinn = BulkCYPINN(perturbation_scale=0.2).to(device)
bulk_R_pinn = BulkCYPINN(perturbation_scale=0.2).to(device)

print(f"Neck PINN: {neck_pinn.param_count():,} params")
print(f"Bulk L PINN: {bulk_L_pinn.param_count():,} params")
print(f"Bulk R PINN: {bulk_R_pinn.param_count():,} params")
total_params = neck_pinn.param_count() + bulk_L_pinn.param_count() + bulk_R_pinn.param_count()
print(f"Total: {total_params:,} params")

# Optional: load P2 checkpoint for neck warm-start
P2_CHECKPOINT = 'outputs/neck_pinn_seed42.pt'
if os.path.exists(P2_CHECKPOINT):
    try:
        ckpt = torch.load(P2_CHECKPOINT, map_location=device, weights_only=True)
        neck_pinn.load_state_dict(ckpt, strict=False)
        print(f"\n  Loaded P2 checkpoint: {P2_CHECKPOINT}")
        NECK_PRETRAINED = True
    except Exception as e:
        print(f"  Could not load P2 checkpoint: {e}")
        NECK_PRETRAINED = False
else:
    print(f"\n  No P2 checkpoint found at {P2_CHECKPOINT}, training from scratch")
    NECK_PRETRAINED = False


  PHASE 1: Independent Chart Training
Neck PINN: 187,022 params
Bulk L PINN: 188,828 params
Bulk R PINN: 188,828 params
Total: 564,678 params

  No P2 checkpoint found at outputs/neck_pinn_seed42.pt, training from scratch


In [11]:
# Train Neck
NECK_EPOCHS = 1500 if NECK_PRETRAINED else 3000
neck_history = train_chart(
    neck_pinn, "Neck", supervised_loss_neck, DOMAIN_NECK,
    is_neck=True, n_epochs=NECK_EPOCHS
)


  [Neck] Phase: warm-start | 600 epochs | lr=0.001
    [    0] loss=3.34e-02  det=0.00e+00  phys=0.00e+00  sup=3.34e-04  (5s)
    [  500] loss=3.20e-02  det=0.00e+00  phys=0.00e+00  sup=3.20e-04  (20s)
    [  599] loss=3.10e-02  det=0.00e+00  phys=0.00e+00  sup=3.10e-04  (24s)

  [Neck] Phase: blend | 900 epochs | lr=0.0005
    [  600] loss=6.61e-01  det=1.31e-02  phys=3.05e-03  sup=3.09e-04  (24s)
    [ 1100] loss=6.70e-03  det=1.51e-07  phys=8.42e-05  sup=3.33e-04  (282s)
    [ 1499] loss=6.46e-03  det=1.23e-07  phys=7.27e-05  sup=3.21e-04  (491s)

  [Neck] Phase: physics | 1500 epochs | lr=0.0002
    [ 1500] loss=4.01e-04  det=1.22e-07  phys=7.21e-05  sup=3.17e-04  (491s)
    [ 2000] loss=3.25e-04  det=2.65e-10  phys=7.60e-07  sup=3.24e-04  (746s)
    [ 2500] loss=3.29e-04  det=1.46e-10  phys=3.77e-07  sup=3.29e-04  (1003s)
    [ 2999] loss=3.30e-04  det=1.23e-10  phys=3.45e-07  sup=3.29e-04  (1259s)
  [Neck] Done: 3000 epochs in 1258.8s, best=3.10e-04


In [12]:
# Train Bulk Left
bulk_L_history = train_chart(
    bulk_L_pinn, "Bulk_L",
    lambda model, x: supervised_loss_bulk(model, x, side='left'),
    DOMAIN_BULK_L, is_neck=False, n_epochs=2500
)


  [Bulk_L] Phase: warm-start | 625 epochs | lr=0.001
    [    0] loss=6.57e+00  det=0.00e+00  phys=0.00e+00  sup=6.57e-02  (0s)
    [  500] loss=3.20e-03  det=0.00e+00  phys=0.00e+00  sup=3.20e-05  (7s)
    [  624] loss=3.38e-03  det=0.00e+00  phys=0.00e+00  sup=3.38e-05  (9s)

  [Bulk_L] Phase: blend | 875 epochs | lr=0.0005
    [  625] loss=1.64e-03  det=2.25e-05  phys=8.93e-05  sup=3.15e-05  (9s)
    [ 1125] loss=3.04e-02  det=5.91e-04  phys=4.61e-04  sup=4.00e-05  (141s)
    [ 1499] loss=5.05e-04  det=3.64e-09  phys=7.56e-07  sup=3.36e-05  (240s)

  [Bulk_L] Phase: physics | 1000 epochs | lr=0.0001
    [ 1500] loss=2.11e-05  det=3.49e-09  phys=7.43e-07  sup=4.00e-05  (240s)
    [ 2000] loss=1.42e-03  det=1.37e-05  phys=2.73e-05  sup=3.21e-05  (374s)
    [ 2499] loss=1.61e-05  det=2.90e-09  phys=5.83e-07  sup=3.05e-05  (508s)
  [Bulk_L] Done: 2500 epochs in 507.5s, best=1.55e-05


In [13]:
# Train Bulk Right
bulk_R_history = train_chart(
    bulk_R_pinn, "Bulk_R",
    lambda model, x: supervised_loss_bulk(model, x, side='right'),
    DOMAIN_BULK_R, is_neck=False, n_epochs=2500
)


  [Bulk_R] Phase: warm-start | 625 epochs | lr=0.001
    [    0] loss=6.57e+00  det=0.00e+00  phys=0.00e+00  sup=6.57e-02  (0s)
    [  500] loss=3.10e-03  det=0.00e+00  phys=0.00e+00  sup=3.10e-05  (7s)
    [  624] loss=3.30e-03  det=0.00e+00  phys=0.00e+00  sup=3.30e-05  (9s)

  [Bulk_R] Phase: blend | 875 epochs | lr=0.0005
    [  625] loss=1.39e-03  det=1.40e-05  phys=4.87e-05  sup=4.46e-05  (9s)
    [ 1125] loss=3.30e-02  det=6.37e-04  phys=1.16e-03  sup=3.77e-05  (143s)
    [ 1499] loss=5.85e-04  det=8.68e-09  phys=1.46e-06  sup=3.89e-05  (242s)

  [Bulk_R] Phase: physics | 1000 epochs | lr=0.0001
    [ 1500] loss=1.98e-05  det=7.45e-09  phys=1.44e-06  sup=3.52e-05  (242s)
    [ 2000] loss=1.84e-03  det=1.78e-05  phys=4.39e-05  sup=3.48e-05  (374s)
    [ 2499] loss=1.69e-05  det=7.30e-09  phys=1.27e-06  sup=2.98e-05  (506s)
  [Bulk_R] Done: 2500 epochs in 506.3s, best=1.64e-05


## 9. Phase 1 Validation (Per-Chart)

In [14]:
def validate_chart(model, domain, label, n_samples=20000, chunk_size=5000):
    """Validate a single chart PINN."""
    model.eval()
    all_det = []
    all_min_eig = []
    all_cond = []

    with torch.no_grad():
        for start in range(0, n_samples, chunk_size):
            end = min(start + chunk_size, n_samples)
            x = sample_domain(end - start, domain)
            g = model.metric(x)
            det_g = torch.linalg.det(g)
            eigs = torch.linalg.eigvalsh(g)
            all_det.append(det_g)
            all_min_eig.append(eigs.min(dim=1).values)
            all_cond.append(eigs.max(dim=1).values / eigs.min(dim=1).values.clamp(min=1e-10))

    det_vals = torch.cat(all_det)
    min_eigs = torch.cat(all_min_eig)
    cond_vals = torch.cat(all_cond)

    det_err = abs(det_vals.mean().item() - DET_G_TARGET) / DET_G_TARGET * 100

    # Torsion (for neck only)
    if hasattr(model, 'phi_tensor'):
        model.train()
        x_tor = sample_domain(2000, domain)
        tor = torsion_loss_fast(model, x_tor).item()
        model.eval()
    else:
        tor = float('nan')

    result = {
        'det_mean': float(det_vals.mean().item()),
        'det_std': float(det_vals.std().item()),
        'det_err_pct': float(det_err),
        'min_eig': float(min_eigs.min().item()),
        'condition': float(cond_vals.mean().item()),
        'pos_def': bool((min_eigs > 0).all().item()),
        'torsion': float(tor),
    }

    print(f"\n  {label}:")
    print(f"    det(g) = {result['det_mean']:.8f} (err={det_err:.4f}%)")
    print(f"    min_eig = {result['min_eig']:.6f}, cond = {result['condition']:.4f}")
    print(f"    PD = {result['pos_def']}, torsion = {tor:.2e}")

    return result


print("\n" + "=" * 70)
print("  Phase 1 Validation")
print("=" * 70)

val_neck = validate_chart(neck_pinn, DOMAIN_NECK, "Neck")
val_bulk_L = validate_chart(bulk_L_pinn, DOMAIN_BULK_L, "Bulk_L")
val_bulk_R = validate_chart(bulk_R_pinn, DOMAIN_BULK_R, "Bulk_R")


  Phase 1 Validation

  Neck:
    det(g) = 2.03125649 (err=0.0003%)
    min_eig = 1.106538, cond = 1.0000
    PD = True, torsion = 3.52e-07

  Bulk_L:
    det(g) = 2.03125072 (err=0.0000%)
    min_eig = 1.093401, cond = 1.0258
    PD = True, torsion = nan

  Bulk_R:
    det(g) = 2.03124743 (err=0.0001%)
    min_eig = 1.093640, cond = 1.0261
    PD = True, torsion = nan


## 10. Phase 2 — Schwarz Alternating Method

Iteratively match charts at their overlapping interfaces.

**Left interface** (Bulk_L ∩ Neck): direct metric matching (no twist).
**Right interface** (Neck ∩ Bulk_R): matching through Kovalev twist.

Protocol:
1. Freeze Neck → optimize Bulk_L, Bulk_R with interface matching
2. Freeze Bulk_L, Bulk_R → optimize Neck with interface matching
3. Log interface error → repeat until convergence

In [15]:
def schwarz_iteration(neck_model, bulk_L_model, bulk_R_model,
                      n_schwarz=8, epochs_per_step=300,
                      batch_size=1024, verbose=True):
    """
    Schwarz alternating method for 3-chart atlas.

    Returns: history dict with interface matching errors per iteration.
    """
    history = {
        'iteration': [], 'match_left': [], 'match_right': [],
        'det_err_neck': [], 'det_err_left': [], 'det_err_right': [],
    }
    t_start = time.time()

    if verbose:
        print("\n" + "=" * 70)
        print(f"  PHASE 2: Schwarz Alternating ({n_schwarz} iterations)")
        print("=" * 70)

    for schwarz_iter in range(n_schwarz):
        if verbose:
            print(f"\n  --- Schwarz iteration {schwarz_iter + 1}/{n_schwarz} ---")

        # === Step A: Freeze Neck, optimize Bulks with matching ===
        for p in neck_model.parameters():
            p.requires_grad_(False)

        for name, bulk_model, domain, overlap, side in [
            ("Bulk_L", bulk_L_model, DOMAIN_BULK_L, OVERLAP_LEFT, 'left'),
            ("Bulk_R", bulk_R_model, DOMAIN_BULK_R, OVERLAP_RIGHT, 'right'),
        ]:
            optimizer = torch.optim.Adam(bulk_model.parameters(), lr=1e-4)
            for ep in range(epochs_per_step):
                x_bulk = sample_domain(batch_size, domain)
                x_ov = sample_overlap(batch_size // 4, overlap)

                optimizer.zero_grad()
                L_det = det_loss(bulk_model, x_bulk)
                L_ricci = ricci_proxy_loss(bulk_model, x_bulk)
                L_decay = asymptotic_decay_loss(bulk_model, x_bulk, side=side)

                # Interface matching
                if side == 'left':
                    L_match = interface_loss_cholesky(
                        neck_model, bulk_model, x_ov, transform_fn=None)
                else:
                    # Right interface: match through Kovalev twist
                    L_match = interface_loss_cholesky(
                        neck_model, bulk_model, x_ov, transform_fn=kovalev_twist_coords)

                loss = 50 * L_det + 0.5 * L_ricci + 5 * L_decay + 100 * L_match
                loss.backward()
                torch.nn.utils.clip_grad_norm_(bulk_model.parameters(), 1.0)
                optimizer.step()

            if verbose:
                print(f"    {name}: det={L_det.item():.2e}  match={L_match.item():.2e}")

        for p in neck_model.parameters():
            p.requires_grad_(True)

        # === Step B: Freeze Bulks, optimize Neck with matching ===
        for p in bulk_L_model.parameters():
            p.requires_grad_(False)
        for p in bulk_R_model.parameters():
            p.requires_grad_(False)

        optimizer = torch.optim.Adam(neck_model.parameters(), lr=5e-5)
        for ep in range(epochs_per_step):
            x_neck = sample_domain(batch_size, DOMAIN_NECK)
            x_ov_L = sample_overlap(batch_size // 4, OVERLAP_LEFT)
            x_ov_R = sample_overlap(batch_size // 4, OVERLAP_RIGHT)

            optimizer.zero_grad()
            L_det = det_loss(neck_model, x_neck)
            L_tor = torsion_loss_fast(neck_model, x_neck)

            # Left interface: direct
            L_match_L = interface_loss_cholesky(
                neck_model, bulk_L_model, x_ov_L, transform_fn=None)
            # Right interface: Kovalev twist
            L_match_R = interface_loss_cholesky(
                neck_model, bulk_R_model, x_ov_R, transform_fn=kovalev_twist_coords)

            loss = 50 * L_det + 1.0 * L_tor + 100 * (L_match_L + L_match_R)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(neck_model.parameters(), 1.0)
            optimizer.step()

        for p in bulk_L_model.parameters():
            p.requires_grad_(True)
        for p in bulk_R_model.parameters():
            p.requires_grad_(True)

        # === Log interface errors ===
        with torch.no_grad():
            x_ov_L = sample_overlap(2000, OVERLAP_LEFT)
            x_ov_R = sample_overlap(2000, OVERLAP_RIGHT)

            ml = interface_loss_cholesky(
                neck_model, bulk_L_model, x_ov_L).item()
            mr = interface_loss_cholesky(
                neck_model, bulk_R_model, x_ov_R,
                transform_fn=kovalev_twist_coords).item()

            x_n = sample_domain(2000, DOMAIN_NECK)
            x_l = sample_domain(2000, DOMAIN_BULK_L)
            x_r = sample_domain(2000, DOMAIN_BULK_R)
            de_n = abs(neck_model.det_g(x_n).mean().item() - DET_G_TARGET)
            de_l = abs(bulk_L_model.det_g(x_l).mean().item() - DET_G_TARGET)
            de_r = abs(bulk_R_model.det_g(x_r).mean().item() - DET_G_TARGET)

        history['iteration'].append(schwarz_iter)
        history['match_left'].append(float(ml))
        history['match_right'].append(float(mr))
        history['det_err_neck'].append(float(de_n))
        history['det_err_left'].append(float(de_l))
        history['det_err_right'].append(float(de_r))

        if verbose:
            elapsed = time.time() - t_start
            print(f"    Interface: left={ml:.6f}  right={mr:.6f}  ({elapsed:.0f}s)")

        # Early stopping
        if ml < 1e-4 and mr < 1e-4:
            if verbose:
                print("    Interfaces converged!")
            break

    elapsed = time.time() - t_start
    if verbose:
        print(f"\n  Schwarz complete in {elapsed:.1f}s "
              f"({schwarz_iter + 1} iterations)")

    return history

In [16]:
schwarz_history = schwarz_iteration(
    neck_pinn, bulk_L_pinn, bulk_R_pinn,
    n_schwarz=8, epochs_per_step=300
)


  PHASE 2: Schwarz Alternating (8 iterations)

  --- Schwarz iteration 1/8 ---
    Bulk_L: det=2.08e-05  match=3.49e-08
    Bulk_R: det=1.33e-04  match=3.58e-07
    Interface: left=0.000000  right=0.000000  (147s)
    Interfaces converged!

  Schwarz complete in 147.0s (1 iterations)


## 11. Phase 3 — Joint Fine-Tuning

All three PINNs trained jointly with global loss:
torsion + Ricci + det + matching + decay

In [17]:
def joint_fine_tune(neck_model, bulk_L_model, bulk_R_model,
                    n_epochs=2000, batch_size=1024, verbose=True):
    """
    Joint fine-tuning of all three chart PINNs.

    All parameters unfrozen, trained together with global loss.
    """
    all_params = (list(neck_model.parameters()) +
                  list(bulk_L_model.parameters()) +
                  list(bulk_R_model.parameters()))
    optimizer = torch.optim.AdamW(all_params, lr=1e-5, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs)

    history = {'loss': [], 'det': [], 'torsion': [], 'match_L': [], 'match_R': []}
    best_loss = float('inf')
    best_states = None
    t_start = time.time()

    if verbose:
        print("\n" + "=" * 70)
        print(f"  PHASE 3: Joint Fine-Tuning ({n_epochs} epochs)")
        print("=" * 70)

    for ep in range(n_epochs):
        optimizer.zero_grad()

        # Sample all domains
        x_neck = sample_domain(batch_size, DOMAIN_NECK)
        x_bulk_L = sample_domain(batch_size // 2, DOMAIN_BULK_L)
        x_bulk_R = sample_domain(batch_size // 2, DOMAIN_BULK_R)
        x_ov_L = sample_overlap(batch_size // 4, OVERLAP_LEFT)
        x_ov_R = sample_overlap(batch_size // 4, OVERLAP_RIGHT)

        # Neck losses
        L_det_N = det_loss(neck_model, x_neck)
        L_tor = torsion_loss_fast(neck_model, x_neck)

        # Bulk losses
        L_det_L = det_loss(bulk_L_model, x_bulk_L)
        L_det_R = det_loss(bulk_R_model, x_bulk_R)
        L_ricci_L = ricci_proxy_loss(bulk_L_model, x_bulk_L)
        L_ricci_R = ricci_proxy_loss(bulk_R_model, x_bulk_R)

        # Interface matching
        L_match_L = interface_loss_cholesky(
            neck_model, bulk_L_model, x_ov_L)
        L_match_R = interface_loss_cholesky(
            neck_model, bulk_R_model, x_ov_R,
            transform_fn=kovalev_twist_coords)

        # Global loss
        loss = (50 * (L_det_N + L_det_L + L_det_R) +
                1.0 * L_tor +
                0.3 * (L_ricci_L + L_ricci_R) +
                80 * (L_match_L + L_match_R))

        loss.backward()
        torch.nn.utils.clip_grad_norm_(all_params, 1.0)
        optimizer.step()
        scheduler.step()

        history['loss'].append(loss.item())
        history['det'].append((L_det_N + L_det_L + L_det_R).item())
        history['torsion'].append(L_tor.item())
        history['match_L'].append(L_match_L.item())
        history['match_R'].append(L_match_R.item())

        if loss.item() < best_loss:
            best_loss = loss.item()
            best_states = {
                'neck': {k: v.cpu().clone() for k, v in neck_model.state_dict().items()},
                'bulk_L': {k: v.cpu().clone() for k, v in bulk_L_model.state_dict().items()},
                'bulk_R': {k: v.cpu().clone() for k, v in bulk_R_model.state_dict().items()},
            }

        if verbose and (ep % 250 == 0 or ep == n_epochs - 1):
            elapsed = time.time() - t_start
            print(f"  [{ep:5d}] loss={loss.item():.2e}  det={L_det_N.item():.2e}  "
                  f"tor={L_tor.item():.2e}  match_L={L_match_L.item():.2e}  "
                  f"match_R={L_match_R.item():.2e}  ({elapsed:.0f}s)")

    # Restore best
    if best_states is not None:
        neck_model.load_state_dict(best_states['neck'])
        bulk_L_model.load_state_dict(best_states['bulk_L'])
        bulk_R_model.load_state_dict(best_states['bulk_R'])
        neck_model.to(device)
        bulk_L_model.to(device)
        bulk_R_model.to(device)

    elapsed = time.time() - t_start
    if verbose:
        print(f"\n  Joint fine-tuning complete in {elapsed:.1f}s, best={best_loss:.2e}")

    return history

In [18]:
joint_history = joint_fine_tune(neck_pinn, bulk_L_pinn, bulk_R_pinn, n_epochs=2000)


  PHASE 3: Joint Fine-Tuning (2000 epochs)
  [    0] loss=2.85e-03  det=5.75e-11  tor=1.90e-07  match_L=3.04e-08  match_R=2.77e-07  (0s)
  [  250] loss=4.01e-06  det=5.31e-11  tor=1.73e-07  match_L=6.89e-11  match_R=5.41e-09  (113s)
  [  500] loss=9.55e-06  det=4.38e-11  tor=1.57e-07  match_L=1.09e-10  match_R=5.41e-11  (226s)
  [  750] loss=2.35e-06  det=5.27e-11  tor=1.48e-07  match_L=3.97e-12  match_R=2.02e-11  (339s)
  [ 1000] loss=1.74e-06  det=4.59e-11  tor=1.41e-07  match_L=7.45e-12  match_R=1.06e-11  (452s)
  [ 1250] loss=1.34e-06  det=4.86e-11  tor=1.41e-07  match_L=3.09e-12  match_R=8.86e-12  (565s)
  [ 1500] loss=1.12e-06  det=4.91e-11  tor=1.37e-07  match_L=2.75e-12  match_R=5.64e-12  (678s)
  [ 1750] loss=1.01e-06  det=4.22e-11  tor=1.32e-07  match_L=1.95e-12  match_R=5.34e-12  (792s)
  [ 1999] loss=1.05e-06  det=4.88e-11  tor=1.33e-07  match_L=2.14e-12  match_R=5.07e-12  (905s)

  Joint fine-tuning complete in 904.6s, best=9.22e-07


## 12. Global Validation

Evaluate the assembled atlas across all three charts.

In [19]:
def global_validation(neck_model, bulk_L_model, bulk_R_model,
                      n_per_chart=20000, verbose=True):
    """
    Comprehensive validation of the 3-chart atlas.

    Checks:
    - Per-chart det(g), positive-definiteness, condition number
    - Interface matching error (both sides)
    - Torsion in neck region
    - Eigenvalue statistics
    """
    results = {}

    if verbose:
        print("\n" + "=" * 70)
        print("  GLOBAL VALIDATION")
        print("=" * 70)

    # Per-chart validation
    for name, model, domain in [
        ("Neck", neck_model, DOMAIN_NECK),
        ("Bulk_L", bulk_L_model, DOMAIN_BULK_L),
        ("Bulk_R", bulk_R_model, DOMAIN_BULK_R),
    ]:
        results[name] = validate_chart(model, domain, name, n_samples=n_per_chart)

    # Interface errors
    with torch.no_grad():
        x_ov_L = sample_overlap(5000, OVERLAP_LEFT)
        x_ov_R = sample_overlap(5000, OVERLAP_RIGHT)

        ml = interface_loss_cholesky(neck_model, bulk_L_model, x_ov_L).item()
        mr = interface_loss_cholesky(
            neck_model, bulk_R_model, x_ov_R,
            transform_fn=kovalev_twist_coords).item()

    results['interface_left'] = float(ml)
    results['interface_right'] = float(mr)

    if verbose:
        print(f"\n  Interface matching:")
        print(f"    Left  (Bulk_L ∩ Neck): {ml:.6f}")
        print(f"    Right (Neck ∩ Bulk_R): {mr:.6f} (through Kovalev twist)")

    # Torsion in neck (the critical region)
    neck_model.train()
    with torch.no_grad():
        x_tor = sample_domain(5000, DOMAIN_NECK)
    tor = torsion_loss_full(neck_model, x_tor).item()
    neck_model.eval()
    safety = TORSION_THRESHOLD / max(tor, 1e-10)

    results['global_torsion'] = float(tor)
    results['joyce_margin'] = float(safety)

    if verbose:
        print(f"\n  Global torsion (neck): {tor:.2e}")
        print(f"  Joyce margin: {safety:,.0f}×")
        print(f"  Joyce threshold: {TORSION_THRESHOLD}")

    # Overall assessment
    all_pd = all(results[c]['pos_def'] for c in ['Neck', 'Bulk_L', 'Bulk_R'])
    max_det_err = max(results[c]['det_err_pct'] for c in ['Neck', 'Bulk_L', 'Bulk_R'])
    interface_ok = ml < 1e-2 and mr < 1e-2

    results['all_positive_definite'] = all_pd
    results['max_det_error_pct'] = float(max_det_err)
    results['interfaces_converged'] = interface_ok
    results['level_A_pass'] = all_pd and max_det_err < 0.1 and interface_ok and safety > 10

    if verbose:
        print(f"\n  LEVEL A ASSESSMENT:")
        print(f"    All PD: {all_pd}")
        print(f"    Max det error: {max_det_err:.4f}%")
        print(f"    Interfaces converged: {interface_ok}")
        print(f"    Joyce margin > 10: {safety > 10}")
        print(f"    >>> LEVEL A: {'PASS' if results['level_A_pass'] else 'FAIL'} <<<")

    return results

In [20]:
global_results = global_validation(neck_pinn, bulk_L_pinn, bulk_R_pinn)


  GLOBAL VALIDATION

  Neck:
    det(g) = 2.03125400 (err=0.0002%)
    min_eig = 1.106538, cond = 1.0000
    PD = True, torsion = 1.35e-07

  Bulk_L:
    det(g) = 2.03124997 (err=0.0000%)
    min_eig = 1.106510, cond = 1.0000
    PD = True, torsion = nan

  Bulk_R:
    det(g) = 2.03124845 (err=0.0001%)
    min_eig = 1.106508, cond = 1.0000
    PD = True, torsion = nan

  Interface matching:
    Left  (Bulk_L ∩ Neck): 0.000000
    Right (Neck ∩ Bulk_R): 0.000000 (through Kovalev twist)

  Global torsion (neck): 6.88e-06
  Joyce margin: 4,188×
  Joyce threshold: 0.0288

  LEVEL A ASSESSMENT:
    All PD: True
    Max det error: 0.0002%
    Interfaces converged: True
    Joyce margin > 10: True
    >>> LEVEL A: PASS <<<


## 13. Phase 4 — Spectral Bridge

Compute Laplacian eigenvalues on the assembled K₇ metric.
The MC Galerkin method uses basis functions adapted to the K₇ topology.

**Key test**: λ₁ × H* = 14  (GIFT master equation)

In [21]:
def compute_spectral_bridge(neck_model, bulk_L_model, bulk_R_model,
                            n_basis=60, n_mc=30000, verbose=True):
    """
    Monte Carlo Galerkin computation of Laplacian eigenvalues.

    Uses TCS-adapted basis functions:
    - Neck: Fourier modes on K3 × T²
    - Bulks: exponentially decaying modes

    Returns: eigenvalues and λ₁ × H* comparison.
    """
    if verbose:
        print("\n" + "=" * 70)
        print("  PHASE 4: Spectral Bridge")
        print("=" * 70)

    def global_metric_at(x):
        """Evaluate the assembled global metric at point x."""
        t = x[:, 0]
        g = torch.zeros(x.shape[0], 7, 7, device=x.device, dtype=x.dtype)
        N = x.shape[0]

        # Determine chart membership
        in_bulk_L = t < OVERLAP_LEFT[0]
        in_bulk_R = t > OVERLAP_RIGHT[1]
        in_overlap_L = (t >= OVERLAP_LEFT[0]) & (t <= OVERLAP_LEFT[1])
        in_overlap_R = (t >= OVERLAP_RIGHT[0]) & (t <= OVERLAP_RIGHT[1])
        in_neck_only = (t > OVERLAP_LEFT[1]) & (t < OVERLAP_RIGHT[0])

        # Pure regions
        if in_bulk_L.any():
            g[in_bulk_L] = bulk_L_model.metric(x[in_bulk_L])
        if in_bulk_R.any():
            g[in_bulk_R] = bulk_R_model.metric(x[in_bulk_R])
        if in_neck_only.any():
            g[in_neck_only] = neck_model.metric(x[in_neck_only])

        # Left overlap: blend Bulk_L and Neck
        if in_overlap_L.any():
            x_ov = x[in_overlap_L]
            alpha = hermite_blend(x_ov[:, 0], OVERLAP_LEFT[0], OVERLAP_LEFT[1])
            g_bulk = bulk_L_model.metric(x_ov)
            g_neck = neck_model.metric(x_ov)
            a = alpha.unsqueeze(1).unsqueeze(2)
            g[in_overlap_L] = (1 - a) * g_bulk + a * g_neck

        # Right overlap: blend Neck and Bulk_R (with Kovalev twist)
        if in_overlap_R.any():
            x_ov = x[in_overlap_R]
            alpha = hermite_blend(x_ov[:, 0], OVERLAP_RIGHT[0], OVERLAP_RIGHT[1])
            g_neck = neck_model.metric(x_ov)
            x_twisted = kovalev_twist_coords(x_ov)
            g_bulk_raw = bulk_R_model.metric(x_twisted)
            g_bulk = transform_metric_kovalev(g_bulk_raw, device_ref=x_ov.device)
            a = alpha.unsqueeze(1).unsqueeze(2)
            g[in_overlap_R] = (1 - a) * g_neck + a * g_bulk

        return g

    # MC sample points uniformly across full [0,1]^7
    x_mc = torch.rand(n_mc, 7, device=device, dtype=dtype)

    if verbose:
        print(f"  Computing global metric at {n_mc} MC points...")

    with torch.no_grad():
        g_mc = global_metric_at(x_mc)
        det_g_mc = torch.linalg.det(g_mc)
        g_inv_mc = torch.linalg.inv(g_mc)

    if verbose:
        print(f"  Global det(g): mean={det_g_mc.mean():.6f}, std={det_g_mc.std():.6f}")

    # Build basis functions: low-frequency TCS-adapted modes
    # For K₇ = (CY₃ × S¹) ∪ (CY₃ × S¹), the natural basis is:
    # - Products of Fourier modes on S¹ factors × K3 harmonics
    # - Exponentially weighted for asymptotic regions
    if verbose:
        print(f"  Building {n_basis} TCS-adapted basis functions...")

    basis_vectors = torch.arange(1, n_basis + 1, device=device, dtype=dtype)

    # Galerkin matrices: S_ij = <ψ_i, ψ_j>_g  and  K_ij = <∇ψ_i, ∇ψ_j>_g
    S = torch.zeros(n_basis, n_basis, device=device, dtype=dtype)
    K = torch.zeros(n_basis, n_basis, device=device, dtype=dtype)

    # Basis functions: cos/sin modes on different coordinate combinations
    # TCS-adapted: separate modes for (t), (θ,ψ), (u₁..u₄)
    n_t = 5        # radial modes
    n_circle = 10  # S¹ modes (θ, ψ)
    n_k3 = n_basis - n_t - n_circle  # K3 modes

    def basis_fn(x, idx):
        """Evaluate basis function idx at points x. Returns (N,) values."""
        if idx < n_t:
            # Radial modes: cos(k·π·t) with boundary tapering
            k = idx + 1
            return torch.cos(k * math.pi * x[:, 0])
        elif idx < n_t + n_circle:
            # Circle modes: cos/sin on θ or ψ
            j = idx - n_t
            if j < 5:
                return torch.cos((j + 1) * 2 * math.pi * x[:, 1])
            else:
                return torch.sin((j - 4) * 2 * math.pi * x[:, 6])
        else:
            # K3 modes: products of cos on u coordinates
            j = idx - n_t - n_circle
            coord = j % 4
            freq = j // 4 + 1
            return torch.cos(freq * 2 * math.pi * x[:, 2 + coord])

    def basis_grad(x, idx):
        """Gradient of basis function idx. Returns (N, 7)."""
        x_req = x.clone().requires_grad_(True)
        psi = basis_fn(x_req, idx)
        grad = torch.autograd.grad(psi.sum(), x_req, create_graph=False)[0]
        return grad

    # Assemble Galerkin matrices
    if verbose:
        print(f"  Assembling Galerkin matrices ({n_basis}×{n_basis})...")
    t_start = time.time()

    # Pre-compute basis values and gradients
    basis_vals = []
    basis_grads = []
    for i in range(n_basis):
        basis_vals.append(basis_fn(x_mc, i))
        basis_grads.append(basis_grad(x_mc, i))
    basis_vals = torch.stack(basis_vals, dim=1)   # (n_mc, n_basis)
    basis_grads = torch.stack(basis_grads, dim=1) # (n_mc, n_basis, 7)

    # sqrt(det(g)) for volume element
    sqrt_det = torch.sqrt(det_g_mc.abs().clamp(min=1e-10))

    # S_ij = (1/n_mc) Σ_k ψ_i(x_k) ψ_j(x_k) sqrt(det(g_k))
    weighted_basis = basis_vals * sqrt_det.unsqueeze(1)  # (n_mc, n_basis)
    S = (weighted_basis.T @ basis_vals) / n_mc

    # K_ij = (1/n_mc) Σ_k g^{ab}(x_k) ∂_a ψ_i(x_k) ∂_b ψ_j(x_k) sqrt(det(g_k))
    # = (1/n_mc) Σ_k [∇ψ_i · g⁻¹ · ∇ψ_j] sqrt(det)
    for i in range(n_basis):
        # g^{-1} @ grad_i: (n_mc, 7)
        ginv_gradi = torch.einsum('nab,nb->na', g_inv_mc, basis_grads[:, i, :])
        for j in range(i, n_basis):
            # <∇ψ_i, ∇ψ_j>_g = grad_i · g^{-1} · grad_j
            integrand = (ginv_gradi * basis_grads[:, j, :]).sum(dim=1) * sqrt_det
            val = integrand.mean()
            K[i, j] = val
            K[j, i] = val

    elapsed = time.time() - t_start
    if verbose:
        print(f"  Galerkin matrices assembled in {elapsed:.1f}s")
        print(f"  S condition: {torch.linalg.cond(S).item():.1f}")

    # Solve generalized eigenvalue problem: K v = λ S v
    # Regularize S slightly for numerical stability
    S_reg = S + 1e-8 * torch.eye(n_basis, device=device, dtype=dtype)

    try:
        L_chol = torch.linalg.cholesky(S_reg)
        # Transform to standard EVP: L^{-1} K L^{-T} w = λ w
        K_transformed = torch.linalg.solve_triangular(
            L_chol, torch.linalg.solve_triangular(L_chol, K, upper=False).T, upper=False).T
        eigenvalues = torch.linalg.eigvalsh(K_transformed)
    except Exception as e:
        if verbose:
            print(f"  Cholesky failed ({e}), using eigh on S^{-1}K")
        S_inv = torch.linalg.inv(S_reg)
        eigenvalues = torch.linalg.eigvalsh(S_inv @ K)

    # Filter positive eigenvalues and sort
    pos_eigs = eigenvalues[eigenvalues > 1e-6].sort().values

    if verbose:
        print(f"\n  Eigenvalue spectrum (first 20):")
        for i, ev in enumerate(pos_eigs[:20]):
            lxh = ev.item() * H_STAR
            marker = " ← λ₁" if i == 0 else ""
            print(f"    λ_{i + 1} = {ev.item():.6f}  (λ×H* = {lxh:.2f}){marker}")

        if len(pos_eigs) > 0:
            lambda_1 = pos_eigs[0].item()
            lxh = lambda_1 * H_STAR
            print(f"\n  === SPECTRAL BRIDGE RESULT ===")
            print(f"  λ₁ = {lambda_1:.6f}")
            print(f"  λ₁ × H* = {lxh:.4f}")
            print(f"  GIFT prediction: {LAMBDA_1_PREDICTED:.6f} × {H_STAR} = {DIM_G2}")
            print(f"  Deviation: {abs(lxh - DIM_G2) / DIM_G2 * 100:.2f}%")

    spectral_results = {
        'eigenvalues': [float(e) for e in pos_eigs[:50].tolist()],
        'lambda_1': float(pos_eigs[0].item()) if len(pos_eigs) > 0 else None,
        'lambda_1_x_Hstar': float(pos_eigs[0].item() * H_STAR) if len(pos_eigs) > 0 else None,
        'n_basis': n_basis,
        'n_mc': n_mc,
        'S_condition': float(torch.linalg.cond(S).item()),
    }

    return spectral_results

In [22]:
spectral_results = compute_spectral_bridge(
    neck_pinn, bulk_L_pinn, bulk_R_pinn,
    n_basis=60, n_mc=30000
)


  PHASE 4: Spectral Bridge
  Computing global metric at 30000 MC points...
  Global det(g): mean=2.031251, std=0.000053
  Building 60 TCS-adapted basis functions...
  Assembling Galerkin matrices (60×60)...
  Galerkin matrices assembled in 0.2s
  S condition: 1.2

  Eigenvalue spectrum (first 20):
    λ_1 = 9.072689  (λ×H* = 898.20) ← λ₁
    λ_2 = 34.418045  (λ×H* = 3407.39)
    λ_3 = 34.969051  (λ×H* = 3461.94)
    λ_4 = 34.996797  (λ×H* = 3464.68)
    λ_5 = 35.350577  (λ×H* = 3499.71)
    λ_6 = 35.584982  (λ×H* = 3522.91)
    λ_7 = 35.875500  (λ×H* = 3551.67)
    λ_8 = 36.052364  (λ×H* = 3569.18)
    λ_9 = 80.713593  (λ×H* = 7990.65)
    λ_10 = 139.570237  (λ×H* = 13817.45)
    λ_11 = 142.029795  (λ×H* = 14060.95)
    λ_12 = 142.611103  (λ×H* = 14118.50)
    λ_13 = 142.942996  (λ×H* = 14151.36)
    λ_14 = 143.681905  (λ×H* = 14224.51)
    λ_15 = 144.759877  (λ×H* = 14331.23)
    λ_16 = 145.824855  (λ×H* = 14436.66)
    λ_17 = 222.547518  (λ×H* = 22032.20)
    λ_18 = 313.249203  (λ×H

## 14. Visualization

In [25]:
try:
    import os
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle('GIFT Atlas G₂ Metric — 3-Chart TCS Construction', fontsize=14)

    # Panel 1: Schwarz convergence
    ax = axes[0, 0]
    ax.semilogy(schwarz_history['iteration'], schwarz_history['match_left'], 'b-o', label='Left')
    ax.semilogy(schwarz_history['iteration'], schwarz_history['match_right'], 'r-s', label='Right')
    ax.axhline(1e-3, color='g', ls='--', alpha=0.5, label='Target')
    ax.set_xlabel('Schwarz Iteration')
    ax.set_ylabel('Interface Error')
    ax.set_title('Schwarz Convergence')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Panel 2: Joint fine-tuning loss
    ax = axes[0, 1]
    ax.semilogy(joint_history['loss'], alpha=0.5)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Total Loss')
    ax.set_title('Joint Fine-Tuning')
    ax.grid(True, alpha=0.3)

    # Panel 3: Torsion during joint training
    ax = axes[0, 2]
    ax.semilogy(joint_history['torsion'], alpha=0.5, color='red')
    ax.axhline(TORSION_THRESHOLD, color='k', ls='--', label=f'Joyce ε₀={TORSION_THRESHOLD}')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Torsion')
    ax.set_title('Torsion Convergence')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Panel 4: det(g) across full manifold
    ax = axes[1, 0]
    with torch.no_grad():
        t_range = torch.linspace(0, 1, 200, device=device, dtype=dtype)
        det_profile = []
        for t_val in t_range:
            x_sample = torch.rand(100, 7, device=device, dtype=dtype)
            x_sample[:, 0] = t_val
            if t_val < OVERLAP_LEFT[0]:
                d = bulk_L_pinn.det_g(x_sample).mean().item()
            elif t_val > OVERLAP_RIGHT[1]:
                d = bulk_R_pinn.det_g(x_sample).mean().item()
            else:
                d = neck_pinn.det_g(x_sample).mean().item()
            det_profile.append(d)
    ax.plot(t_range.cpu().numpy(), det_profile, 'b-')
    ax.axhline(DET_G_TARGET, color='r', ls='--', label=f'65/32 = {DET_G_TARGET}')
    ax.axvspan(*OVERLAP_LEFT, alpha=0.1, color='green', label='Overlap L')
    ax.axvspan(*OVERLAP_RIGHT, alpha=0.1, color='orange', label='Overlap R')
    ax.set_xlabel('t (radial coordinate)')
    ax.set_ylabel('det(g)')
    ax.set_title('det(g) Profile Across K₇')
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)

    # Panel 5: Eigenvalue spectrum
    ax = axes[1, 1]
    if spectral_results['eigenvalues']:
        eigs = spectral_results['eigenvalues'][:30]
        eigs_x_H = [e * H_STAR for e in eigs]
        ax.bar(range(1, len(eigs_x_H) + 1), eigs_x_H, alpha=0.7)
        ax.axhline(DIM_G2, color='r', ls='--', label=f'dim(G₂) = {DIM_G2}')
        ax.set_xlabel('Eigenvalue index n')
        ax.set_ylabel('λₙ × H*')
        ax.set_title('Spectral Bridge: λₙ × H*')
        ax.legend()
    ax.grid(True, alpha=0.3)

    # Panel 6: Summary
    ax = axes[1, 2]
    ax.axis('off')
    summary_text = (
        f"GIFT Atlas G₂ Metric Results\n"
        f"{'=' * 35}\n\n"
        f"Topology: K₇ = TCS(Quintic, CI(2,2,2))\n"
        f"b₂ = {B2}, b₃ = {B3}, H* = {H_STAR}\n\n"
        f"Neck:   det err = {global_results['Neck']['det_err_pct']:.4f}%\n"
        f"Bulk L: det err = {global_results['Bulk_L']['det_err_pct']:.4f}%\n"
        f"Bulk R: det err = {global_results['Bulk_R']['det_err_pct']:.4f}%\n\n"
        f"Interface L: {global_results['interface_left']:.6f}\n"
        f"Interface R: {global_results['interface_right']:.6f}\n\n"
        f"Torsion: {global_results['global_torsion']:.2e}\n"
        f"Joyce margin: {global_results['joyce_margin']:,.0f}×\n\n"
    )
    if spectral_results['lambda_1'] is not None:
        summary_text += (
            f"λ₁ = {spectral_results['lambda_1']:.6f}\n"
            f"λ₁ × H* = {spectral_results['lambda_1_x_Hstar']:.4f}\n"
            f"GIFT: dim(G₂) = {DIM_G2}\n"
            f"Level A: {'PASS' if global_results['level_A_pass'] else 'FAIL'}\n"
        )
    ax.text(0.05, 0.95, summary_text, transform=ax.transAxes,
            fontsize=10, verticalalignment='top', fontfamily='monospace',
            bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

    plt.tight_layout()
    os.makedirs('outputs', exist_ok=True)
    plt.savefig('outputs/gift_atlas_results.png', dpi=150, bbox_inches='tight')
    print("Plot saved: outputs/gift_atlas_results.png")
    plt.close()

except ImportError:
    print("matplotlib not available, skipping visualization")

Plot saved: outputs/gift_atlas_results.png


## 15. Save Results

In [26]:
# Save checkpoints
os.makedirs('outputs', exist_ok=True)

torch.save(neck_pinn.state_dict(), 'outputs/atlas_neck_pinn.pt')
torch.save(bulk_L_pinn.state_dict(), 'outputs/atlas_bulk_L_pinn.pt')
torch.save(bulk_R_pinn.state_dict(), 'outputs/atlas_bulk_R_pinn.pt')
print("Checkpoints saved: atlas_neck_pinn.pt, atlas_bulk_L_pinn.pt, atlas_bulk_R_pinn.pt")

# Save results JSON
all_results = {
    'version': 'A1',
    'date': time.strftime('%Y-%m-%d %H:%M:%S'),
    'architecture': {
        'neck_params': neck_pinn.param_count(),
        'bulk_L_params': bulk_L_pinn.param_count(),
        'bulk_R_params': bulk_R_pinn.param_count(),
        'total_params': total_params,
    },
    'topology': {
        'b2': B2, 'b3': B3, 'H_star': H_STAR,
        'b2_M1': B2_M1, 'b3_M1': B3_M1,
        'b2_M2': B2_M2, 'b3_M2': B3_M2,
    },
    'domains': {
        'bulk_L': list(DOMAIN_BULK_L),
        'neck': list(DOMAIN_NECK),
        'bulk_R': list(DOMAIN_BULK_R),
        'overlap_L': list(OVERLAP_LEFT),
        'overlap_R': list(OVERLAP_RIGHT),
    },
    'phase1_validation': {
        'neck': val_neck,
        'bulk_L': val_bulk_L,
        'bulk_R': val_bulk_R,
    },
    'schwarz': {
        'n_iterations': len(schwarz_history['iteration']),
        'final_match_left': schwarz_history['match_left'][-1] if schwarz_history['match_left'] else None,
        'final_match_right': schwarz_history['match_right'][-1] if schwarz_history['match_right'] else None,
    },
    'global_validation': global_results,
    'spectral_bridge': spectral_results,
    'level_A_pass': global_results.get('level_A_pass', False),
}

with open('outputs/gift_atlas_results.json', 'w') as f:
    json.dump(all_results, f, indent=2, default=float)
print("Results saved: outputs/gift_atlas_results.json")

Checkpoints saved: atlas_neck_pinn.pt, atlas_bulk_L_pinn.pt, atlas_bulk_R_pinn.pt
Results saved: outputs/gift_atlas_results.json


## 16. Summary

### What We Built
- 3-chart atlas on K₇ = TCS(Quintic, CI(2,2,2)) with Kovalev twist
- ~600k total parameters across 3 PINNs
- Schwarz alternating method for interface matching
- Full spectral bridge via MC Galerkin

### Key Results
- det(g) = 65/32 across all charts
- Interface matching through Kovalev twist
- Torsion below Joyce threshold
- λ₁ × H* = ? (the decisive test for GIFT)

### Next Steps (if Level A PASS)
- **A2**: Multi-seed convergence on atlas (Level B)
- **A3**: Spectral bridge refinement (more basis functions, adaptive MC)
- **A4**: Compare λₙ with Riemann zeros (Level A.5)

In [None]:
print("\n" + "=" * 70)
print("  GIFT Atlas G₂ Metric — COMPLETE")
print("=" * 70)
print(f"\n  Level A: {'PASS' if all_results['level_A_pass'] else 'FAIL'}")
if spectral_results['lambda_1'] is not None:
    print(f"  λ₁ × H* = {spectral_results['lambda_1_x_Hstar']:.4f} (target: {DIM_G2})")
print(f"\n  Total parameters: {total_params:,}")
print(f"  Files saved in outputs/")
print("=" * 70)