# K7 TCS Hodge v2.0 - Proper TCS/Joyce Global Modes

**Major upgrade from v1.9b:**
- Replace polynomial/trigonometric global modes with proper TCS structure
- 77 H3 modes = 35 local (fiber) + 42 global (TCS gluing)
- 42 global = 14 left-weighted + 14 right-weighted + 14 neck-coupled
- Profile functions: smooth plateaus and Gaussian bumps
- Goal: achieve 43/77 visible/hidden split with tau = 3472/891

## TCS (Twisted Connected Sum) Construction

Joyce's G2 manifolds are built by gluing two asymptotically cylindrical Calabi-Yau 3-folds:
- Left building block: M_L with asymptotic cylinder
- Right building block: M_R with asymptotic cylinder  
- Neck region: where the gluing happens with twist

The 42 global modes arise from this gluing structure.

In [None]:
# @title Setup and Imports
# @markdown Run this cell first. Works on Colab, local Jupyter, or any Python environment.

import os
import sys
import json
import time
import csv
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, asdict
from typing import Optional, Tuple, Dict, List
from itertools import combinations
import warnings
warnings.filterwarnings('ignore')

import numpy as np

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Check environment
IN_COLAB = 'google.colab' in sys.modules
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("=" * 60)
print("K7 TCS HODGE v2.0 - Setup")
print("=" * 60)
print(f"Environment: {'Google Colab' if IN_COLAB else 'Local'}")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"PyTorch: {torch.__version__}")

In [None]:
# @title Configuration
# @markdown Adjust parameters here. TCS-specific settings control the profile functions.

@dataclass
class Config:
    """v2.0 Configuration with TCS parameters."""
    
    # === Geometry ===
    dim: int = 7
    b2_K7: int = 21          # Harmonic 2-forms
    b3_K7: int = 77          # Harmonic 3-forms
    b3_local: int = 35       # Local fiber modes
    b3_global: int = 42      # Global TCS modes
    
    # TCS global mode breakdown: 42 = 14 + 14 + 14
    n_left: int = 14         # Left-weighted modes
    n_right: int = 14        # Right-weighted modes
    n_neck: int = 14         # Neck-coupled modes
    
    # === Targets ===
    target_det_g: float = 2.03125        # 65/32
    target_kappa_T: float = 0.01639344   # 1/61
    tau_target: float = 3.8967452        # 3472/891
    
    # === TCS Profile Parameters ===
    lambda_L: float = -1.0   # Left boundary
    lambda_R: float = +1.0   # Right boundary
    lambda_neck: float = 0.0 # Neck center
    sigma_transition: float = 0.15  # Transition width for plateaus
    sigma_neck: float = 0.2         # Neck bump width
    
    # === Network Architecture ===
    hidden_dim: int = 256
    n_layers: int = 4
    
    # === Training ===
    n_epochs_h2: int = 3000
    n_epochs_h3: int = 8000  # More epochs for TCS learning
    lr_h2: float = 1e-3
    lr_h3: float = 3e-4      # Lower LR for stability
    batch_size: int = 2048
    weight_decay: float = 1e-5
    max_grad_norm: float = 1.0
    scheduler_patience: int = 400
    scheduler_factor: float = 0.5
    
    # === Loss Weights ===
    w_closed: float = 1.0
    w_coclosed: float = 1.0
    w_orthonormal: float = 0.1
    w_g2_compat: float = 0.5
    w_tcs_profile: float = 0.3   # NEW: TCS profile regularization
    
    # === Checkpointing ===
    checkpoint_every: int = 500
    checkpoint_dir: str = "checkpoints_v2_0"
    log_every: int = 100
    output_dir: str = "outputs_v2_0"

config = Config()

print("Configuration loaded:")
print(f"  Geometry: b2={config.b2_K7}, b3={config.b3_K7} ({config.b3_local} local + {config.b3_global} global)")
print(f"  TCS global: {config.n_left} left + {config.n_right} right + {config.n_neck} neck")
print(f"  Training: H2={config.n_epochs_h2} epochs, H3={config.n_epochs_h3} epochs")
print(f"  Targets: det(g)={config.target_det_g}, tau={config.tau_target:.4f}")

In [None]:
# @title TCS Profile Functions
# @markdown These profile functions encode the TCS (Twisted Connected Sum) geometry.
# @markdown - left_plateau: ~1 on left building block, ~0 elsewhere
# @markdown - right_plateau: ~1 on right building block, ~0 elsewhere  
# @markdown - neck_bump: peaked at neck, decays away

def smooth_step(x: torch.Tensor, x0: float = 0.0, width: float = 0.1) -> torch.Tensor:
    """Smooth sigmoid transition centered at x0 with given width."""
    w = max(width, 1e-8)
    t = (x - x0) / w
    return torch.sigmoid(5.0 * t)  # Steepness factor 5 for sharp-ish transition

def left_plateau(lam: torch.Tensor, config: Config) -> torch.Tensor:
    """
    Profile for left-weighted modes.
    Returns ~1 for lambda < lambda_neck, ~0 for lambda > lambda_neck.
    """
    return 1.0 - smooth_step(lam, x0=config.lambda_neck, width=config.sigma_transition)

def right_plateau(lam: torch.Tensor, config: Config) -> torch.Tensor:
    """
    Profile for right-weighted modes.
    Returns ~0 for lambda < lambda_neck, ~1 for lambda > lambda_neck.
    """
    return smooth_step(lam, x0=config.lambda_neck, width=config.sigma_transition)

def neck_bump(lam: torch.Tensor, config: Config) -> torch.Tensor:
    """
    Profile for neck-coupled modes.
    Gaussian bump centered at neck, decays to both sides.
    """
    t = (lam - config.lambda_neck) / max(config.sigma_neck, 1e-8)
    return torch.exp(-t * t)

def get_tcs_profiles(lam: torch.Tensor, config: Config) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Return all three TCS profiles for given lambda coordinates."""
    return left_plateau(lam, config), right_plateau(lam, config), neck_bump(lam, config)

# Test profiles
print("TCS Profile Functions defined.")
print("\nTesting on lambda in [-1, 1]:")
test_lam = torch.linspace(-1, 1, 11)
L, R, N = get_tcs_profiles(test_lam, config)
print(f"{'lambda':>8} | {'left':>6} | {'right':>6} | {'neck':>6}")
print("-" * 40)
for i, l in enumerate(test_lam):
    print(f"{l.item():8.2f} | {L[i].item():6.3f} | {R[i].item():6.3f} | {N[i].item():6.3f}")

In [None]:
# @title Load Data
# @markdown Loads calibrated K7 metric samples from v1.8/v1.9 or generates synthetic data.
# @markdown For TCS, the first coordinate x[0] is interpreted as the neck coordinate lambda.

def load_data(path: Optional[str] = None) -> Dict[str, torch.Tensor]:
    """Load K7 metric data from various possible locations."""
    paths_to_try = [
        path,
        "samples.npz",
        "../1_9b/outputs_v1_9b/samples.npz",
        "../1_8/samples.npz",
        "/content/samples.npz",
        "/content/drive/MyDrive/GIFT/G2_ML/1_8/samples.npz",
        "/content/drive/MyDrive/GIFT/G2_ML/1_9b/outputs_v1_9b/samples.npz",
    ]
    
    for p in paths_to_try:
        if p and os.path.exists(p):
            print(f"Loading from: {p}")
            data = np.load(p)
            result = {
                'coords': torch.from_numpy(data['coords']).float(),
                'metric': torch.from_numpy(data['metric']).float(),
            }
            # phi may be in different formats
            if 'phi' in data:
                phi = data['phi']
                if phi.ndim == 2 and phi.shape[1] == 35:
                    result['phi'] = torch.from_numpy(phi).float()
                elif phi.ndim == 3:
                    # Take diagonal or mean
                    result['phi'] = torch.from_numpy(phi).float().mean(dim=-1)
                else:
                    result['phi'] = torch.from_numpy(phi).float()
            return result
    
    # Generate synthetic data with TCS-aware coordinates
    print("Data not found, generating synthetic with TCS structure...")
    n = 8000
    
    # First coordinate is lambda (neck coordinate) in [-1, 1]
    # Others are angular coordinates in [0, 2pi]
    coords = torch.zeros(n, 7)
    coords[:, 0] = torch.rand(n) * 2 - 1  # lambda in [-1, 1]
    coords[:, 1:] = torch.rand(n, 6) * 2 * np.pi  # xi in [0, 2pi]
    
    # Metric with det(g) ~ 65/32
    scale = config.target_det_g ** (1/7)
    metric = torch.eye(7).unsqueeze(0).expand(n, -1, -1).clone() * scale
    metric = metric + 0.02 * torch.randn(n, 7, 7)
    metric = 0.5 * (metric + metric.transpose(-1, -2))  # Symmetrize
    
    # Phi: G2 3-form components (35 basis elements)
    phi = torch.randn(n, 35) * 0.3
    # Normalize to ||phi||^2 ~ 7
    phi = phi * np.sqrt(7.0) / (torch.norm(phi, dim=1, keepdim=True) + 1e-8)
    
    return {'coords': coords, 'metric': metric, 'phi': phi}

# Load data
data = load_data()
n_samples = data['coords'].shape[0]

# Compute statistics
det_g = torch.det(data['metric'])
det_g_mean = det_g.mean().item()
det_g_std = det_g.std().item()

# Lambda distribution
lam = data['coords'][:, 0]
lam_min, lam_max = lam.min().item(), lam.max().item()

print(f"\nData loaded:")
print(f"  Samples: {n_samples}")
print(f"  det(g): {det_g_mean:.6f} +/- {det_g_std:.4f} (target: {config.target_det_g})")
print(f"  Lambda range: [{lam_min:.2f}, {lam_max:.2f}]")
print(f"  Phi shape: {data['phi'].shape}")

In [None]:
# @title H2 Network (21 harmonic 2-forms)
# @markdown Standard architecture for 2-forms, unchanged from v1.9b.

class H2Network(nn.Module):
    """
    Network for 21 harmonic 2-forms on K7.
    Output: (batch, 21, 21) - 21 forms, each with 21 = C(7,2) components.
    """
    
    def __init__(self, config: Config):
        super().__init__()
        
        # Shared feature extractor
        layers = []
        in_dim = config.dim
        for _ in range(config.n_layers):
            layers.extend([
                nn.Linear(in_dim, config.hidden_dim),
                nn.SiLU(),
            ])
            in_dim = config.hidden_dim
        self.features = nn.Sequential(*layers)
        
        # 21 separate heads for each 2-form
        # Each outputs 21 components (antisymmetric 2-tensor has C(7,2) = 21 components)
        self.heads = nn.ModuleList([
            nn.Linear(config.hidden_dim, 21) for _ in range(config.b2_K7)
        ])
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, 7) coordinates
        Returns:
            omega: (batch, 21, 21) - 21 2-forms with 21 components each
        """
        f = self.features(x)
        forms = [head(f) for head in self.heads]
        return torch.stack(forms, dim=1)

# Test
_test_h2 = H2Network(config)
_test_out = _test_h2(torch.randn(4, 7))
print(f"H2Network defined.")
print(f"  Parameters: {sum(p.numel() for p in _test_h2.parameters()):,}")
print(f"  Output shape: {_test_out.shape} (batch, 21 forms, 21 components)")
del _test_h2, _test_out

In [None]:
# @title H3 TCS Network (77 harmonic 3-forms with TCS structure)
# @markdown **KEY v2.0 CHANGE**: Global modes use proper TCS profiles.
# @markdown - 35 local modes: fiber modes, independent of neck coordinate
# @markdown - 14 left-weighted: multiplied by left_plateau(lambda)
# @markdown - 14 right-weighted: multiplied by right_plateau(lambda)
# @markdown - 14 neck-coupled: multiplied by neck_bump(lambda)

class H3TCSNetwork(nn.Module):
    """
    Network for 77 harmonic 3-forms with TCS (Twisted Connected Sum) structure.
    
    The 77 = 35 + 42 modes are:
    - 35 local modes: fiber Λ³(R⁷) modes, no special profile
    - 42 global TCS modes:
        - 14 left-weighted: profile = left_plateau(λ)
        - 14 right-weighted: profile = right_plateau(λ)
        - 14 neck-coupled: profile = neck_bump(λ)
    
    Output: (batch, 77, 35) - 77 forms, each with 35 = C(7,3) components.
    """
    
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        
        # Shared feature extractor
        layers = []
        in_dim = config.dim
        for _ in range(config.n_layers):
            layers.extend([
                nn.Linear(in_dim, config.hidden_dim),
                nn.SiLU(),
            ])
            in_dim = config.hidden_dim
        self.features = nn.Sequential(*layers)
        
        # === Local modes (35) ===
        # These are fiber modes, independent of lambda
        self.local_heads = nn.ModuleList([
            nn.Linear(config.hidden_dim, 35) for _ in range(config.b3_local)
        ])
        
        # === Global TCS modes (42 = 14 + 14 + 14) ===
        # Each global mode has a base form multiplied by a profile
        
        # Left-weighted modes (14)
        self.left_heads = nn.ModuleList([
            nn.Linear(config.hidden_dim, 35) for _ in range(config.n_left)
        ])
        
        # Right-weighted modes (14)
        self.right_heads = nn.ModuleList([
            nn.Linear(config.hidden_dim, 35) for _ in range(config.n_right)
        ])
        
        # Neck-coupled modes (14)
        self.neck_heads = nn.ModuleList([
            nn.Linear(config.hidden_dim, 35) for _ in range(config.n_neck)
        ])
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, 7) coordinates, x[:, 0] is the neck coordinate lambda
        Returns:
            Phi: (batch, 77, 35) - 77 3-forms with 35 components each
        """
        batch = x.shape[0]
        lam = x[:, 0]  # Neck coordinate
        
        # Get TCS profiles
        prof_left = left_plateau(lam, self.config)      # (batch,)
        prof_right = right_plateau(lam, self.config)    # (batch,)
        prof_neck = neck_bump(lam, self.config)         # (batch,)
        
        # Feature extraction
        f = self.features(x)  # (batch, hidden_dim)
        
        # === Local modes (35) ===
        local_forms = torch.stack([h(f) for h in self.local_heads], dim=1)  # (batch, 35, 35)
        
        # === Global TCS modes ===
        # Left-weighted: base_form * left_plateau(lambda)
        left_base = torch.stack([h(f) for h in self.left_heads], dim=1)  # (batch, 14, 35)
        left_forms = left_base * prof_left.unsqueeze(-1).unsqueeze(-1)   # Apply profile
        
        # Right-weighted: base_form * right_plateau(lambda)
        right_base = torch.stack([h(f) for h in self.right_heads], dim=1)  # (batch, 14, 35)
        right_forms = right_base * prof_right.unsqueeze(-1).unsqueeze(-1)
        
        # Neck-coupled: base_form * neck_bump(lambda)
        neck_base = torch.stack([h(f) for h in self.neck_heads], dim=1)  # (batch, 14, 35)
        neck_forms = neck_base * prof_neck.unsqueeze(-1).unsqueeze(-1)
        
        # Concatenate: [local (35), left (14), right (14), neck (14)] = 77
        Phi = torch.cat([local_forms, left_forms, right_forms, neck_forms], dim=1)
        
        return Phi

# Test
_test_h3 = H3TCSNetwork(config)
_test_x = torch.randn(4, 7)
_test_x[:, 0] = torch.tensor([-0.8, -0.2, 0.2, 0.8])  # Various lambda values
_test_out = _test_h3(_test_x)

print(f"H3TCSNetwork defined.")
print(f"  Parameters: {sum(p.numel() for p in _test_h3.parameters()):,}")
print(f"  Output shape: {_test_out.shape} (batch, 77 forms, 35 components)")
print(f"\nTCS mode structure:")
print(f"  Modes 0-34: Local fiber modes")
print(f"  Modes 35-48: Left-weighted (left_plateau)")
print(f"  Modes 49-62: Right-weighted (right_plateau)")
print(f"  Modes 63-76: Neck-coupled (neck_bump)")

# Check profile effect
print(f"\nProfile effect at different lambda:")
for i, l in enumerate([-0.8, -0.2, 0.2, 0.8]):
    norms = _test_out[i].norm(dim=1)
    print(f"  lambda={l:+.1f}: local={norms[:35].mean():.3f}, left={norms[35:49].mean():.3f}, right={norms[49:63].mean():.3f}, neck={norms[63:].mean():.3f}")

del _test_h3, _test_x, _test_out

In [None]:
# @title Loss Functions
# @markdown Losses enforce: harmonicity (d=0, d*=0), orthonormality, G2 compatibility.

def gram_matrix(forms: torch.Tensor, metric: torch.Tensor) -> torch.Tensor:
    """
    Compute Gram matrix G_ij = <form_i, form_j> with metric weighting.
    
    Args:
        forms: (batch, n_forms, n_components)
        metric: (batch, 7, 7)
    Returns:
        G: (n_forms, n_forms) averaged Gram matrix
    """
    det_g = torch.det(metric)
    vol = torch.sqrt(det_g.abs()).unsqueeze(-1).unsqueeze(-1)  # (batch, 1, 1)
    weighted = forms * vol  # Volume weighting
    G = torch.einsum('bic,bjc->ij', weighted, forms) / forms.shape[0]
    return G

def orthonormality_loss(G: torch.Tensor) -> torch.Tensor:
    """Loss for enforcing G = I (orthonormal forms)."""
    I = torch.eye(G.shape[0], device=G.device)
    return torch.mean((G - I) ** 2)

def closedness_loss_fd(x: torch.Tensor, model: nn.Module, eps: float = 1e-4) -> torch.Tensor:
    """
    Finite-difference approximation of d(omega) = 0 (closedness).
    Penalizes large gradients of form components.
    """
    total = torch.tensor(0.0, device=x.device)
    omega_0 = model(x)
    
    for c in range(7):
        x_plus = x.clone()
        x_minus = x.clone()
        x_plus[:, c] += eps
        x_minus[:, c] -= eps
        
        grad = (model(x_plus) - model(x_minus)) / (2 * eps)
        total = total + torch.mean(grad ** 2)
    
    return total / 7

def g2_compatibility_loss(Phi: torch.Tensor, phi_ref: torch.Tensor) -> torch.Tensor:
    """
    Loss for G2 3-form compatibility.
    The diagonal components should match the reference G2 3-form.
    
    Args:
        Phi: (batch, 77, 35) - 77 3-forms
        phi_ref: (batch, 35) - reference G2 form components
    """
    # Extract diagonal from local modes (first 35)
    local_diag = Phi[:, :35, :].diagonal(dim1=1, dim2=2)  # (batch, 35)
    return torch.mean((local_diag - phi_ref) ** 2)

def tcs_profile_regularization(Phi: torch.Tensor, lam: torch.Tensor, config: Config) -> torch.Tensor:
    """
    NEW v2.0: Regularize global modes to follow their expected profiles.
    
    This encourages:
    - Left modes to be large when lambda < 0
    - Right modes to be large when lambda > 0
    - Neck modes to be large when lambda ~ 0
    """
    batch = Phi.shape[0]
    
    # Get expected profiles
    prof_left = left_plateau(lam, config)
    prof_right = right_plateau(lam, config)
    prof_neck = neck_bump(lam, config)
    
    # Global mode norms (modes 35-76)
    left_norms = Phi[:, 35:49, :].norm(dim=2).mean(dim=1)   # (batch,)
    right_norms = Phi[:, 49:63, :].norm(dim=2).mean(dim=1)  # (batch,)
    neck_norms = Phi[:, 63:77, :].norm(dim=2).mean(dim=1)   # (batch,)
    
    # Normalize profiles to similar scale as norms
    scale = (left_norms.mean() + right_norms.mean() + neck_norms.mean()) / 3 + 1e-8
    
    # Loss: mode norms should follow their profiles
    loss_left = torch.mean((left_norms / scale - prof_left) ** 2)
    loss_right = torch.mean((right_norms / scale - prof_right) ** 2)
    loss_neck = torch.mean((neck_norms / scale - prof_neck) ** 2)
    
    return (loss_left + loss_right + loss_neck) / 3

print("Loss functions defined:")
print("  - gram_matrix: Hodge inner product")
print("  - orthonormality_loss: G = I")
print("  - closedness_loss_fd: d(omega) = 0")
print("  - g2_compatibility_loss: match G2 form")
print("  - tcs_profile_regularization: TCS structure (NEW)")

In [None]:
# @title Checkpointing and Auto-Resume
# @markdown Saves checkpoints regularly. Training resumes automatically if interrupted.

def save_checkpoint(path: str, epoch: int, model: nn.Module, optimizer: optim.Optimizer,
                   scheduler, losses: Dict, best_loss: float, phase: str):
    """Save training checkpoint."""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'losses': losses,
        'best_loss': best_loss,
        'phase': phase,
        'config': asdict(config),
    }, path)

def load_checkpoint(path: str, model: nn.Module, optimizer: optim.Optimizer,
                   scheduler=None) -> Dict:
    """Load checkpoint and restore state."""
    ckpt = torch.load(path, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    if scheduler and ckpt.get('scheduler_state_dict'):
        scheduler.load_state_dict(ckpt['scheduler_state_dict'])
    print(f"  Resumed from epoch {ckpt['epoch']} (best loss: {ckpt['best_loss']:.2e})")
    return ckpt

def find_latest_checkpoint(checkpoint_dir: str, phase: str) -> Optional[str]:
    """Find most recent checkpoint for a phase."""
    if not os.path.exists(checkpoint_dir):
        return None
    pattern = f"{phase}_epoch_*.pt"
    ckpts = sorted(Path(checkpoint_dir).glob(pattern))
    return str(ckpts[-1]) if ckpts else None

def format_time(seconds: float) -> str:
    """Format seconds as HH:MM:SS."""
    h = int(seconds // 3600)
    m = int((seconds % 3600) // 60)
    s = int(seconds % 60)
    return f"{h:02d}:{m:02d}:{s:02d}"

# Create checkpoint directory
os.makedirs(config.checkpoint_dir, exist_ok=True)
print(f"Checkpointing configured:")
print(f"  Directory: {config.checkpoint_dir}")
print(f"  Save every: {config.checkpoint_every} epochs")
print(f"  Log every: {config.log_every} epochs")

In [None]:
# @title Phase 1: Train H2 (21 harmonic 2-forms)
# @markdown Text-only monitoring. Auto-resumes from checkpoint if available.

def train_h2(config: Config, data: Dict[str, torch.Tensor], resume: bool = True) -> Tuple[nn.Module, Dict]:
    """Train H2 network with auto-resume capability."""
    
    print("=" * 70)
    print("PHASE 1: Training H2 (21 harmonic 2-forms)")
    print("=" * 70)
    
    # Initialize
    model = H2Network(config).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=config.lr_h2, weight_decay=config.weight_decay)
    scheduler = ReduceLROnPlateau(optimizer, patience=config.scheduler_patience, 
                                   factor=config.scheduler_factor, min_lr=1e-6)
    
    coords = data['coords'].to(device)
    metric = data['metric'].to(device)
    n = coords.shape[0]
    
    start_epoch = 0
    best_loss = float('inf')
    all_losses = {'total': [], 'ortho': [], 'closed': []}
    best_state = model.state_dict().copy()
    
    # Try to resume
    if resume:
        latest = find_latest_checkpoint(config.checkpoint_dir, 'h2')
        if latest:
            print(f"Found checkpoint: {latest}")
            ckpt = load_checkpoint(latest, model, optimizer, scheduler)
            start_epoch = ckpt['epoch'] + 1
            best_loss = ckpt['best_loss']
            all_losses = ckpt.get('losses', all_losses)
            best_state = model.state_dict().copy()
    
    if start_epoch == 0:
        print("Starting fresh training...")
    
    # Header
    print(f"\n{'Epoch':>6} | {'Loss':>10} | {'Ortho':>10} | {'Closed':>10} | {'Best':>10} | {'LR':>8} | {'Time':>8}")
    print("-" * 80)
    
    start_time = time.time()
    
    for epoch in range(start_epoch, config.n_epochs_h2):
        model.train()
        
        # Sample batch
        idx = torch.randperm(n)[:config.batch_size]
        x, g = coords[idx], metric[idx]
        
        # Forward
        omega = model(x)
        
        # Losses
        G = gram_matrix(omega, g)
        loss_ortho = orthonormality_loss(G)
        loss_closed = closedness_loss_fd(x, model)
        total = config.w_orthonormal * loss_ortho + config.w_closed * loss_closed
        
        # Backward
        optimizer.zero_grad()
        total.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
        optimizer.step()
        scheduler.step(total)
        
        # Track
        all_losses['total'].append(total.item())
        all_losses['ortho'].append(loss_ortho.item())
        all_losses['closed'].append(loss_closed.item())
        
        if total.item() < best_loss:
            best_loss = total.item()
            best_state = model.state_dict().copy()
        
        # Log
        if (epoch + 1) % config.log_every == 0:
            elapsed = time.time() - start_time
            lr = optimizer.param_groups[0]['lr']
            print(f"{epoch+1:6d} | {total.item():10.2e} | {loss_ortho.item():10.2e} | "
                  f"{loss_closed.item():10.2e} | {best_loss:10.2e} | {lr:8.1e} | {format_time(elapsed)}")
        
        # Checkpoint
        if (epoch + 1) % config.checkpoint_every == 0:
            ckpt_path = f"{config.checkpoint_dir}/h2_epoch_{epoch+1:05d}.pt"
            save_checkpoint(ckpt_path, epoch, model, optimizer, scheduler, all_losses, best_loss, 'h2')
    
    # Restore best
    model.load_state_dict(best_state)
    
    # Final checkpoint
    final_path = f"{config.checkpoint_dir}/h2_final.pt"
    save_checkpoint(final_path, config.n_epochs_h2 - 1, model, optimizer, scheduler, all_losses, best_loss, 'h2')
    
    print(f"\nH2 training complete.")
    print(f"  Best loss: {best_loss:.2e}")
    print(f"  Saved: {final_path}")
    
    return model, all_losses

# Train H2
h2_model, h2_losses = train_h2(config, data)

In [None]:
# @title Phase 2: Train H3 with TCS structure (77 harmonic 3-forms)
# @markdown **KEY v2.0**: Includes TCS profile regularization for global modes.

def train_h3_tcs(config: Config, data: Dict[str, torch.Tensor], resume: bool = True) -> Tuple[nn.Module, Dict]:
    """Train H3 TCS network with proper global mode structure."""
    
    print("=" * 70)
    print("PHASE 2: Training H3 with TCS Structure (77 harmonic 3-forms)")
    print("=" * 70)
    print("  35 local modes + 42 global TCS modes (14 left + 14 right + 14 neck)")
    
    # Initialize
    model = H3TCSNetwork(config).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=config.lr_h3, weight_decay=config.weight_decay)
    scheduler = ReduceLROnPlateau(optimizer, patience=config.scheduler_patience,
                                   factor=config.scheduler_factor, min_lr=1e-6)
    
    coords = data['coords'].to(device)
    metric = data['metric'].to(device)
    phi = data['phi'].to(device)
    n = coords.shape[0]
    
    start_epoch = 0
    best_loss = float('inf')
    all_losses = {'total': [], 'ortho': [], 'closed': [], 'g2': [], 'tcs': []}
    best_state = model.state_dict().copy()
    
    # Try to resume
    if resume:
        latest = find_latest_checkpoint(config.checkpoint_dir, 'h3')
        if latest:
            print(f"Found checkpoint: {latest}")
            ckpt = load_checkpoint(latest, model, optimizer, scheduler)
            start_epoch = ckpt['epoch'] + 1
            best_loss = ckpt['best_loss']
            all_losses = ckpt.get('losses', all_losses)
            best_state = model.state_dict().copy()
    
    if start_epoch == 0:
        print("Starting fresh training...")
    
    # Header
    print(f"\n{'Epoch':>6} | {'Loss':>9} | {'Ortho':>9} | {'Closed':>9} | {'G2':>9} | {'TCS':>9} | {'Best':>9} | {'Time':>8}")
    print("-" * 95)
    
    start_time = time.time()
    
    for epoch in range(start_epoch, config.n_epochs_h3):
        model.train()
        
        # Sample batch
        idx = torch.randperm(n)[:config.batch_size]
        x, g, p = coords[idx], metric[idx], phi[idx]
        lam = x[:, 0]  # Neck coordinate
        
        # Forward
        Phi = model(x)  # (batch, 77, 35)
        
        # Losses
        G = gram_matrix(Phi, g)
        loss_ortho = orthonormality_loss(G)
        loss_closed = closedness_loss_fd(x, model)
        loss_g2 = g2_compatibility_loss(Phi, p)
        loss_tcs = tcs_profile_regularization(Phi, lam, config)
        
        total = (config.w_orthonormal * loss_ortho + 
                config.w_closed * loss_closed + 
                config.w_g2_compat * loss_g2 +
                config.w_tcs_profile * loss_tcs)
        
        # Backward
        optimizer.zero_grad()
        total.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
        optimizer.step()
        scheduler.step(total)
        
        # Track
        all_losses['total'].append(total.item())
        all_losses['ortho'].append(loss_ortho.item())
        all_losses['closed'].append(loss_closed.item())
        all_losses['g2'].append(loss_g2.item())
        all_losses['tcs'].append(loss_tcs.item())
        
        if total.item() < best_loss:
            best_loss = total.item()
            best_state = model.state_dict().copy()
        
        # Log
        if (epoch + 1) % config.log_every == 0:
            elapsed = time.time() - start_time
            print(f"{epoch+1:6d} | {total.item():9.2e} | {loss_ortho.item():9.2e} | "
                  f"{loss_closed.item():9.2e} | {loss_g2.item():9.2e} | {loss_tcs.item():9.2e} | "
                  f"{best_loss:9.2e} | {format_time(elapsed)}")
        
        # Checkpoint
        if (epoch + 1) % config.checkpoint_every == 0:
            ckpt_path = f"{config.checkpoint_dir}/h3_epoch_{epoch+1:05d}.pt"
            save_checkpoint(ckpt_path, epoch, model, optimizer, scheduler, all_losses, best_loss, 'h3')
    
    # Restore best
    model.load_state_dict(best_state)
    
    # Final checkpoint
    final_path = f"{config.checkpoint_dir}/h3_final.pt"
    save_checkpoint(final_path, config.n_epochs_h3 - 1, model, optimizer, scheduler, all_losses, best_loss, 'h3')
    
    print(f"\nH3 TCS training complete.")
    print(f"  Best loss: {best_loss:.2e}")
    print(f"  Saved: {final_path}")
    
    return model, all_losses

# Train H3 with TCS structure
h3_model, h3_losses = train_h3_tcs(config, data)

In [None]:
# @title Phase 3: Compute Yukawa Tensor
# @markdown Y_ijk = integral(omega_i wedge omega_j wedge Phi_k) using proper Levi-Civita.

def levi_civita_7(indices: Tuple[int, ...]) -> int:
    """Compute Levi-Civita symbol for 7D indices."""
    if len(set(indices)) != 7:
        return 0
    inv = sum(1 for i in range(7) for j in range(i+1, 7) if indices[i] > indices[j])
    return 1 if inv % 2 == 0 else -1

def build_yukawa_coefficients() -> List[Tuple[int, int, int, int]]:
    """Build list of non-zero Yukawa wedge coefficients."""
    pairs = list(combinations(range(7), 2))    # 21 pairs for 2-forms
    triples = list(combinations(range(7), 3))  # 35 triples for 3-forms
    
    coeffs = []
    for i1, p1 in enumerate(pairs):
        for i2, p2 in enumerate(pairs):
            if i2 < i1:
                continue  # Symmetry
            for i3, t in enumerate(triples):
                all_idx = p1 + p2 + t
                if len(set(all_idx)) != 7:
                    continue  # Not a valid wedge
                sign = levi_civita_7(all_idx)
                if sign != 0:
                    coeffs.append((i1, i2, i3, sign))
    
    return coeffs

# Build coefficients once
YUKAWA_COEFFS = build_yukawa_coefficients()
print(f"Built {len(YUKAWA_COEFFS)} Yukawa wedge coefficients")

def compute_yukawa(h2_model: nn.Module, h3_model: nn.Module, 
                   data: Dict[str, torch.Tensor], 
                   coeffs: List[Tuple[int, int, int, int]],
                   n_pts: int = 5000) -> Dict[str, np.ndarray]:
    """Compute Yukawa tensor Y_ijk and Gram matrix M."""
    
    print("=" * 70)
    print("PHASE 3: Computing Yukawa Tensor (proper wedge product)")
    print("=" * 70)
    
    h2_model.eval()
    h3_model.eval()
    
    coords = data['coords'].to(device)
    metric = data['metric'].to(device)
    n = min(n_pts, coords.shape[0])
    idx = torch.randperm(coords.shape[0])[:n]
    x, g = coords[idx], metric[idx]
    
    # Volume element
    det_g = torch.det(g)
    vol = torch.sqrt(det_g.abs())
    total_vol = vol.sum()
    
    with torch.no_grad():
        omega = h2_model(x)  # (n, 21, 21)
        Phi = h3_model(x)    # (n, 77, 35)
    
    print(f"Integration points: {n}")
    print(f"omega shape: {omega.shape}")
    print(f"Phi shape: {Phi.shape}")
    print("Computing Y_ijk...")
    
    Y = torch.zeros(21, 21, 77, device=device)
    
    for a in range(21):
        if (a + 1) % 7 == 0:
            print(f"  H2 mode {a+1}/21")
        omega_a = omega[:, a, :]  # (n, 21)
        
        for b in range(a, 21):
            omega_b = omega[:, b, :]  # (n, 21)
            
            for c in range(77):
                Phi_c = Phi[:, c, :]  # (n, 35)
                
                # Compute wedge integral
                integral = torch.zeros(n, device=device)
                for i1, i2, i3, sign in coeffs:
                    integral += sign * omega_a[:, i1] * omega_b[:, i2] * Phi_c[:, i3]
                
                Y[a, b, c] = (integral * vol).sum() / total_vol
                if a != b:
                    Y[b, a, c] = -Y[a, b, c]  # Antisymmetry
    
    print("Computing Gram matrix M = Y^T Y...")
    # M_kl = sum_ij Y_ijk * Y_ijl
    M = torch.einsum('ijk,ijl->kl', Y, Y)
    
    print("Eigendecomposition...")
    eigenvalues, eigenvectors = torch.linalg.eigh(M)
    idx_sort = torch.argsort(eigenvalues, descending=True)
    eigenvalues = eigenvalues[idx_sort]
    eigenvectors = eigenvectors[:, idx_sort]
    
    print("Done.")
    
    return {
        'Y': Y.cpu().numpy(),
        'M': M.cpu().numpy(),
        'eigenvalues': eigenvalues.cpu().numpy(),
        'eigenvectors': eigenvectors.cpu().numpy(),
        'omega': omega.cpu().numpy(),
        'Phi': Phi.cpu().numpy(),
    }

# Compute Yukawa
yukawa = compute_yukawa(h2_model, h3_model, data, YUKAWA_COEFFS)

In [None]:
# @title Spectral Analysis
# @markdown Analyze Yukawa spectrum for 43/77 split and tau = 3472/891.

def analyze_spectrum(eigs: np.ndarray, tau_target: float = 3472/891) -> Dict:
    """Comprehensive spectral analysis of Yukawa Gram matrix."""
    
    print("=" * 70)
    print("YUKAWA SPECTRAL ANALYSIS")
    print("=" * 70)
    
    # Non-zero count
    nonzero_mask = np.abs(eigs) > 1e-10
    nonzero = nonzero_mask.sum()
    
    print(f"\n[EIGENVALUES]")
    print(f"  Total: {len(eigs)}, Non-zero: {nonzero}")
    print(f"  Top 5: {eigs[:5].round(4)}")
    print(f"  Around 35: eigs[32:38] = {eigs[32:38].round(6)}")
    print(f"  Around 43: eigs[40:46] = {eigs[40:46]}")
    
    # Gaps
    gaps = np.abs(np.diff(eigs))
    mean_gap = gaps.mean() if len(gaps) > 0 else 1.0
    
    print(f"\n[TOP 5 GAPS]")
    gap_order = np.argsort(gaps)[::-1]
    for i, idx in enumerate(gap_order[:5]):
        ratio = gaps[idx] / mean_gap if mean_gap > 0 else 0
        print(f"  #{i+1}: gap {idx}->{idx+1}: {gaps[idx]:.6f} ({ratio:.1f}x mean)")
    
    # Key positions
    print(f"\n[KEY POSITIONS]")
    for pos in [20, 21, 34, 35, 42, 43]:
        if pos < len(gaps):
            ratio = gaps[pos] / mean_gap if mean_gap > 0 else 0
            print(f"  Gap {pos}->{pos+1}: {gaps[pos]:.6f} ({ratio:.1f}x mean)")
    
    # Cumulative
    cumsum = np.cumsum(eigs)
    total = eigs.sum() if eigs.sum() > 0 else 1.0
    
    print(f"\n[CUMULATIVE VARIANCE]")
    for n in [21, 35, 43, 77]:
        if n <= len(eigs):
            pct = 100 * cumsum[n-1] / total
            print(f"  First {n}: {pct:.1f}%")
    
    # Tau search
    print(f"\n[TAU SEARCH] (target: {tau_target:.4f})")
    best_n, best_ratio, best_err = 0, 0, float('inf')
    
    for n in range(20, 55):
        if n < len(eigs):
            visible = cumsum[n-1]
            hidden = total - visible
            if hidden > 1e-8:
                ratio = visible / hidden
                err = 100 * abs(ratio - tau_target) / tau_target
                if err < best_err:
                    best_n, best_ratio, best_err = n, ratio, err
    
    if best_err < float('inf'):
        print(f"  Best: n={best_n}, tau={best_ratio:.4f}, error={best_err:.1f}%")
    else:
        print(f"  No valid tau (hidden sum ~ 0)")
    
    # Check specific tau values
    print(f"\n[TAU AT KEY POSITIONS]")
    for n in [35, 42, 43]:
        if n < len(eigs):
            visible = cumsum[n-1]
            hidden = total - visible
            if hidden > 1e-8:
                ratio = visible / hidden
                err = 100 * abs(ratio - tau_target) / tau_target
                print(f"  n={n}: tau={ratio:.4f}, error={err:.1f}%")
    
    # Verdict
    largest_gap_idx = np.argmax(gaps) if len(gaps) > 0 else 0
    n_visible = largest_gap_idx + 1
    
    gap_43 = gaps[42] if len(gaps) > 42 else 0
    gap_43_ratio = gap_43 / mean_gap if mean_gap > 0 else 0
    
    gap_35 = gaps[34] if len(gaps) > 34 else 0
    gap_35_ratio = gap_35 / mean_gap if mean_gap > 0 else 0
    
    print(f"\n[VERDICT]")
    print(f"  Largest gap at: {largest_gap_idx}->{largest_gap_idx+1}")
    print(f"  Suggested n_visible: {n_visible}")
    print(f"  Gap at 35: {gap_35_ratio:.2f}x mean")
    print(f"  Gap at 43: {gap_43_ratio:.2f}x mean")
    
    return {
        'n_visible': int(n_visible),
        'nonzero_count': int(nonzero),
        'largest_gap_idx': int(largest_gap_idx),
        'gap_35_ratio': float(gap_35_ratio),
        'gap_43_ratio': float(gap_43_ratio),
        'tau_best_n': int(best_n) if best_err < float('inf') else -1,
        'tau_estimate': float(best_ratio) if best_err < float('inf') else 0.0,
        'tau_error_pct': float(best_err) if best_err < float('inf') else -1.0,
    }

# Analyze
analysis = analyze_spectrum(yukawa['eigenvalues'])

In [None]:
# @title Save Outputs
# @markdown Save all results: models, Yukawa tensor, metrics, samples.

os.makedirs(config.output_dir, exist_ok=True)
print(f"Saving outputs to: {config.output_dir}")

# 1. Yukawa tensor and spectrum
np.savez(f"{config.output_dir}/yukawa.npz",
         Y=yukawa['Y'],
         M=yukawa['M'],
         eigenvalues=yukawa['eigenvalues'],
         eigenvectors=yukawa['eigenvectors'])
print("  yukawa.npz")

# 2. Models
torch.save({
    'h2': h2_model.state_dict(),
    'h3': h3_model.state_dict(),
    'config': asdict(config),
}, f"{config.output_dir}/models.pt")
print("  models.pt")

# 3. Metrics JSON
is_43 = 41 <= analysis['n_visible'] <= 45
is_35 = 33 <= analysis['n_visible'] <= 37
tau_ok = 0 < analysis['tau_error_pct'] < 15

det_g_mean = torch.det(data['metric']).mean().item()

metrics = {
    'version': '2.0',
    'tcs_structure': True,
    'geometry': {
        'det_g_mean': float(det_g_mean),
        'det_g_target': float(config.target_det_g),
        'det_g_error_pct': float(100 * abs(det_g_mean - config.target_det_g) / config.target_det_g),
    },
    'training': {
        'h2_epochs': int(config.n_epochs_h2),
        'h3_epochs': int(config.n_epochs_h3),
        'h2_final_loss': float(h2_losses['total'][-1]) if h2_losses['total'] else None,
        'h3_final_loss': float(h3_losses['total'][-1]) if h3_losses['total'] else None,
    },
    'tcs_modes': {
        'n_local': int(config.b3_local),
        'n_left': int(config.n_left),
        'n_right': int(config.n_right),
        'n_neck': int(config.n_neck),
    },
    'yukawa': {
        'n_visible': int(analysis['n_visible']),
        'nonzero_count': int(analysis['nonzero_count']),
        'largest_gap_idx': int(analysis['largest_gap_idx']),
        'gap_35_ratio': float(analysis['gap_35_ratio']),
        'gap_43_ratio': float(analysis['gap_43_ratio']),
        'tau_best_n': int(analysis['tau_best_n']),
        'tau_estimate': float(analysis['tau_estimate']),
        'tau_target': float(config.tau_target),
        'tau_error_pct': float(analysis['tau_error_pct']) if analysis['tau_error_pct'] >= 0 else None,
    },
    'verdict': {
        '43_77_structure': bool(is_43),
        '35_42_structure': bool(is_35),
        'tau_emerged': bool(tau_ok),
    },
}

with open(f"{config.output_dir}/metrics.json", 'w') as f:
    json.dump(metrics, f, indent=2)
print("  metrics.json")

# 4. Eigenvalues CSV
with open(f"{config.output_dir}/eigenvalues.csv", 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['index', 'eigenvalue', 'cumulative', 'gap'])
    cumsum = np.cumsum(yukawa['eigenvalues'])
    gaps = np.abs(np.diff(yukawa['eigenvalues']))
    for i, ev in enumerate(yukawa['eigenvalues']):
        g = float(gaps[i]) if i < len(gaps) else 0.0
        writer.writerow([i, float(ev), float(cumsum[i]), g])
print("  eigenvalues.csv")

# 5. Samples with forms
with torch.no_grad():
    sample_coords = data['coords'][:1000].to(device)
    sample_omega = h2_model(sample_coords).cpu().numpy()
    sample_Phi = h3_model(sample_coords).cpu().numpy()

np.savez(f"{config.output_dir}/samples.npz",
         coords=data['coords'][:1000].numpy(),
         metric=data['metric'][:1000].numpy(),
         phi=data['phi'][:1000].numpy(),
         omega=sample_omega,
         Phi=sample_Phi)
print("  samples.npz")

print("\nAll outputs saved.")

In [None]:
# @title Final Summary
# @markdown Complete report of v2.0 TCS Hodge training results.

print("=" * 70)
print("K7 TCS HODGE v2.0 - FINAL REPORT")
print("=" * 70)

print(f"\n[GEOMETRY]")
print(f"  det(g): {metrics['geometry']['det_g_mean']:.6f} (target: {metrics['geometry']['det_g_target']}, error: {metrics['geometry']['det_g_error_pct']:.2f}%)")

print(f"\n[TCS MODE STRUCTURE]")
print(f"  Local (fiber): {metrics['tcs_modes']['n_local']} modes")
print(f"  Left-weighted: {metrics['tcs_modes']['n_left']} modes")
print(f"  Right-weighted: {metrics['tcs_modes']['n_right']} modes")
print(f"  Neck-coupled: {metrics['tcs_modes']['n_neck']} modes")
print(f"  Total: {config.b3_K7} modes")

print(f"\n[TRAINING]")
print(f"  H2: {metrics['training']['h2_epochs']} epochs, final loss: {metrics['training']['h2_final_loss']:.2e}")
print(f"  H3: {metrics['training']['h3_epochs']} epochs, final loss: {metrics['training']['h3_final_loss']:.2e}")

print(f"\n[YUKAWA SPECTRUM]")
print(f"  Non-zero eigenvalues: {metrics['yukawa']['nonzero_count']}")
print(f"  Largest gap at: {metrics['yukawa']['largest_gap_idx']}->{metrics['yukawa']['largest_gap_idx']+1}")
print(f"  Suggested n_visible: {metrics['yukawa']['n_visible']}")
print(f"  Gap at 35: {metrics['yukawa']['gap_35_ratio']:.2f}x mean")
print(f"  Gap at 43: {metrics['yukawa']['gap_43_ratio']:.2f}x mean")

print(f"\n[TAU PARAMETER]")
print(f"  Target: {metrics['yukawa']['tau_target']:.4f} (3472/891)")
if metrics['yukawa']['tau_error_pct'] is not None and metrics['yukawa']['tau_error_pct'] >= 0:
    print(f"  Best match at n={metrics['yukawa']['tau_best_n']}: tau={metrics['yukawa']['tau_estimate']:.4f}")
    print(f"  Error: {metrics['yukawa']['tau_error_pct']:.1f}%")
else:
    print(f"  Could not compute tau (hidden sum ~ 0)")

print(f"\n[VERDICT]")
v = metrics['verdict']
print(f"  43/77 visible/hidden structure: {'YES' if v['43_77_structure'] else 'NO'}")
print(f"  35/42 local/global structure: {'YES' if v['35_42_structure'] else 'NO'}")
print(f"  Tau = 3472/891 emerged: {'YES' if v['tau_emerged'] else 'NO'}")

print(f"\n[OUTPUT FILES]")
print(f"  {config.output_dir}/models.pt")
print(f"  {config.output_dir}/yukawa.npz")
print(f"  {config.output_dir}/metrics.json")
print(f"  {config.output_dir}/eigenvalues.csv")
print(f"  {config.output_dir}/samples.npz")

print("=" * 70)

# Success criteria
success = v['43_77_structure'] or v['35_42_structure'] or v['tau_emerged']
if success:
    print("\nTCS structure has influence on Yukawa spectrum.")
else:
    print("\nTCS structure not yet visible. May need more training or parameter tuning.")