# GIFT Yukawa Tensor from Atlas G₂ Metric

**Phase 3**: Atlas metric → Harmonic forms → Yukawa couplings → Mass ratios

**Hardware**: A100 GPU (Colab)
**Depends on**: `colab_atlas_g2_metric.ipynb` (Phase 2 — atlas A1 weights)

## Pipeline
1. **Load atlas**: Neck + Bulk_L + Bulk_R PINNs from A1 checkpoints
2. **Harmonic 2-forms**: 21 basis elements of H²(K₇) via metric inner product
3. **Harmonic 3-forms**: 77 basis elements of H³(K₇) (35 local + 42 global)
4. **Yukawa tensor**: Y_{IJA} = ∫_{K₇} ω_I ∧ ω_J ∧ ρ_A (MC integration)
5. **Physics**: Mass matrices, mixing angles, Koide formula

## Key Formula
From the 11D Chern-Simons coupling C₃ ∧ G₄ ∧ G₄, dimensionally reduced on K₇:

$$Y_{IJA} = \int_{K_7} \omega_I \wedge \omega_J \wedge \rho_A$$

where ω_I ∈ H²(K₇) (b₂=21) and ρ_A ∈ H³(K₇) (b₃=77).
This is a (2+2+3=7)-form, integrated over the 7-manifold.

**References**: Acharya-Witten (2001), Atiyah-Witten (2002)

In [None]:
# ============================================================
# Cell 1: Setup & Dependencies
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import time
import json
import os
from itertools import combinations

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float64

print(f'PyTorch {torch.__version__}')
print(f'Device: {device}')
if device.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    props = torch.cuda.get_device_properties(0)
    print(f'Memory: {props.total_memory / 1e9:.1f} GB')

torch.manual_seed(42)
np.random.seed(42)

T_START = time.time()

In [None]:
# ============================================================
# Cell 2: GIFT Constants & TCS Domain Layout
# ============================================================
DIM = 7
B2 = 21                          # b_2(K_7)
B3 = 77                          # b_3(K_7)
H_STAR = B2 + B3 + 1            # = 99
DIM_G2 = 14
DET_G_TARGET = 65 / 32          # = 2.03125
SCALE_FACTOR = DET_G_TARGET ** (1 / 14)  # ~ 1.0543
KAPPA_T = 1.0 / 61

# TCS building blocks
B2_M1, B3_M1 = 11, 40           # Quintic CY3
B2_M2, B3_M2 = 10, 37           # CI(2,2,2) CY3

# Atlas domain layout (t = coord[0])
DOMAIN_BULK_L = (0.00, 0.40)
DOMAIN_NECK   = (0.25, 0.75)
DOMAIN_BULK_R = (0.60, 1.00)

# Number of independent components for k-forms in 7D
N_2FORM = 21   # C(7,2)
N_3FORM = 35   # C(7,3)
N_4FORM = 35   # C(7,4)

# All ordered index tuples
IDX_2 = list(combinations(range(DIM), 2))
IDX_3 = list(combinations(range(DIM), 3))
IDX_4 = list(combinations(range(DIM), 4))

# Lookup tables
IDX_2_MAP = {t: i for i, t in enumerate(IDX_2)}
IDX_3_MAP = {t: i for i, t in enumerate(IDX_3)}
IDX_4_MAP = {t: i for i, t in enumerate(IDX_4)}

print(f'GIFT Yukawa Pipeline')
print(f'  b_2={B2}, b_3={B3}, H*={H_STAR}, dim(G_2)={DIM_G2}')
print(f'  det(g) target = {DET_G_TARGET}')
print(f'  Yukawa tensor: {B2} x {B2} x {B3} = {B2*B2*B3:,} components')
print(f'  M1 (Quintic): b2={B2_M1}, b3={B3_M1}')
print(f'  M2 (CI(2,2,2)): b2={B2_M2}, b3={B3_M2}')

In [None]:
# ============================================================
# Cell 3: G2 Algebraic Infrastructure
# ============================================================

# Standard G2 form: 7 Fano triples with signs
STANDARD_G2_FORM = [
    ((0, 1, 3), +1), ((1, 2, 4), +1), ((2, 3, 5), +1),
    ((3, 4, 6), +1), ((4, 5, 0), +1), ((5, 6, 1), +1), ((6, 0, 2), +1),
]
FANO_LINES = [(0,1,3),(1,2,4),(2,3,5),(3,4,6),(4,5,0),(5,6,1),(6,0,2)]


def build_phi0_tensor():
    """Full 7x7x7 antisymmetric tensor for the standard G2 form."""
    phi0 = np.zeros((7, 7, 7), dtype=np.float64)
    for (indices, sign) in STANDARD_G2_FORM:
        i, j, k = indices
        phi0[i,j,k] = sign;  phi0[j,k,i] = sign;  phi0[k,i,j] = sign
        phi0[j,i,k] = -sign; phi0[i,k,j] = -sign; phi0[k,j,i] = -sign
    return phi0


PHI0_TENSOR = build_phi0_tensor()


def phi0_components(normalize=True):
    """Standard G2 3-form as 35 independent components."""
    phi0 = np.zeros(35, dtype=np.float64)
    idx = 0
    for i in range(7):
        for j in range(i+1, 7):
            for k in range(j+1, 7):
                phi0[idx] = PHI0_TENSOR[i, j, k]
                idx += 1
    if normalize:
        phi0 *= SCALE_FACTOR
    return phi0


def g2_generators():
    """14 generators of G2 in so(7)."""
    gens = np.zeros((14, 7, 7), dtype=np.float64)
    for idx, (i, j, k) in enumerate(FANO_LINES):
        gens[idx, i, j] = 1;  gens[idx, j, i] = -1
    for idx in range(7):
        i, j, k = idx, (idx + 1) % 7, (idx + 3) % 7
        gens[7 + idx, i, k] = 1;    gens[7 + idx, k, i] = -1
        gens[7 + idx, j, k] = 0.5;  gens[7 + idx, k, j] = -0.5
    for idx in range(14):
        norm = np.linalg.norm(gens[idx])
        if norm > 1e-10:
            gens[idx] /= norm
    return gens


def compute_lie_derivatives():
    """Lie derivatives L_X phi0 for each G2 generator X. Shape: (14, 35)."""
    gens = g2_generators()
    phi0 = phi0_components(normalize=True)
    lie = np.zeros((14, 35), dtype=np.float64)
    for g_idx in range(14):
        X = gens[g_idx]
        c_idx = 0
        for i in range(7):
            for j in range(i+1, 7):
                for k in range(j+1, 7):
                    val = 0.0
                    for m in range(7):
                        val += X[i,m] * PHI0_TENSOR[m,j,k] * SCALE_FACTOR
                        val += X[j,m] * PHI0_TENSOR[i,m,k] * SCALE_FACTOR
                        val += X[k,m] * PHI0_TENSOR[i,j,m] * SCALE_FACTOR
                    lie[g_idx, c_idx] = val
                    c_idx += 1
    return lie


LIE_DERIVATIVES = compute_lie_derivatives()
print(f'G2 generators: 14 matrices in so(7)')
print(f'Lie derivative matrix: {LIE_DERIVATIVES.shape}')

In [None]:
# ============================================================
# Cell 4: PINN Architecture (reproduced from atlas notebook)
# ============================================================

class FourierFeatures(nn.Module):
    """Random Fourier features with periodic-aware encoding."""
    def __init__(self, input_dim=7, num_random_freq=24, num_periodic_freq=8, scale=1.0):
        super().__init__()
        n_nonperiodic = input_dim - 2
        B_rand = torch.randn(num_random_freq, n_nonperiodic, dtype=dtype) * scale
        self.register_buffer('B_rand', B_rand)
        k = torch.arange(1, num_periodic_freq + 1, dtype=dtype).unsqueeze(1)
        self.register_buffer('k_periodic', k)
        self.output_dim = 2 * num_random_freq + 4 * num_periodic_freq
        self.num_periodic_freq = num_periodic_freq

    def forward(self, x):
        x_nonper = torch.stack([x[:,0], x[:,2], x[:,3], x[:,4], x[:,5]], dim=1)
        proj = 2 * math.pi * x_nonper @ self.B_rand.T
        feat_rand = torch.cat([torch.cos(proj), torch.sin(proj)], dim=-1)
        theta = 2 * math.pi * x[:, 1:2]
        psi = 2 * math.pi * x[:, 6:7]
        k = self.k_periodic.T
        feat_per = torch.cat([
            torch.cos(theta * k), torch.sin(theta * k),
            torch.cos(psi * k), torch.sin(psi * k)
        ], dim=-1)
        return torch.cat([feat_rand, feat_per], dim=-1)


class NeckPINN(nn.Module):
    """G2-native PINN for the TCS neck region."""
    def __init__(self, num_random_freq=24, num_periodic_freq=8,
                 hidden_dims=None, perturbation_scale=0.2):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = [256, 256, 256, 128]
        self.perturbation_scale = perturbation_scale
        self.register_buffer('phi0', torch.from_numpy(phi0_components(True)).to(dtype))
        self.register_buffer('lie_derivs', torch.from_numpy(LIE_DERIVATIVES).to(dtype))
        self.fourier = FourierFeatures(7, num_random_freq, num_periodic_freq)
        layers = []
        in_dim = self.fourier.output_dim
        for h in hidden_dims:
            layers += [nn.Linear(in_dim, h, dtype=dtype), nn.SiLU()]
            in_dim = h
        layers.append(nn.Linear(in_dim, 14, dtype=dtype))
        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        adj = self.mlp(self.fourier(x))
        delta_phi = adj @ self.lie_derivs
        return self.phi0.unsqueeze(0) + self.perturbation_scale * delta_phi

    def phi_tensor(self, x):
        comp = self.forward(x)
        N = comp.shape[0]
        phi = torch.zeros(N, 7, 7, 7, device=x.device, dtype=x.dtype)
        idx = 0
        for i in range(7):
            for j in range(i+1, 7):
                for k in range(j+1, 7):
                    v = comp[:, idx]
                    phi[:,i,j,k] = v;  phi[:,j,k,i] = v;  phi[:,k,i,j] = v
                    phi[:,j,i,k] = -v; phi[:,i,k,j] = -v; phi[:,k,j,i] = -v
                    idx += 1
        return phi

    def metric(self, x):
        phi = self.phi_tensor(x)
        return torch.einsum('nikl,njkl->nij', phi, phi) / 6.0

    def det_g(self, x):
        return torch.linalg.det(self.metric(x))


class BulkCYPINN(nn.Module):
    """Cholesky-parameterized PINN for ACyl CY3 x S1 bulk regions."""
    def __init__(self, num_random_freq=24, num_periodic_freq=8,
                 hidden_dims=None, perturbation_scale=0.2, base_diag=None):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = [256, 256, 256, 128]
        self.perturbation_scale = perturbation_scale
        if base_diag is None:
            base_diag = [1.0, 1.0, SCALE_FACTOR, SCALE_FACTOR,
                         SCALE_FACTOR, SCALE_FACTOR, 1.0]
        L0 = torch.zeros(7, 7, dtype=dtype)
        for i in range(7):
            L0[i, i] = base_diag[i]
        tril_idx = torch.tril_indices(7, 7)
        self.register_buffer('tril_row', tril_idx[0])
        self.register_buffer('tril_col', tril_idx[1])
        self.register_buffer('L0_flat', L0[tril_idx[0], tril_idx[1]])
        self.fourier = FourierFeatures(7, num_random_freq, num_periodic_freq)
        layers = []
        in_dim = self.fourier.output_dim
        for h in hidden_dims:
            layers += [nn.Linear(in_dim, h, dtype=dtype), nn.SiLU()]
            in_dim = h
        layers.append(nn.Linear(in_dim, 28, dtype=dtype))
        self.mlp = nn.Sequential(*layers)

    def cholesky_factor(self, x):
        N = x.shape[0]
        delta_L = self.mlp(self.fourier(x))
        L = torch.zeros(N, 7, 7, device=x.device, dtype=x.dtype)
        for i in range(28):
            val = self.L0_flat[i] + self.perturbation_scale * delta_L[:, i]
            r, c = self.tril_row[i].item(), self.tril_col[i].item()
            if r == c:
                val = F.softplus(val)
            L[:, r, c] = val
        return L

    def metric(self, x):
        L = self.cholesky_factor(x)
        return torch.bmm(L, L.transpose(1, 2))

    def det_g(self, x):
        L = self.cholesky_factor(x)
        log_det = 2 * torch.sum(torch.log(torch.diagonal(L, dim1=1, dim2=2)), dim=1)
        return torch.exp(log_det)


print('PINN architectures defined: NeckPINN, BulkCYPINN')

## Phase 1: Load Atlas A1 Weights

Load the trained PINN models from the atlas A1 run.
Weight files: `atlas_neck_pinn.pt`, `atlas_bulk_L_pinn.pt`, `atlas_bulk_R_pinn.pt`

In [None]:
# ============================================================
# Cell 5: Load Atlas Weights & Verify
# ============================================================
print('\n' + '=' * 70)
print('  PHASE 1: Load Atlas A1 Weights')
print('=' * 70)

# Instantiate models
neck_pinn = NeckPINN().to(device)
bulk_L_pinn = BulkCYPINN().to(device)
bulk_R_pinn = BulkCYPINN().to(device)

# Load weights
CKPT_DIR = 'outputs'
neck_pinn.load_state_dict(torch.load(f'{CKPT_DIR}/atlas_neck_pinn.pt',
                                      map_location=device, weights_only=True))
bulk_L_pinn.load_state_dict(torch.load(f'{CKPT_DIR}/atlas_bulk_L_pinn.pt',
                                        map_location=device, weights_only=True))
bulk_R_pinn.load_state_dict(torch.load(f'{CKPT_DIR}/atlas_bulk_R_pinn.pt',
                                        map_location=device, weights_only=True))

neck_pinn.eval()
bulk_L_pinn.eval()
bulk_R_pinn.eval()

n_params = sum(sum(p.numel() for p in m.parameters())
               for m in [neck_pinn, bulk_L_pinn, bulk_R_pinn])
print(f'  Total parameters: {n_params:,}')


def kovalev_twist_coords(x_neck):
    """Apply Kovalev transition: Neck -> Bulk_R coordinates."""
    x_bulk = x_neck.clone()
    x_bulk[:, 1] = x_neck[:, 6]         # theta_bulk = psi_neck
    x_bulk[:, 6] = x_neck[:, 1]         # psi_bulk = theta_neck
    x_bulk[:, 2] = torch.fmod(-x_neck[:, 3] + 1.0, 1.0)
    x_bulk[:, 3] = x_neck[:, 5]
    x_bulk[:, 4] = x_neck[:, 2]
    x_bulk[:, 5] = x_neck[:, 4]
    return x_bulk


@torch.no_grad()
def global_metric_at(x):
    """Evaluate assembled global metric at points x. Returns (N,7,7)."""
    t = x[:, 0]
    N = x.shape[0]
    g = torch.zeros(N, 7, 7, device=x.device, dtype=x.dtype)

    # Smooth partition of unity
    def bump(t_val, lo, hi):
        mid = (lo + hi) / 2
        half = (hi - lo) / 2
        return torch.exp(-1.0 / (1.0 - ((t_val - mid) / half).clamp(-0.999, 0.999) ** 2))

    w_N = bump(t, 0.25, 0.75)
    w_L = bump(t, 0.00, 0.40)
    w_R = bump(t, 0.60, 1.00)

    # Zero out weights outside domains (cast bool to float64, not float32)
    w_N = w_N * ((t >= 0.25) & (t <= 0.75)).to(x.dtype)
    w_L = w_L * ((t >= 0.00) & (t <= 0.40)).to(x.dtype)
    w_R = w_R * ((t >= 0.60) & (t <= 1.00)).to(x.dtype)

    w_total = w_N + w_L + w_R + 1e-10
    w_N = w_N / w_total
    w_L = w_L / w_total
    w_R = w_R / w_total

    # Neck contribution
    mask_N = (t >= 0.25) & (t <= 0.75)
    if mask_N.any():
        g_neck = neck_pinn.metric(x[mask_N])
        g[mask_N] += w_N[mask_N].unsqueeze(-1).unsqueeze(-1) * g_neck

    # Bulk_L contribution
    mask_L = (t >= 0.00) & (t <= 0.40)
    if mask_L.any():
        g_L = bulk_L_pinn.metric(x[mask_L])
        g[mask_L] += w_L[mask_L].unsqueeze(-1).unsqueeze(-1) * g_L

    # Bulk_R contribution (with Kovalev twist)
    mask_R = (t >= 0.60) & (t <= 1.00)
    if mask_R.any():
        x_R = kovalev_twist_coords(x[mask_R])
        g_R = bulk_R_pinn.metric(x_R)
        g[mask_R] += w_R[mask_R].unsqueeze(-1).unsqueeze(-1) * g_R

    return g


# Quick verification
x_test = torch.rand(1000, 7, device=device, dtype=dtype)
g_test = global_metric_at(x_test)
det_test = torch.linalg.det(g_test)
eig_test = torch.linalg.eigvalsh(g_test)

print(f'\n  Quick verification (1000 random points):')
print(f'    det(g) mean: {det_test.mean().item():.6f} (target: {DET_G_TARGET:.6f})')
print(f'    det(g) std:  {det_test.std().item():.6e}')
print(f'    min eigenvalue: {eig_test.min().item():.6f}')
print(f'    all positive definite: {(eig_test.min(dim=1).values > 0).all().item()}')
print(f'\n  Atlas loaded and verified.')

## Phase 2: Harmonic Form Basis

Construct approximate harmonic forms on K₇ using the PINN metric.

**Strategy**: On a nearly-flat metric, constant-coefficient forms are
approximately harmonic. We orthonormalize them using the L² inner
product induced by the metric:

$$\langle \omega, \eta \rangle_g = \int_{K_7} g^{ik} g^{jl} \omega_{ij} \eta_{kl} \sqrt{\det g} \, d^7x$$

- **2-forms**: 21 = C(7,2) constant basis elements dxⁱ ∧ dxʲ
- **3-forms**: 35 local + 21 M₁-global + 21 M₂-global = 77 total

In [None]:
# ============================================================
# Cell 6: MC Sample Points & Volume Elements
# ============================================================
print('\n' + '=' * 70)
print('  PHASE 2: Harmonic Form Basis')
print('=' * 70)

N_MC = 50000  # Monte Carlo sample points
BATCH = 5000  # batch size for GPU evaluation

print(f'  Sampling {N_MC:,} points in [0,1]^7...')
x_mc = torch.rand(N_MC, 7, device=device, dtype=dtype)

# Evaluate metric at all MC points
print(f'  Evaluating metric at {N_MC:,} points (batched)...')
g_mc_list = []
ginv_mc_list = []
sqrtdet_mc_list = []
det_mc_list = []

t_eval = time.time()
for i in range(0, N_MC, BATCH):
    x_batch = x_mc[i:i+BATCH]
    g_batch = global_metric_at(x_batch)
    det_batch = torch.linalg.det(g_batch)
    g_mc_list.append(g_batch)
    det_mc_list.append(det_batch)
    ginv_mc_list.append(torch.linalg.inv(g_batch))
    sqrtdet_mc_list.append(torch.sqrt(det_batch.clamp(min=1e-10)))

g_mc = torch.cat(g_mc_list, dim=0)
ginv_mc = torch.cat(ginv_mc_list, dim=0)
sqrtdet_mc = torch.cat(sqrtdet_mc_list, dim=0)
det_mc = torch.cat(det_mc_list, dim=0)

print(f'  Metric evaluation: {time.time()-t_eval:.1f}s')
print(f'  det(g): mean={det_mc.mean().item():.6f}, std={det_mc.std().item():.2e}')
print(f'  det(g): min={det_mc.min().item():.6e}, max={det_mc.max().item():.6e}')

# --- Filter outlier MC points ---
# Points where det(g) is extreme or ginv has huge entries corrupt the overlap matrices.
# Use robust statistics (median + MAD) to identify good points.
det_median = det_mc.median().item()
det_mad = (det_mc - det_median).abs().median().item() * 1.4826  # MAD -> sigma
ginv_maxabs = ginv_mc.abs().amax(dim=(1,2))  # max |g^{ij}| at each point
ginv_threshold = ginv_maxabs.median().item() + 10 * ginv_maxabs.std().item()

mask_good = (
    (det_mc > max(1e-6, det_median - 10 * det_mad)) &
    (det_mc < det_median + 10 * det_mad) &
    (ginv_maxabs < max(ginv_threshold, 100.0)) &
    (~torch.isnan(det_mc)) &
    (~torch.isinf(ginv_maxabs))
)

n_good = mask_good.sum().item()
n_bad = N_MC - n_good
print(f'\n  Outlier filter: {n_good:,} good / {N_MC:,} total ({n_bad} removed)')
if n_bad > 0:
    bad_det = det_mc[~mask_good]
    print(f'    Removed det(g) range: [{bad_det.min().item():.2e}, {bad_det.max().item():.2e}]')
    bad_ginv = ginv_maxabs[~mask_good]
    print(f'    Removed max|ginv| range: [{bad_ginv.min().item():.2e}, {bad_ginv.max().item():.2e}]')

# Apply filter
x_mc = x_mc[mask_good]
g_mc = g_mc[mask_good]
ginv_mc = ginv_mc[mask_good]
sqrtdet_mc = sqrtdet_mc[mask_good]
N_MC_EFF = n_good

vol_K7 = sqrtdet_mc.mean().item()
print(f'  Volume density <sqrt(det g)>: {vol_K7:.6f}')
print(f'  Filtered det(g): mean={torch.linalg.det(g_mc).mean().item():.6f}, '
      f'std={torch.linalg.det(g_mc).std().item():.2e}')
print(f'  max|ginv|: {ginv_mc.abs().max().item():.4f}')

In [None]:
# ============================================================
# Cell 7: Harmonic 2-Form Basis (b_2 = 21)
# ============================================================
print('\n  --- Harmonic 2-Forms ---')
print(f'  Building {B2} constant-coefficient 2-forms dx^i ^ dx^j...')

print('  Computing L^2 overlap matrix S_{IJ} = <omega_I, omega_J>_g...')
t_s = time.time()

S2 = torch.zeros(B2, B2, device=device, dtype=dtype)

for I in range(B2):
    a_I, b_I = IDX_2[I]
    for J in range(I, B2):
        a_J, b_J = IDX_2[J]
        val = (ginv_mc[:, a_I, a_J] * ginv_mc[:, b_I, b_J]
             - ginv_mc[:, a_I, b_J] * ginv_mc[:, b_I, a_J]) * sqrtdet_mc
        S2[I, J] = val.mean()
        S2[J, I] = S2[I, J]

# --- Robust eigendecomposition on CPU with Tikhonov regularization ---
S2_np = S2.cpu().double().numpy()
S2_np = 0.5 * (S2_np + S2_np.T)  # enforce symmetry
S2_np = np.nan_to_num(S2_np, nan=0.0, posinf=0.0, neginf=0.0)

# Tikhonov: S2_reg = S2 + eta * I  (bounds condition number)
eta2 = max(1e-10, 1e-6 * np.abs(np.diag(S2_np)).mean())
S2_reg = S2_np + eta2 * np.eye(B2)
print(f'  S2 diag range: [{np.diag(S2_np).min():.4e}, {np.diag(S2_np).max():.4e}]')
print(f'  Tikhonov eta = {eta2:.2e}')

try:
    eig2_np, vec2_np = np.linalg.eigh(S2_reg)
    print(f'  numpy eigh succeeded')
except np.linalg.LinAlgError:
    print(f'  numpy eigh failed, using SVD...')
    U, s, Vt = np.linalg.svd(S2_reg, full_matrices=True)
    eig2_np = s
    vec2_np = U
    idx_sort = np.argsort(eig2_np)
    eig2_np = eig2_np[idx_sort]
    vec2_np = vec2_np[:, idx_sort]

eig2 = torch.from_numpy(eig2_np).to(device=device, dtype=dtype)
vec2 = torch.from_numpy(vec2_np).to(device=device, dtype=dtype)

eig2_clamped = eig2.clamp(min=eta2)
cond2 = eig2_clamped.max().item() / eig2_clamped.min().item()
print(f'  S2 eigenvalues: min={eig2.min().item():.6e}, max={eig2.max().item():.6e}')
print(f'  S2 condition number: {cond2:.2e}')

# Orthonormalize
S2_inv_sqrt = vec2 @ torch.diag(1.0 / torch.sqrt(eig2_clamped)) @ vec2.T
T2 = S2_inv_sqrt  # (21, 21)

# Verify
S2_gpu = S2.to(device)
S2_check = T2.T @ S2_gpu @ T2
off_diag_err = (S2_check - torch.eye(B2, device=device, dtype=dtype)).abs().max().item()
print(f'  Orthonormality check: max|S_new - I| = {off_diag_err:.2e}')
print(f'  2-form basis ready: {B2} orthonormal forms ({time.time()-t_s:.1f}s)')

In [None]:
# ============================================================
# Cell 8: Harmonic 3-Form Basis (b_3 = 77)
# ============================================================
print('\n  --- Harmonic 3-Forms ---')
print(f'  Target: {B3} forms = 35 local + 21 M1-global + 21 M2-global')

# Cutoff functions
t_mc = x_mc[:, 0]

def smooth_cutoff_left(t, t_max=0.40, width=0.10):
    x = (t - (t_max - width)) / width
    return torch.where(x < 0, torch.ones_like(x),
           torch.where(x > 1, torch.zeros_like(x),
                       0.5 * (1 + torch.cos(math.pi * x))))

def smooth_cutoff_right(t, t_min=0.60, width=0.10):
    x = ((t_min + width) - t) / width
    return torch.where(x < 0, torch.ones_like(x),
           torch.where(x > 1, torch.zeros_like(x),
                       0.5 * (1 + torch.cos(math.pi * x))))

chi_L = smooth_cutoff_left(t_mc)
chi_R = smooth_cutoff_right(t_mc)

print(f'  Cutoff weights: <chi_L>={chi_L.mean():.3f}, <chi_R>={chi_R.mean():.3f}')

K3_PAIRS = list(combinations([2, 3, 4, 5], 2))  # 6 pairs
N_GLOBAL_PER_SIDE = 21

def get_3form_index(i, j, k):
    triple = tuple(sorted([i, j, k]))
    return IDX_3_MAP.get(triple, -1)

# Build 3-form coefficients: shape (77, N_MC_EFF, 35)
print(f'  Building 3-form basis coefficients ({N_MC_EFF} MC points)...')
form3_coeffs = torch.zeros(B3, N_MC_EFF, N_3FORM, device=device, dtype=dtype)

# Local forms (0..34): constant
for a in range(N_3FORM):
    form3_coeffs[a, :, a] = 1.0

# M1-global forms (35..55): chi_L(t) * cos(n*pi*u) * sigma ^ dtheta
for a_idx in range(N_GLOBAL_PER_SIDE):
    form_id = N_3FORM + a_idx
    pair_idx = a_idx % len(K3_PAIRS)
    freq = a_idx // len(K3_PAIRS) + 1
    p, q = K3_PAIRS[pair_idx]
    comp_idx = get_3form_index(1, p, q)
    if comp_idx >= 0:
        modulation = chi_L * torch.cos(freq * math.pi * x_mc[:, p])
        form3_coeffs[form_id, :, comp_idx] = modulation

# M2-global forms (56..76): chi_R(t) * cos(n*pi*u) * sigma ^ dpsi
for b_idx in range(N_GLOBAL_PER_SIDE):
    form_id = N_3FORM + N_GLOBAL_PER_SIDE + b_idx
    pair_idx = b_idx % len(K3_PAIRS)
    freq = b_idx // len(K3_PAIRS) + 1
    p, q = K3_PAIRS[pair_idx]
    comp_idx = get_3form_index(6, p, q)
    if comp_idx < 0:
        comp_idx = get_3form_index(min(6,p,q), sorted([6,p,q])[1], max(6,p,q))
    if comp_idx >= 0:
        modulation = chi_R * torch.cos(freq * math.pi * x_mc[:, p])
        form3_coeffs[form_id, :, comp_idx] = modulation

print(f'  3-form basis shape: {form3_coeffs.shape}')
print(f'  Local forms:  0..34  (constant on K7)')
print(f'  M1 global:   35..55  (localized to left bulk)')
print(f'  M2 global:   56..76  (localized to right bulk)')

In [None]:
# ============================================================
# Cell 9: Orthonormalize 3-Forms with Metric Inner Product
# ============================================================
print('\n  Computing L^2 overlap matrix for 3-forms...')
t_s3 = time.time()

# For basis 3-forms e^alpha, e^beta with ordered triples:
# <e^alpha, e^beta>_g = det(G_sub) where G_sub is the 3x3 submatrix
# of g^{-1} with rows from alpha and cols from beta.

print('  Building 3-form metric tensor G3 (35x35) at MC points...')

S3 = torch.zeros(B3, B3, device=device, dtype=dtype)

for batch_start in range(0, N_MC_EFF, BATCH):
    batch_end = min(batch_start + BATCH, N_MC_EFF)
    bs = batch_end - batch_start
    ginv_b = ginv_mc[batch_start:batch_end]  # (bs, 7, 7)
    sqrtdet_b = sqrtdet_mc[batch_start:batch_end]  # (bs,)
    
    # Build G3 at these points: shape (bs, 35, 35)
    G3_b = torch.zeros(bs, N_3FORM, N_3FORM, device=device, dtype=dtype)
    for alpha in range(N_3FORM):
        i1, j1, k1 = IDX_3[alpha]
        for beta in range(alpha, N_3FORM):
            i2, j2, k2 = IDX_3[beta]
            # det(G_sub) via Leibniz formula
            val = torch.zeros(bs, device=device, dtype=dtype)
            for (a,b,c), sgn1 in [((i1,j1,k1),1),((j1,k1,i1),1),((k1,i1,j1),1),
                                   ((j1,i1,k1),-1),((i1,k1,j1),-1),((k1,j1,i1),-1)]:
                val += sgn1 * (ginv_b[:,a,i2] * ginv_b[:,b,j2] * ginv_b[:,c,k2])
            G3_b[:, alpha, beta] = val
            G3_b[:, beta, alpha] = val
    
    forms_b = form3_coeffs[:, batch_start:batch_end, :]  # (77, bs, 35)
    G3_forms = torch.einsum('xab,Axb->Axa', G3_b, forms_b)  # (77, bs, 35)
    S3 += torch.einsum('Axc,Bxc->AB', forms_b * sqrtdet_b.unsqueeze(0).unsqueeze(-1),
                       G3_forms) / N_MC_EFF

# Symmetrize
S3 = 0.5 * (S3 + S3.T)

# --- Robust eigendecomposition with Tikhonov regularization ---
print(f'  S3 diagnostics:')
S3_np = S3.cpu().double().numpy()
print(f'    range: [{S3_np.min():.4e}, {S3_np.max():.4e}]')
print(f'    diag range: [{np.diag(S3_np).min():.4e}, {np.diag(S3_np).max():.4e}]')
print(f'    NaN: {np.isnan(S3_np).any()}, Inf: {np.isinf(S3_np).any()}')

S3_np = np.nan_to_num(S3_np, nan=0.0, posinf=0.0, neginf=0.0)

# Tikhonov: S3_reg = S3 + eta * I  (bounds condition number)
eta3 = max(1e-10, 1e-6 * np.abs(np.diag(S3_np)).mean())
S3_reg = S3_np + eta3 * np.eye(B3)
print(f'  Tikhonov eta3 = {eta3:.2e}')

try:
    eig3_np, vec3_np = np.linalg.eigh(S3_reg)
    print(f'  numpy eigh succeeded')
except np.linalg.LinAlgError:
    print(f'  numpy eigh failed, using SVD fallback...')
    U, s, Vt = np.linalg.svd(S3_reg, full_matrices=True)
    eig3_np = s
    vec3_np = U
    idx_sort = np.argsort(eig3_np)
    eig3_np = eig3_np[idx_sort]
    vec3_np = vec3_np[:, idx_sort]

eig3 = torch.from_numpy(eig3_np).to(device=device, dtype=dtype)
vec3 = torch.from_numpy(vec3_np).to(device=device, dtype=dtype)

n_positive = (eig3 > 1e-10).sum().item()
print(f'  S3 eigenvalues: {n_positive} positive out of {B3}')
if n_positive > 0:
    pos_eigs = eig3[eig3 > 1e-10]
    print(f'  S3 eig range: [{pos_eigs.min().item():.6e}, {pos_eigs.max().item():.6e}]')
    print(f'  S3 condition (positive modes): {pos_eigs.max().item()/pos_eigs.min().item():.2e}')

# Keep only well-conditioned positive eigenvalues
eig_max = eig3.max().item()
threshold = max(1e-8, 1e-6 * eig_max)
mask3 = eig3 > threshold
n_effective = mask3.sum().item()
eig3_pos = eig3[mask3]
vec3_pos = vec3[:, mask3]

print(f'  Threshold: {threshold:.2e}, keeping {n_effective} modes')

# Orthonormalize
T3 = vec3_pos @ torch.diag(1.0 / torch.sqrt(eig3_pos))  # (77, n_eff)

# Pad to full size if needed
if n_effective < B3:
    print(f'  Warning: only {n_effective}/{B3} linearly independent 3-forms')
    print(f'  Padding with zeros (these modes have zero Yukawa coupling)')
    T3_full = torch.zeros(B3, B3, device=device, dtype=dtype)
    T3_full[:, :n_effective] = T3
    T3 = T3_full

print(f'  3-form basis ready: {n_effective} effective forms ({time.time()-t_s3:.1f}s)')

## Phase 3: Yukawa Tensor

Compute Y_{IJA} = ∫_{K₇} ω_I ∧ ω_J ∧ ρ_A via Monte Carlo integration.

The wedge product of two 2-forms and one 3-form gives a 7-form (top form),
which is a scalar density we integrate over K₇.

In [None]:
# ============================================================
# Cell 10: Wedge Product Infrastructure
# ============================================================
print('\n' + '=' * 70)
print('  PHASE 3: Yukawa Tensor Computation')
print('=' * 70)

def permutation_sign(perm):
    """Sign of a permutation."""
    n = len(perm)
    inv = sum(1 for i in range(n) for j in range(i+1, n) if perm[i] > perm[j])
    return (-1) ** inv


# Precompute the wedge coefficient table for 2+2+3 = 7 form:
# For basis forms omega_I = dx^a ^ dx^b and omega_J = dx^c ^ dx^d
# and rho_A = dx^e ^ dx^f ^ dx^g:
#
# (omega_I ^ omega_J ^ rho_A)_{1234567} = sign(a,b,c,d,e,f,g)
#   if {a,b} U {c,d} U {e,f,g} = {0,1,2,3,4,5,6} (disjoint union)
#   else 0.
#
# We precompute W[I,J,alpha] for all I,J in IDX_2, alpha in IDX_3.

print('  Precomputing wedge coefficient table W[I,J,alpha]...')
t_w = time.time()

W_table = np.zeros((N_2FORM, N_2FORM, N_3FORM), dtype=np.float64)

for I in range(N_2FORM):
    a, b = IDX_2[I]
    for J in range(N_2FORM):
        c, d = IDX_2[J]
        # Check disjointness of {a,b} and {c,d}
        if len({a, b} & {c, d}) > 0:
            continue
        for alpha in range(N_3FORM):
            e, f, g = IDX_3[alpha]
            # Check that {a,b,c,d,e,f,g} = {0,...,6}
            all_idx = {a, b, c, d, e, f, g}
            if len(all_idx) == 7:
                W_table[I, J, alpha] = permutation_sign([a, b, c, d, e, f, g])

W_table_t = torch.from_numpy(W_table).to(device=device, dtype=dtype)

n_nonzero = np.count_nonzero(W_table)
print(f'  W table: {N_2FORM}x{N_2FORM}x{N_3FORM}, nonzero entries: {n_nonzero}')
print(f'  Antisymmetry check W[I,J,:] = -W[J,I,:]: ',
      np.allclose(W_table, -W_table.transpose(1, 0, 2)))
print(f'  ({time.time()-t_w:.1f}s)')

In [None]:
# ============================================================
# Cell 11: Yukawa Tensor via MC Integration
# ============================================================
print('\n  Computing Yukawa tensor Y_{IJA}...')
print(f'  Shape: ({B2}, {B2}, {B3}) = {B2*B2*B3:,} components')
print(f'  MC points: {N_MC_EFF:,} (filtered)')

t_yuk = time.time()

# Y_{IJA} = integral (omega_I ^ omega_J ^ rho_A)
#         = integral sum_{alpha,beta,gamma} T2[I,alpha] T2[J,beta]
#           W[alpha,beta,gamma] (T3 @ form3_coeffs)_A,gamma * sqrtdet
#
# In the orthonormal basis:
# omega_I(x) = sum_alpha T2[alpha, I] * e^{IDX_2[alpha]}     (constant forms)
# rho_A(x)   = sum_a T3[a, A] * form3_coeffs[a, x, :]       (may vary with x)
#
# The Yukawa integral:
# Y_{IJA} = (1/N) sum_x sum_{alpha,beta} T2[alpha,I] T2[beta,J]
#           * sum_gamma W[alpha,beta,gamma] rho_A_gamma(x) * sqrtdet(x)
#
# where rho_A_gamma(x) = sum_a T3[a,A] * form3_coeffs[a, x, gamma]

# Step 1: Transform 3-form coefficients to orthonormal basis
# rho_orth[A, x, gamma] = sum_a T3[a, A] * form3_coeffs[a, x, gamma]
print('  Step 1: Transform 3-forms to orthonormal basis...')
# T3: (77, n_eff) or (77, 77), form3_coeffs: (77, N_MC_EFF, 35)
# rho_orth = T3^T @ form3_coeffs (matrix mult on first axis)
rho_orth = torch.einsum('aA,axg->Axg', T3.T, form3_coeffs)  # (B3, N_MC_EFF, 35)

# Step 2: At each MC point, compute the weighted rho contribution:
# R[x, gamma] = rho_orth[A, x, gamma] * sqrtdet(x)
# Then Y_{IJA} = (1/N) sum_x sum_{alpha,beta,gamma}
#                 T2[alpha,I] T2[beta,J] W[alpha,beta,gamma] R_A[x,gamma]

# Step 3: Contract W with T2 to get effective wedge in orthonormal 2-form basis
# W_orth[I,J,gamma] = sum_{alpha,beta} T2[alpha,I] T2[beta,J] W[alpha,beta,gamma]
print('  Step 2: Contract wedge table with 2-form transform...')
T2_np = T2.cpu().numpy()
W_orth = np.einsum('aI,bJ,abg->IJg', T2_np, T2_np, W_table)  # (21, 21, 35)
W_orth_t = torch.from_numpy(W_orth).to(device=device, dtype=dtype)

# Step 4: MC integration
# Y_{IJA} = (1/N) sum_x W_orth[I,J,gamma] * rho_orth[A,x,gamma] * sqrtdet(x)
print('  Step 3: MC integration...')

Y = torch.zeros(B2, B2, B3, device=device, dtype=dtype)

for batch_start in range(0, N_MC_EFF, BATCH):
    batch_end = min(batch_start + BATCH, N_MC_EFF)
    rho_b = rho_orth[:, batch_start:batch_end, :]  # (B3, bs, 35)
    sd_b = sqrtdet_mc[batch_start:batch_end]        # (bs,)
    
    # weighted_rho[A, bs, gamma] = rho_b[A, bs, gamma] * sqrtdet
    weighted_rho = rho_b * sd_b.unsqueeze(0).unsqueeze(-1)
    
    # Sum over x and gamma:
    # Y[I,J,A] += sum_x sum_gamma W_orth[I,J,gamma] * weighted_rho[A,x,gamma]
    # = W_orth[I,J,gamma] . sum_x weighted_rho[A,x,gamma]
    sum_rho = weighted_rho.sum(dim=1)  # (B3, 35)
    Y += torch.einsum('IJg,Ag->IJA', W_orth_t, sum_rho)

Y = Y / N_MC_EFF

elapsed_yuk = time.time() - t_yuk
print(f'\n  Yukawa tensor computed in {elapsed_yuk:.1f}s')
print(f'  Shape: {tuple(Y.shape)}')
print(f'  ||Y||_F = {torch.norm(Y).item():.6e}')
print(f'  max|Y| = {Y.abs().max().item():.6e}')
print(f'  Antisymmetry Y[I,J,A]=-Y[J,I,A]: ',
      torch.allclose(Y, -Y.permute(1,0,2), atol=1e-10))

# Count effective non-zero couplings
threshold = 1e-8 * Y.abs().max().item() if Y.abs().max().item() > 0 else 1e-15
n_nonzero_Y = (Y.abs() > threshold).sum().item()
print(f'  Nonzero couplings (>{threshold:.2e}): {n_nonzero_Y:,} / {B2*B2*B3:,}')

## Phase 4: Physical Observables

Extract fermion mass matrices, mass ratios, and mixing angles from the
Yukawa tensor.

**Mass matrix**: $M_{IJ}^{(A)} = Y_{IJA} \cdot v^A$ (VEV contraction)

**Mass eigenvalues**: SVD of $M_{IJ}$

**Mixing matrix**: CKM-like matrix from bi-unitary diagonalization

In [None]:
# ============================================================
# Cell 12: Mass Matrices & Eigenvalues
# ============================================================
print('\n' + '=' * 70)
print('  PHASE 4: Physical Observables')
print('=' * 70)

Y_np = Y.cpu().numpy()

# === Correlation matrix ===
# C_{AB} = sum_{I,J} Y_{IJA} Y_{IJB}
C = np.einsum('IJA,IJB->AB', Y_np, Y_np)
C_eigs = np.linalg.eigvalsh(C)[::-1]
n_significant = np.sum(C_eigs > 1e-10 * C_eigs[0])

print(f'\n  Correlation matrix C = Y^T Y:')
print(f'    Shape: ({B3}, {B3})')
print(f'    Top 10 eigenvalues: {C_eigs[:10]}')
print(f'    Significant modes: {n_significant}')

# === Flatten to (B2*B2, B3) and SVD ===
Y_flat = Y_np.reshape(B2 * B2, B3)
U, sigma, Vh = np.linalg.svd(Y_flat, full_matrices=False)

print(f'\n  SVD of Y_flat ({B2*B2}x{B3}):')
print(f'    Top 10 singular values: {sigma[:10]}')
print(f'    Effective rank (>1e-8*sigma_1): {np.sum(sigma > 1e-8*sigma[0])}')

# === Mass matrices from VEV contraction ===
# Choose VEV direction: uniform over 3-form moduli (democratic scenario)
# v_A = delta_{A,0}  (simplest: single modulus gets VEV)
# More physical: v along the volume mode (A=0 is the phi_0 direction)

print(f'\n  --- Mass Matrix Analysis ---')

def extract_masses(Y_tensor, vev_direction, label=''):
    """Extract mass eigenvalues from Yukawa tensor given VEV direction."""
    M = np.einsum('IJA,A->IJ', Y_tensor, vev_direction)
    # Mass matrix is antisymmetric in I,J (from 2-form wedge)
    # Physical masses come from M^dagger M
    MdagM = M.T @ M
    eigs = np.linalg.eigvalsh(MdagM)[::-1]
    masses = np.sqrt(np.maximum(eigs, 0))
    
    # Top 3 generations
    m = masses[:3]
    if label:
        print(f'\n  {label}:')
        print(f'    Top 3 masses: {m}')
        if m[0] > 0:
            print(f'    Ratios: m2/m1={m[1]/m[0]:.6f}, m3/m1={m[2]/m[0]:.6f}')
            if m[2] > 0:
                print(f'    Hierarchy: m1/m3={m[0]/m[2]:.6e}')
    return masses, M

# VEV scenario 1: Single modulus (volume mode)
v1 = np.zeros(B3); v1[0] = 1.0
masses_v1, M_v1 = extract_masses(Y_np, v1, 'VEV along volume mode (A=0)')

# VEV scenario 2: Democratic (equal weight)
v2 = np.ones(B3) / np.sqrt(B3)
masses_v2, M_v2 = extract_masses(Y_np, v2, 'VEV democratic (uniform over all moduli)')

# VEV scenario 3: Fano-aligned (7 Fano triples get enhanced VEV)
FANO_TRIPLE_INDICES = []
for (i,j,k) in FANO_LINES:
    triple = tuple(sorted([i,j,k]))
    if triple in IDX_3_MAP:
        FANO_TRIPLE_INDICES.append(IDX_3_MAP[triple])
v3 = np.zeros(B3)
for fi in FANO_TRIPLE_INDICES:
    v3[fi] = 1.0
v3 /= np.linalg.norm(v3)
masses_v3, M_v3 = extract_masses(Y_np, v3, 'VEV along Fano-aligned modes')

In [None]:
# ============================================================
# Cell 13: CKM-like Mixing Matrix & Koide Formula
# ============================================================

print('\n  --- Mixing Matrix & Physical Ratios ---')

# Use the VEV scenario with best hierarchy for mixing analysis
# Try all and pick the one with largest m1/m3 ratio
best_masses = masses_v2  # democratic often gives richer structure
best_M = M_v2
best_label = 'democratic VEV'

for masses, M, label in [(masses_v1, M_v1, 'volume'), (masses_v2, M_v2, 'democratic'),
                          (masses_v3, M_v3, 'Fano')]:
    if masses[0] > 0 and masses[2] > 0:
        ratio = masses[0] / masses[2]
        if ratio > (best_masses[0] / best_masses[2] if best_masses[2] > 0 else 0):
            best_masses = masses
            best_M = M
            best_label = label

print(f'  Using {best_label} VEV for mixing analysis')

# Extract mixing matrix via bi-unitary decomposition
# M = U_L @ diag(m) @ U_R^dag
# CKM = U_L(up)^dag @ U_L(down)
#
# For our single Yukawa tensor, we split into "up-type" and "down-type"
# by contracting with different VEV directions:

# Up-type: VEV along first local mode
v_up = np.zeros(B3); v_up[0] = 1.0
M_up = np.einsum('IJA,A->IJ', Y_np, v_up)[:6, :6]  # 6x6 block

# Down-type: VEV along second local mode
v_down = np.zeros(B3); v_down[1] = 1.0
M_down = np.einsum('IJA,A->IJ', Y_np, v_down)[:6, :6]

# SVD
U_up, s_up, Vh_up = np.linalg.svd(M_up)
U_down, s_down, Vh_down = np.linalg.svd(M_down)

# CKM-like mixing matrix (3x3 block)
V_ckm = U_up[:3, :3].T @ U_down[:3, :3]

print(f'\n  CKM-like mixing matrix (3x3):')
for i in range(3):
    row = ' '.join(f'{abs(V_ckm[i,j]):.4f}' for j in range(3))
    print(f'    |V| = [{row}]')

# Mixing angles (standard parameterization)
theta_12 = np.arcsin(abs(V_ckm[0, 1]))
theta_23 = np.arcsin(abs(V_ckm[1, 2]))
theta_13 = np.arcsin(abs(V_ckm[0, 2]))

print(f'\n  Mixing angles:')
print(f'    theta_12 = {np.degrees(theta_12):.2f} deg')
print(f'    theta_23 = {np.degrees(theta_23):.2f} deg')
print(f'    theta_13 = {np.degrees(theta_13):.2f} deg')

# Koide formula: Q = (m1 + m2 + m3)^2 / (3 * (m1^2 + m2^2 + m3^2))
# Experimental: Q ~ 2/3 for charged leptons
m_top3 = best_masses[:3]
if np.sum(m_top3**2) > 0:
    Q_koide = (np.sum(m_top3))**2 / (3 * np.sum(m_top3**2))
else:
    Q_koide = 0.0

print(f'\n  Koide formula:')
print(f'    Q = (sum m_i)^2 / (3 * sum m_i^2) = {Q_koide:.6f}')
print(f'    Experimental (leptons): Q ~ 0.6667')
print(f'    Deviation: {abs(Q_koide - 2/3):.4f}')

In [None]:
# ============================================================
# Cell 14: Spectral Analysis of Yukawa Tensor
# ============================================================
print('\n  --- Spectral Decomposition ---')

# Tensor decomposition: Y_{IJA} ~ sum_r lambda_r u_r(I) v_r(J) w_r(A)
# Use unfolded SVD for rank analysis

# Mode-1 unfolding (flatten J,A)
Y1 = Y_np.reshape(B2, B2 * B3)
_, s1, _ = np.linalg.svd(Y1, full_matrices=False)

# Mode-3 unfolding (flatten I,J)
Y3 = Y_np.reshape(B2 * B2, B3)
_, s3, _ = np.linalg.svd(Y3, full_matrices=False)

print(f'\n  Mode-1 singular values (top 10): {s1[:10]}')
print(f'  Mode-3 singular values (top 10): {s3[:10]}')

# Effective tensor rank
def effective_rank(sigmas):
    return np.sum(sigmas > 1e-8 * sigmas[0]) if sigmas[0] > 0 else 0

rank_1 = effective_rank(s1)
rank_3 = effective_rank(s3)
print(f'\n  Effective rank:')
print(f'    Mode-1 (2-form space): {rank_1}')
print(f'    Mode-3 (3-form space): {rank_3}')

# === Coupling strength distribution ===
Y_abs = np.abs(Y_np)
Y_sorted = np.sort(Y_abs.ravel())[::-1]
n_above = [np.sum(Y_abs > thresh) for thresh in [1e-2, 1e-4, 1e-6, 1e-8]]

print(f'\n  Coupling distribution:')
print(f'    |Y| > 1e-2: {n_above[0]:,}')
print(f'    |Y| > 1e-4: {n_above[1]:,}')
print(f'    |Y| > 1e-6: {n_above[2]:,}')
print(f'    |Y| > 1e-8: {n_above[3]:,}')
print(f'    Total: {B2*B2*B3:,}')

# === Generation structure ===
# The 21 2-forms naturally split into groups by the TCS building blocks:
# M1 contributes b2(M1)=11 2-forms, M2 contributes b2(M2)=10
# Check if the Yukawa tensor has block structure

# Block norms
Y_M1M1 = np.linalg.norm(Y_np[:B2_M1, :B2_M1, :])
Y_M2M2 = np.linalg.norm(Y_np[B2_M1:, B2_M1:, :])
Y_M1M2 = np.linalg.norm(Y_np[:B2_M1, B2_M1:, :])

print(f'\n  Block structure (M1={B2_M1}, M2={B2_M2}):')
print(f'    ||Y[M1,M1,:]|| = {Y_M1M1:.6e}')
print(f'    ||Y[M2,M2,:]|| = {Y_M2M2:.6e}')
print(f'    ||Y[M1,M2,:]|| = {Y_M1M2:.6e}  (cross-block)')

In [None]:
# ============================================================
# Cell 15: Results Summary & Export
# ============================================================
total_time = time.time() - T_START

print('\n' + '=' * 70)
print('  YUKAWA TENSOR COMPUTATION — SUMMARY')
print('=' * 70)

print(f'''
  ATLAS:
    Parameters: {n_params:,}
    det(g) mean: {torch.linalg.det(g_mc).mean().item():.6f} (target: {DET_G_TARGET})
    Metric positive definite: True

  HARMONIC FORMS:
    2-forms: {B2} (orthonormalized, condition {eig2.max().item()/eig2.min().item():.2f})
    3-forms: {n_effective} effective / {B3} target
      35 local + 21 M1-global + 21 M2-global

  YUKAWA TENSOR:
    Shape: {B2} x {B2} x {B3}
    ||Y||_F = {torch.norm(Y).item():.6e}
    Antisymmetric: Y[I,J,A] = -Y[J,I,A]
    Effective rank (mode-1): {rank_1}
    Effective rank (mode-3): {rank_3}

  MASS SPECTRUM (democratic VEV):
    Top 3 masses: {best_masses[:3]}
    Koide Q = {Q_koide:.4f} (experimental: 0.6667)

  MIXING ANGLES:
    theta_12 = {np.degrees(theta_12):.2f} deg
    theta_23 = {np.degrees(theta_23):.2f} deg
    theta_13 = {np.degrees(theta_13):.2f} deg

  MC SAMPLING:
    Original: {N_MC:,} points
    After outlier filter: {N_MC_EFF:,} points

  TIMING:
    Total: {total_time:.0f}s ({total_time/60:.1f} min)
''')

# === Save results ===
results = {
    'version': 'Y1',
    'date': time.strftime('%Y-%m-%d %H:%M:%S'),
    'atlas': {
        'total_params': int(n_params),
        'det_g_mean': float(torch.linalg.det(g_mc).mean().item()),
        'det_g_target': float(DET_G_TARGET),
    },
    'topology': {
        'b2': int(B2), 'b3': int(B3), 'H_star': int(H_STAR),
        'b2_M1': int(B2_M1), 'b3_M1': int(B3_M1),
        'b2_M2': int(B2_M2), 'b3_M2': int(B3_M2),
    },
    'harmonic_forms': {
        'n_2forms': int(B2),
        'n_3forms_effective': int(n_effective),
        'S2_condition': float(eig2.max().item() / eig2.min().item()),
    },
    'yukawa': {
        'shape': [int(B2), int(B2), int(B3)],
        'norm_F': float(torch.norm(Y).item()),
        'max_abs': float(Y.abs().max().item()),
        'antisymmetric': bool(torch.allclose(Y, -Y.permute(1,0,2), atol=1e-10)),
        'rank_mode1': int(rank_1),
        'rank_mode3': int(rank_3),
        'svd_top10': s3[:10].tolist(),
    },
    'physics': {
        'masses_top3_democratic': best_masses[:3].tolist(),
        'koide_Q': float(Q_koide),
        'theta_12_deg': float(np.degrees(theta_12)),
        'theta_23_deg': float(np.degrees(theta_23)),
        'theta_13_deg': float(np.degrees(theta_13)),
        'V_ckm_abs': np.abs(V_ckm).tolist(),
    },
    'block_structure': {
        'Y_M1M1_norm': float(Y_M1M1),
        'Y_M2M2_norm': float(Y_M2M2),
        'Y_M1M2_norm': float(Y_M1M2),
    },
    'mc_params': {
        'n_mc': int(N_MC),
        'n_mc_effective': int(N_MC_EFF),
        'batch_size': int(BATCH),
    },
    'timing_seconds': float(total_time),
}

with open('outputs/gift_yukawa_results.json', 'w') as f:
    json.dump(results, f, indent=2)
print('Saved: outputs/gift_yukawa_results.json')

# Save tensor
torch.save(Y.cpu(), 'outputs/yukawa_tensor_Y1.pt')
print('Saved: outputs/yukawa_tensor_Y1.pt')

print('\n' + '=' * 70)
print('  DONE — Phase 3 Yukawa Pipeline Complete')
print('=' * 70)