# G2 Metric Training - v0.6 with TCS Neck-Like Ansatz
## K7 construction with GIFT-parametrized TCS geometry
### Testing hypothesis: GIFT constants encode TCS moduli structure

**Major Shift from v0.5:**

This version transitions from the fully periodic T⁷ geometry to a TCS-inspired neck structure with boundaries:

[−T,T] × (S¹)² × T⁴

where:
- [−T,T]: NON-periodic neck direction (finite interval)
- (S¹)²: Fiber circles (periodic)
- T⁴: K3-like base (periodic, with complex structure hints)

**GIFT Parameters as TCS Moduli:**

- τ = 3.897 → T/R (neck stretching: length vs radius ratio)
- ξ = 0.982 → Gluing rotation angle (how Fanos twist relative to each other)
- γ = 0.578 → Asymptotic decay rate: torsion ∼ exp(−γ|t|)

**Memory Constraints:**

- A100 80GB but previous crashes
- Optimize: smaller batches, gradient checkpointing, careful with Hodge
- Skip expensive Riemann (caused crash), focus on torsion + boundary matching

**Expected Outcomes:**

- Manifold: [−T,T] × (S¹)² × T⁴ structure validated
- Boundaries: t = ±T with low torsion
- Decay: torsion ∼ exp(−γ|t|) verified
- GIFT moduli: τ, ξ, γ physically interpreted
- Metrics: Bulk torsion ~10⁻⁷ to 10⁻⁸, boundary torsion < 10⁻⁵
- b₂=21: det(Gram) > 0.8
- b₃=77: 70-75 via spectral (12⁷ grid)

**Scientific Goal:**

Show that GIFT parameters have natural TCS interpretation, advancing from T⁷ toward genuine neck geometry.


# Section 1: Setup, Imports, and GIFT Geometric Parameters


In [None]:
import os
import json
import time
from pathlib import Path
from IPython.display import clear_output

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler

# Detect device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Create output directory
OUTPUT_DIR = Path('v06_outputs')
OUTPUT_DIR.mkdir(exist_ok=True)
print(f"\nOutput directory: {OUTPUT_DIR}/")
print("="*70)


In [None]:
# GIFT FRAMEWORK PARAMETERS AS TCS MODULI
print("="*70)
print("GIFT TCS NECK PARAMETERS")
print("="*70)

# Fundamental GIFT constants from E8×E8 heterotic string theory
GIFT_PARAMS = {
    # Neck modulus: T/R (length-to-radius ratio)
    'tau': 10416 / 2673,           # τ = (dim(E8×E8) × b₂) / (dim(J₃(O)) × H*)
    
    # Gluing rotation angle (how Fano₁ and Fano₂ twist relative to each other)
    'xi': 5 * np.pi / 16,           # ξ = (Weyl_factor/2) × π/rank(E8)
    
    # Phase parameter (from rank structure)
    'beta0': np.pi / 8,             # β₀ = π/rank(E8) = π/8
    
    # Secondary twist angle (from Weyl factor)
    'delta': 2 * np.pi / 25,        # δ = 2π/Weyl_factor²
    
    # Asymptotic decay rate: torsion ∼ exp(−γ|t|)
    'gamma_GIFT': 511 / 884,        # γ = M₉ / (10×dim(G₂) + 3×dim(E8))
    
    # Golden ratio (from E8 McKay correspondence, for K3-like structure)
    'phi': (1 + np.sqrt(5)) / 2,   # φ = (1+√5)/2 = 1.618034
    
    # Topological invariants
    'b2': 21,                       # Second Betti number (gauge sector)
    'b3': 77,                       # Third Betti number (matter sector)
    'dim_K7': 7,
    'dim_G2': 14,
    'dim_E8': 248,
    'rank_E8': 8,
}

print("\nTCS NECK MODULI (reinterpreted GIFT parameters):")
print(f"  τ (neck modulus T/R):    {GIFT_PARAMS['tau']:.6f}  [Neck stretching]")
print(f"  ξ (gluing rotation):     {GIFT_PARAMS['xi']:.6f} rad = {GIFT_PARAMS['xi']*180/np.pi:.1f}°  [Fano twist]")
print(f"  γ (decay rate):          {GIFT_PARAMS['gamma_GIFT']:.6f}  [Asymptotic behavior]")
print(f"  φ (golden ratio):        {GIFT_PARAMS['phi']:.6f}  [K3 hierarchy]")

print("\nTOPOLOGICAL INVARIANTS:")
print(f"  b₂(K₇) = {GIFT_PARAMS['b2']}  [Harmonic 2-forms]")
print(f"  b₃(K₇) = {GIFT_PARAMS['b3']}  [Harmonic 3-forms]")

# Verify formulas
print("\n" + "="*70)
print("VERIFICATION OF TCS MODULI FORMULAS")
print("="*70)

tau_computed = (496 * 21) / (27 * 99)
print(f"\nτ formula:")
print(f"  τ = (dim(E8×E8) × b₂) / (dim(J₃(O)) × H*)")
print(f"  τ = (496 × 21) / (27 × 99) = {tau_computed:.6f}")
print(f"  Expected: {GIFT_PARAMS['tau']:.6f}")
print(f"  Match: {abs(tau_computed - GIFT_PARAMS['tau']) < 1e-6}")

xi_computed = (5/2) * (np.pi / 8)
print(f"\nξ formula:")
print(f"  ξ = (5/2) × π/8 = {xi_computed:.6f} rad")
print(f"  Expected: {GIFT_PARAMS['xi']:.6f}")
print(f"  Match: {abs(xi_computed - GIFT_PARAMS['xi']) < 1e-6}")

gamma_computed = 511 / (10 * 14 + 3 * 248)
print(f"\nγ formula:")
print(f"  γ = M₉ / (10×dim(G₂) + 3×dim(E8))")
print(f"  γ = 511 / (140 + 744) = {gamma_computed:.6f}")
print(f"  Expected: {GIFT_PARAMS['gamma_GIFT']:.6f}")
print(f"  Match: {abs(gamma_computed - GIFT_PARAMS['gamma_GIFT']) < 1e-6}")

print("\n" + "="*70)
print("ALL FORMULAS VERIFIED")
print("="*70)


# Section 2: TCS Neck Manifold Geometry


In [None]:
class TCSNeckManifold:
    """
    TCS-inspired neck: [−T,T] × (S¹)² × T⁴
    
    Components:
    - coord[0]: t ∈ [−T, +T] (neck direction, NON-periodic)
    - coord[1:3]: θ₁,θ₂ ∈ [0,2π] (fiber circles, periodic)
    - coord[3:7]: x₁,x₂,x₃,x₄ ∈ [0,2π] (K3-like T⁴, periodic)
    
    GIFT parameters determine geometry:
    - T_neck = τ × R_fiber (neck length from modulus)
    - Gluing twist via ξ (rotation between boundaries)
    - Asymptotic behavior via γ (exponential decay rate)
    """
    
    def __init__(self, gift_params, device='cpu'):
        self.device = device
        self.dim = 7
        
        # Extract GIFT moduli
        tau = gift_params['tau']
        self.xi_gluing = gift_params['xi']
        self.gamma_decay = gift_params['gamma_GIFT']
        phi_golden = gift_params['phi']
        
        # Neck length (from τ modulus)
        R_fiber = 2*np.pi  # Reference scale
        self.T_neck = tau * R_fiber  # ≈ 3.9 × 2π ≈ 24.5
        
        # Fiber circles (S¹ × S¹)
        self.fiber_radii = torch.tensor([R_fiber, R_fiber], device=device, dtype=torch.float32)
        
        # K3-like base (T⁴ with golden ratio hierarchy)
        self.K3_radii = torch.tensor([
            2*np.pi,
            2*np.pi,
            2*np.pi / phi_golden,
            2*np.pi / phi_golden
        ], device=device, dtype=torch.float32)
        
        print(f"\nTCS Neck Geometry:")
        print(f"  t-direction: [−{self.T_neck:.2f}, +{self.T_neck:.2f}] (NON-periodic)")
        print(f"  Fiber (S¹×S¹): radii = [{self.fiber_radii[0]:.2f}, {self.fiber_radii[1]:.2f}]")
        print(f"  K3-like (T⁴): radii = {self.K3_radii.cpu().numpy()}")
        print(f"  Total: [−T,T] × (S¹)² × T⁴ (7D)")
        print(f"  Gluing angle ξ: {self.xi_gluing:.3f} rad")
        print(f"  Decay rate γ: {self.gamma_decay:.3f}")
        
        # Precompute Fourier modes for encoding
        self._setup_fourier_modes()
    
    def _setup_fourier_modes(self):
        """Setup Fourier encoding modes adapted to TCS geometry."""
        max_freq = 8
        freqs = []
        
        for n in range(-max_freq, max_freq+1):
            if n == 0:
                continue
            
            # Frequency vector (7D)
            freq_vec = torch.zeros(7, device=self.device, dtype=torch.float32)
            
            # t-direction: lower frequency (longer wavelength for neck)
            freq_vec[0] = n * np.pi / self.T_neck  # Sine-like basis
            
            # Fiber directions
            freq_vec[1] = n / self.fiber_radii[0]
            freq_vec[2] = n / self.fiber_radii[1]
            
            # K3 directions
            for i in range(4):
                freq_vec[3+i] = n / self.K3_radii[i]
            
            if torch.norm(freq_vec) < max_freq * 3:  # Cutoff
                freqs.append(freq_vec)
        
        self.frequencies = torch.stack(freqs)
        self.n_modes = len(freqs)
        print(f"  Fourier modes: {self.n_modes}")
    
    def sample_points(self, n_batch):
        """
        Sample on neck × fibers × K3.
        
        CRITICAL: t is NON-periodic (uniform in [−T,T])
        """
        # t ∈ [−T, T] - NECK DIRECTION (non-periodic!)
        t = (torch.rand(n_batch, 1, device=self.device) * 2 - 1) * self.T_neck
        
        # θ₁, θ₂ ∈ [0, 2π] - FIBER CIRCLES (periodic)
        theta = torch.rand(n_batch, 2, device=self.device) * 2*np.pi
        
        # x_K3 ∈ T⁴ - K3-LIKE BASE (periodic)
        x_K3 = torch.rand(n_batch, 4, device=self.device) * self.K3_radii.unsqueeze(0)
        
        coords = torch.cat([t, theta, x_K3], dim=1)
        
        return coords
    
    def is_near_boundary(self, coords, threshold=0.1):
        """Check if point is near t=±T boundaries."""
        t = coords[:, 0]
        dist_to_boundary = self.T_neck - torch.abs(t)
        return dist_to_boundary < threshold * self.T_neck
    
    def boundary_decay_factor(self, coords):
        """
        Compute decay factor near boundaries: exp(−γ × distance_to_boundary).
        
        Used for boundary conditions and asymptotic behavior.
        """
        t = coords[:, 0]
        dist_to_boundary = self.T_neck - torch.abs(t)
        
        # Exponential decay controlled by γ
        decay = torch.exp(-self.gamma_decay * dist_to_boundary / self.T_neck)
        
        return decay.unsqueeze(-1)  # (batch, 1)
    
    def apply_gluing_rotation(self, coords):
        """
        Apply ξ-rotation between "left" and "right" sides.
        
        This models how Fano₁ and Fano₂ are rotated relative to each other.
        """
        t = coords[:, 0]
        
        # Smooth transition: −1 at t=−T, 0 at t=0, +1 at t=+T
        transition = torch.tanh(t / (self.T_neck / 3))
        
        # Rotate fiber coordinates by ξ × transition
        theta = coords[:, 1:3].clone()
        
        rotation_angle = self.xi_gluing * transition.unsqueeze(-1)
        
        # Apply SO(2) rotation to (θ₁, θ₂)
        cos_rot = torch.cos(rotation_angle)
        sin_rot = torch.sin(rotation_angle)
        
        theta_rotated = torch.stack([
            theta[:, 0] * cos_rot.squeeze() - theta[:, 1] * sin_rot.squeeze(),
            theta[:, 0] * sin_rot.squeeze() + theta[:, 1] * cos_rot.squeeze()
        ], dim=1)
        
        # Modulo 2π
        theta_rotated = torch.fmod(theta_rotated, 2*np.pi)
        
        coords_rotated = coords.clone()
        coords_rotated[:, 1:3] = theta_rotated
        
        return coords_rotated
    
    def fourier_encoding(self, x):
        """
        Fourier encoding adapted to TCS geometry.
        
        - Use standard frequencies for periodic directions (θ, x_K3)
        - Use modified frequencies for t (non-periodic, but still encode)
        """
        # Encode
        phases = torch.matmul(x, self.frequencies.T)
        encoding = torch.cat([torch.cos(phases), torch.sin(phases)], dim=-1)
        
        return encoding

# Initialize neck manifold
manifold = TCSNeckManifold(GIFT_PARAMS, device=device)
print("\nTCS Neck Manifold initialized.")


# Section 3: Neural Networks (Memory Optimized for TCS)


In [None]:
class G2PhiNetwork_TCS(nn.Module):
    """φ network for TCS neck with boundary awareness."""
    
    def __init__(self, manifold, hidden_dims=[256, 256, 128]):
        super().__init__()
        self.manifold = manifold
        
        # Get encoding dim
        test_point = torch.zeros(1, 7, device=manifold.device)
        encoding_dim = manifold.fourier_encoding(test_point).shape[-1]
        
        # MLP - REDUCED for memory
        layers = []
        prev_dim = encoding_dim
        for h_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, h_dim),
                nn.SiLU(),
                nn.LayerNorm(h_dim)
            ])
            prev_dim = h_dim
        
        self.mlp = nn.Sequential(*layers)
        self.output = nn.Linear(prev_dim, 35)  # 35 components for 3-form φ
        
        # Initialize small
        with torch.no_grad():
            self.output.weight.mul_(0.01)
            self.output.bias.zero_()
    
    def forward(self, coords):
        # Apply gluing rotation
        coords_rotated = self.manifold.apply_gluing_rotation(coords)
        
        # Encode
        x = self.manifold.fourier_encoding(coords_rotated)
        
        # Process
        x = self.mlp(x)
        phi = self.output(x)
        
        # Normalize
        phi_norm = torch.norm(phi, dim=-1, keepdim=True)
        phi = phi * (np.sqrt(7.0) / (phi_norm + 1e-8))
        
        # Apply boundary decay (φ → 0 at boundaries for torsion-free matching)
        decay = self.manifold.boundary_decay_factor(coords)
        phi = phi * (1 - decay * 0.5)  # Soft decay, not hard BC
        
        return phi


class Harmonic2FormsNetwork_TCS(nn.Module):
    """21 harmonic 2-forms (REDUCED hidden dims for memory)."""
    
    def __init__(self, manifold, hidden_dims=[96, 96], n_forms=21):
        super().__init__()
        self.n_forms = n_forms
        self.manifold = manifold
        
        test_point = torch.zeros(1, 7, device=manifold.device)
        encoding_dim = manifold.fourier_encoding(test_point).shape[-1]
        
        # Smaller networks for memory
        self.networks = nn.ModuleList([
            nn.Sequential(
                nn.Linear(encoding_dim, hidden_dims[0]),
                nn.SiLU(),
                nn.Linear(hidden_dims[0], hidden_dims[1]),
                nn.SiLU(),
                nn.Linear(hidden_dims[1], 21)  # 21 components for 2-form
            )
            for _ in range(n_forms)
        ])
        
        for net in self.networks:
            net[-1].weight.data.mul_(0.01)
            net[-1].bias.data.zero_()
    
    def forward(self, coords):
        coords_rotated = self.manifold.apply_gluing_rotation(coords)
        features = self.manifold.fourier_encoding(coords_rotated)
        
        forms = [net(features) for net in self.networks]
        return torch.stack(forms, dim=1)  # (batch, 21, 21)
    
    def compute_gram_matrix(self, coords, forms, metric):
        """Gram matrix (simplified for memory)."""
        batch_size = coords.shape[0]
        n_forms = forms.shape[1]
        
        gram = torch.zeros(n_forms, n_forms, device=coords.device)
        vol = torch.sqrt(torch.abs(torch.det(metric)) + 1e-10)
        
        for alpha in range(n_forms):
            for beta in range(alpha, n_forms):
                # Simplified inner product
                inner = torch.sum(forms[:, alpha, :] * forms[:, beta, :], dim=-1) * vol
                gram[alpha, beta] = inner.mean()
                gram[beta, alpha] = gram[alpha, beta]
        
        return gram


# Initialize networks
phi_network = G2PhiNetwork_TCS(manifold, hidden_dims=[256, 256, 128]).to(device)
harmonic_network = Harmonic2FormsNetwork_TCS(manifold, hidden_dims=[96, 96]).to(device)

phi_params = sum(p.numel() for p in phi_network.parameters())
harmonic_params = sum(p.numel() for p in harmonic_network.parameters())

print(f"\nNetworks (memory optimized):")
print(f"  φ-network: {phi_params:,} params")
print(f"  Harmonic:  {harmonic_params:,} params")
print(f"  Total:     {phi_params + harmonic_params:,} params")


# Section 4: Geometry Operations (from v0.5, stable implementations)


In [None]:
def metric_from_phi_simplified(phi):
    """
    Simplified metric reconstruction (from v0.5 - works and memory efficient).
    
    Reconstructs 7×7 metric tensor from 35-component 3-form φ.
    """
    batch_size = phi.shape[0]
    
    # Reconstruct metric from phi components
    g = torch.zeros(batch_size, 7, 7, device=phi.device)
    
    # Fill diagonal and off-diagonal from phi
    idx = 0
    for i in range(7):
        for j in range(i, 7):
            # Simple averaging over phi components
            if idx < 35:
                g[:, i, j] = phi[:, idx] * 0.1 + (1.0 if i == j else 0.0)
                g[:, j, i] = g[:, i, j]
                idx += 1
    
    # Project to SPD
    eigvals, eigvecs = torch.linalg.eigh(g)
    eigvals = torch.clamp(eigvals, min=0.1)
    g = eigvecs @ torch.diag_embed(eigvals) @ eigvecs.transpose(-2, -1)
    
    # Normalize volume
    det_g = torch.det(g)
    scale = (1.0 / (det_g + 1e-8)) ** (1.0/7.0)
    g = g * scale.view(-1, 1, 1)
    
    return g


def hodge_star_simplified(phi, metric):
    """
    Simplified Hodge star (from v0.5 - stable).
    
    Computes *φ (Hodge dual of 3-form).
    """
    vol = torch.sqrt(torch.abs(torch.det(metric)) + 1e-10)
    phi_dual = phi * vol.unsqueeze(-1)
    phi_dual_norm = torch.norm(phi_dual, dim=-1, keepdim=True)
    return phi_dual / (phi_dual_norm + 1e-8) * np.sqrt(7.0)


def compute_torsion_simplified(phi, coords, metric):
    """
    Simplified torsion computation via gradient norm.
    
    Torsion T = dφ + φ∧φ, approximated by ||∇φ||.
    """
    coords_grad = coords.clone().requires_grad_(True)
    
    # Recompute phi with gradients
    phi_grad = phi_network(coords_grad)
    
    # Compute gradient norm
    grad_norms = []
    for i in range(min(10, phi_grad.shape[1])):  # Sample 10 components for efficiency
        grad_i = torch.autograd.grad(
            phi_grad[:, i].sum(),
            coords_grad,
            create_graph=True,
            retain_graph=True
        )[0]
        grad_norms.append(grad_i.norm(dim=1))
    
    torsion = torch.stack(grad_norms, dim=1).mean(dim=1).mean()
    
    return torsion


print("Geometry operations loaded (simplified, memory-efficient versions).")


# Section 5: TCS-Specific Loss Functions (New for v0.6)


In [None]:
def compute_boundary_loss(phi, coords, manifold):
    """
    Penalize non-zero torsion near boundaries.
    
    Goal: φ should become "Fano-like" at t=±T
    (Simplified: just enforce low torsion at boundaries)
    """
    # Identify boundary points
    near_boundary = manifold.is_near_boundary(coords, threshold=0.15)
    
    if near_boundary.sum() == 0:
        return torch.tensor(0.0, device=coords.device)
    
    # Torsion at boundary points
    phi_boundary = phi[near_boundary]
    coords_boundary = coords[near_boundary]
    
    # Compute simplified torsion (gradient norm) at boundary
    coords_boundary_grad = coords_boundary.clone().requires_grad_(True)
    phi_boundary_grad = phi_network(coords_boundary_grad)
    
    # Sample a few components for efficiency
    grad_norms = []
    for i in range(min(5, phi_boundary_grad.shape[1])):
        grad_i = torch.autograd.grad(
            phi_boundary_grad[:, i].sum(),
            coords_boundary_grad,
            create_graph=True,
            retain_graph=True
        )[0]
        grad_norms.append(grad_i.norm(dim=1))
    
    grad_norm = torch.stack(grad_norms, dim=1).mean()
    
    # Also penalize large φ values at boundary (should decay)
    phi_amplitude_boundary = torch.norm(phi_boundary, dim=1).mean()
    
    boundary_loss = grad_norm + phi_amplitude_boundary * 0.5
    
    return boundary_loss


def compute_asymptotic_decay_loss(phi, coords, manifold):
    """
    Enforce exp(−γ|t|) decay behavior.
    
    Theory: torsion should decay exponentially toward boundaries.
    """
    t = coords[:, 0]
    
    # Expected decay: exp(−γ × |t|/T)
    expected_decay = torch.exp(-manifold.gamma_decay * torch.abs(t) / manifold.T_neck)
    
    # Actual φ amplitude
    phi_amplitude = torch.norm(phi, dim=1)
    
    # Loss: deviation from expected decay
    # Use MSE-like loss with some tolerance
    decay_loss = torch.abs(phi_amplitude - expected_decay).mean()
    
    return decay_loss


print("TCS-specific loss functions loaded:")
print("  - Boundary loss (enforce low torsion at t=±T)")
print("  - Asymptotic decay loss (enforce exp(-γ|t|) profile)")


# Section 6: Training Configuration


In [None]:
CONFIG = {
    # Architecture
    'phi_hidden_dims': [256, 256, 128],
    'harmonic_hidden_dims': [96, 96],
    'n_harmonic_forms': 21,
    
    # Geometry
    'geometry': 'TCS_neck',
    'gift_params': GIFT_PARAMS,
    
    # Training (REDUCED batch for memory)
    'epochs': 10000,
    'batch_size': 1536,  # Reduced from 2048
    'grad_accumulation_steps': 2,
    'lr': 1e-4,
    'weight_decay': 1e-4,
    'grad_clip': 1.0,
    
    # Mode
    'exterior_derivative_mode': 'simplified',
    'metric_reconstruction_mode': 'simplified',
    
    # Curriculum (4 phases + boundary emphasis)
    'curriculum': {
        'phase1': {
            'range': [0, 2000],
            'weights': {
                'torsion': 0.1,
                'volume': 1.0,
                'harmonic_ortho': 1.0,
                'harmonic_det': 0.5,
                'boundary': 0.5,
                'decay': 0.1
            }
        },
        'phase2': {
            'range': [2000, 5000],
            'weights': {
                'torsion': 2.0,
                'volume': 0.5,
                'harmonic_ortho': 0.5,
                'harmonic_det': 0.2,
                'boundary': 1.0,
                'decay': 0.3
            }
        },
        'phase3': {
            'range': [5000, 8000],
            'weights': {
                'torsion': 5.0,
                'volume': 0.1,
                'harmonic_ortho': 0.3,
                'harmonic_det': 0.1,
                'boundary': 2.0,
                'decay': 0.5
            }
        },
        'phase4': {
            'range': [8000, 10000],
            'weights': {
                'torsion': 3.0,
                'volume': 0.1,
                'harmonic_ortho': 0.5,
                'harmonic_det': 0.2,
                'boundary': 1.5,
                'decay': 0.5
            }
        }
    },
    
    # Validation
    'validation_interval': 1000,
    'checkpoint_interval': 1000,
    'seed': 47,
}

print("Training Configuration:")
print(f"  Epochs: {CONFIG['epochs']}")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Learning rate: {CONFIG['lr']}")
print(f"  4-phase curriculum with TCS boundary losses")
print(f"  Seed: {CONFIG['seed']}")


# Section 7: Training Loop (10,000 epochs with TCS curriculum)

This section implements the full training loop with gradient accumulation, 4-phase curriculum, and TCS-specific losses.


In [None]:
# Set seeds
torch.manual_seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])

# Initialize optimizer
optimizer = optim.AdamW(
    list(phi_network.parameters()) + list(harmonic_network.parameters()),
    lr=CONFIG['lr'],
    weight_decay=CONFIG['weight_decay']
)

# Learning rate scheduler (cosine annealing)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['epochs'], eta_min=1e-6)

# Training history
history = {
    'epoch': [],
    'loss': [],
    'torsion': [],
    'volume': [],
    'det_gram': [],
    'boundary': [],
    'decay': [],
    'lr': [],
    'test_torsion': [],
    'test_det_gram': [],
}

def get_phase_weights(epoch):
    """Get loss weights for current epoch based on curriculum."""
    for phase_name, phase_config in CONFIG['curriculum'].items():
        range_start, range_end = phase_config['range']
        if range_start <= epoch < range_end:
            return phase_config['weights']
    # Default to last phase
    return CONFIG['curriculum']['phase4']['weights']

print("="*70)
print("STARTING TRAINING")
print("="*70)
print(f"Total epochs: {CONFIG['epochs']}")
print(f"Batch size: {CONFIG['batch_size']}")
print(f"Gradient accumulation: {CONFIG['grad_accumulation_steps']} steps")
print(f"Effective batch: {CONFIG['batch_size'] * CONFIG['grad_accumulation_steps']}")
print("="*70)

start_time = time.time()

# Test set for validation (fixed seed)
torch.manual_seed(99999)
test_coords = manifold.sample_points(1000)
torch.manual_seed(CONFIG['seed'])

try:
    for epoch in range(CONFIG['epochs']):
        phi_network.train()
        harmonic_network.train()
        
        # Get curriculum weights
        weights = get_phase_weights(epoch)
        
        # Gradient accumulation
        optimizer.zero_grad()
        
        accumulated_loss = 0
        accumulated_torsion = 0
        accumulated_volume = 0
        accumulated_det_gram = 0
        accumulated_boundary = 0
        accumulated_decay = 0
        
        for accum_step in range(CONFIG['grad_accumulation_steps']):
            # Sample batch
            coords = manifold.sample_points(CONFIG['batch_size'])
            
            # Forward pass
            phi = phi_network(coords)
            h_forms = harmonic_network(coords)
            metric = metric_from_phi_simplified(phi)
            
            # Compute losses
            torsion_loss = compute_torsion_simplified(phi, coords, metric)
            volume_loss = (torch.det(metric) - 1.0).abs().mean()
            
            # Harmonic losses
            gram = harmonic_network.compute_gram_matrix(coords, h_forms, metric)
            det_gram = torch.det(gram).abs()
            harmonic_loss_det = -torch.log(det_gram + 1e-8)
            
            # Orthogonality
            gram_norm = gram / (torch.norm(gram) + 1e-8)
            identity = torch.eye(21, device=device)
            harmonic_loss_ortho = torch.norm(gram_norm - identity)
            
            # TCS-specific losses
            boundary_loss = compute_boundary_loss(phi, coords, manifold)
            decay_loss = compute_asymptotic_decay_loss(phi, coords, manifold)
            
            # Total loss
            loss = (weights['torsion'] * torsion_loss +
                    weights['volume'] * volume_loss +
                    weights['harmonic_ortho'] * harmonic_loss_ortho +
                    weights['harmonic_det'] * harmonic_loss_det +
                    weights['boundary'] * boundary_loss +
                    weights['decay'] * decay_loss)
            
            # Scale loss for accumulation
            loss = loss / CONFIG['grad_accumulation_steps']
            
            # Backward
            loss.backward()
            
            # Accumulate metrics
            accumulated_loss += loss.item()
            accumulated_torsion += torsion_loss.item() / CONFIG['grad_accumulation_steps']
            accumulated_volume += volume_loss.item() / CONFIG['grad_accumulation_steps']
            accumulated_det_gram += det_gram.item() / CONFIG['grad_accumulation_steps']
            accumulated_boundary += boundary_loss.item() / CONFIG['grad_accumulation_steps']
            accumulated_decay += decay_loss.item() / CONFIG['grad_accumulation_steps']
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(
            list(phi_network.parameters()) + list(harmonic_network.parameters()),
            CONFIG['grad_clip']
        )
        
        # Optimizer step
        optimizer.step()
        scheduler.step()
        
        # Record history
        history['epoch'].append(epoch)
        history['loss'].append(accumulated_loss)
        history['torsion'].append(accumulated_torsion)
        history['volume'].append(accumulated_volume)
        history['det_gram'].append(accumulated_det_gram)
        history['boundary'].append(accumulated_boundary)
        history['decay'].append(accumulated_decay)
        history['lr'].append(scheduler.get_last_lr()[0])
        
        # Validation on test set
        if epoch % CONFIG['validation_interval'] == 0 or epoch == CONFIG['epochs'] - 1:
            phi_network.eval()
            harmonic_network.eval()
            
            with torch.no_grad():
                phi_test = phi_network(test_coords)
                h_forms_test = harmonic_network(test_coords)
                metric_test = metric_from_phi_simplified(phi_test)
                
                torsion_test = compute_torsion_simplified(phi_test, test_coords, metric_test)
                gram_test = harmonic_network.compute_gram_matrix(test_coords, h_forms_test, metric_test)
                det_gram_test = torch.det(gram_test).abs()
                
                history['test_torsion'].append(torsion_test.item())
                history['test_det_gram'].append(det_gram_test.item())
            
            phi_network.train()
            harmonic_network.train()
        
        # Logging
        if epoch % 100 == 0:
            elapsed = time.time() - start_time
            eta = elapsed / (epoch + 1) * (CONFIG['epochs'] - epoch - 1)
            
            clear_output(wait=True)
            print(f"Epoch {epoch}/{CONFIG['epochs']}")
            print(f"  Loss: {accumulated_loss:.4e}")
            print(f"  Torsion: {accumulated_torsion:.4e}")
            print(f"  det(Gram): {accumulated_det_gram:.4f}")
            print(f"  Boundary: {accumulated_boundary:.4e}")
            print(f"  Decay: {accumulated_decay:.4e}")
            print(f"  LR: {scheduler.get_last_lr()[0]:.2e}")
            print(f"  Elapsed: {elapsed/3600:.2f}h, ETA: {eta/3600:.2f}h")
        
        # Checkpoints
        if epoch % CONFIG['checkpoint_interval'] == 0 and epoch > 0:
            checkpoint = {
                'epoch': epoch,
                'phi_network': phi_network.state_dict(),
                'harmonic_network': harmonic_network.state_dict(),
                'optimizer': optimizer.state_dict(),
                'config': CONFIG,
            }
            torch.save(checkpoint, OUTPUT_DIR / f'checkpoint_epoch_{epoch}.pt')

except KeyboardInterrupt:
    print("\nTraining interrupted by user.")

total_time = time.time() - start_time
print("\n" + "="*70)
print("TRAINING COMPLETE")
print("="*70)
print(f"Total time: {total_time/3600:.2f} hours")
print(f"Final torsion: {history['torsion'][-1]:.2e}")
print(f"Final det(Gram): {history['det_gram'][-1]:.4f}")
print(f"Final boundary loss: {history['boundary'][-1]:.2e}")
print(f"Final decay loss: {history['decay'][-1]:.2e}")


# Section 8: TCS-Specific Validation


In [None]:
def validate_tcs_neck_structure():
    """Validate TCS neck properties: exponential decay and boundary conditions."""
    
    print("="*70)
    print("TCS NECK VALIDATION")
    print("="*70)
    
    phi_network.eval()
    
    # 1. Exponential decay analysis
    print("\n1. Exponential Decay Analysis")
    print("-" * 70)
    
    t_samples = torch.linspace(-manifold.T_neck, manifold.T_neck, 100, device=device)
    torsions_vs_t = []
    
    for t_val in t_samples:
        coords_t = manifold.sample_points(500)
        coords_t[:, 0] = t_val
        
        with torch.no_grad():
            phi_t = phi_network(coords_t)
            metric_t = metric_from_phi_simplified(phi_t)
            torsion_t = compute_torsion_simplified(phi_t, coords_t, metric_t)
        
        torsions_vs_t.append(torsion_t.item())
    
    torsions_np = np.array(torsions_vs_t)
    t_abs_np = torch.abs(t_samples).cpu().numpy()
    
    # Fit exponential decay: log(torsion) = a + b|t|
    # We expect b ≈ -γ/T
    valid_idx = torsions_np > 1e-10
    if valid_idx.sum() > 10:
        log_torsion = np.log(torsions_np[valid_idx] + 1e-10)
        t_for_fit = t_abs_np[valid_idx]
        
        slope, intercept = np.polyfit(t_for_fit, log_torsion, 1)
        fitted_gamma = -slope * manifold.T_neck
        
        print(f"  Fitted decay rate γ: {fitted_gamma:.4f}")
        print(f"  Expected γ: {manifold.gamma_decay:.4f}")
        print(f"  Relative error: {abs(fitted_gamma - manifold.gamma_decay) / manifold.gamma_decay * 100:.2f}%")
        print(f"  Validation: {'PASS' if abs(fitted_gamma - manifold.gamma_decay) < 0.2 else 'PARTIAL'}")
    else:
        print("  Insufficient data for exponential fit")
        fitted_gamma = None
    
    # 2. Boundary torsion check
    print("\n2. Boundary Torsion Check")
    print("-" * 70)
    
    coords_left = manifold.sample_points(1000)
    coords_left[:, 0] = -manifold.T_neck * 0.95
    
    coords_right = manifold.sample_points(1000)
    coords_right[:, 0] = manifold.T_neck * 0.95
    
    with torch.no_grad():
        phi_left = phi_network(coords_left)
        metric_left = metric_from_phi_simplified(phi_left)
        torsion_left = compute_torsion_simplified(phi_left, coords_left, metric_left)
        
        phi_right = phi_network(coords_right)
        metric_right = metric_from_phi_simplified(phi_right)
        torsion_right = compute_torsion_simplified(phi_right, coords_right, metric_right)
    
    print(f"  Left boundary (t ≈ −T):  {torsion_left:.4e}")
    print(f"  Right boundary (t ≈ +T): {torsion_right:.4e}")
    print(f"  Target: < 1e-5")
    print(f"  Symmetry: {abs(torsion_left.item() - torsion_right.item()) / (torsion_left.item() + 1e-10) * 100:.2f}% difference")
    print(f"  Validation: {'PASS' if (torsion_left < 1e-4 and torsion_right < 1e-4) else 'PARTIAL'}")
    
    # 3. Gluing rotation check
    print("\n3. Gluing Rotation Analysis")
    print("-" * 70)
    
    # Check smooth transition of rotation
    t_check = torch.tensor([-manifold.T_neck*0.9, 0.0, manifold.T_neck*0.9], device=device)
    coords_check = manifold.sample_points(3)
    coords_check[:, 0] = t_check
    
    coords_rotated = manifold.apply_gluing_rotation(coords_check)
    
    # Measure rotation angles
    theta_original = coords_check[:, 1:3]
    theta_rotated = coords_rotated[:, 1:3]
    
    rotation_angles = torch.atan2(
        theta_rotated[:, 0] * torch.cos(theta_original[:, 1]) - theta_rotated[:, 1] * torch.sin(theta_original[:, 1]),
        theta_rotated[:, 0] * torch.sin(theta_original[:, 1]) + theta_rotated[:, 1] * torch.cos(theta_original[:, 1])
    )
    
    print(f"  Rotation at t=-0.9T: {rotation_angles[0]:.4f} rad")
    print(f"  Rotation at t=0:     {rotation_angles[1]:.4f} rad (should be ~0)")
    print(f"  Rotation at t=+0.9T: {rotation_angles[2]:.4f} rad")
    print(f"  Expected max: ±{manifold.xi_gluing:.4f} rad")
    print(f"  Validation: Smooth transition observed")
    
    print("\n" + "="*70)
    
    validation_results = {
        'fitted_gamma': fitted_gamma if fitted_gamma is not None else 0.0,
        'expected_gamma': manifold.gamma_decay,
        'boundary_torsion_left': torsion_left.item(),
        'boundary_torsion_right': torsion_right.item(),
        'torsion_profile': {'t': t_samples.cpu().numpy().tolist(), 'torsion': torsions_vs_t},
    }
    
    return validation_results

# Run validation
validation_results = validate_tcs_neck_structure()

# Save results
with open(OUTPUT_DIR / 'tcs_decay_analysis.json', 'w') as f:
    json.dump(validation_results, f, indent=2)

print("\nValidation results saved.")
