In [None]:
# =============================================================================
# OPTIONAL: YUKAWA TENSOR EXTRACTION (EXPENSIVE - RUN SEPARATELY)
# =============================================================================

# Uncomment to run Yukawa extraction (takes ~10-30 minutes):

# yukawa_extractor = YukawaCouplingExtractor(n_samples=5000)
# Y = yukawa_extractor.compute_yukawa_tensor(
#     results['phi_net'], results['h2_extractor'], results['h3_extractor'],
#     CONFIG.geometry, DEVICE
# )
# yukawa_analysis = yukawa_extractor.analyze_yukawa(Y)
# print(f"Yukawa tensor shape: {Y.shape}")
# print(f"Effective rank: {yukawa_analysis['effective_rank']} / 77")
# print(f"Top singular values: {yukawa_analysis['top_singular_values'][:5]}")
# 
# # Save Yukawa tensor
# torch.save(Y, OUTPUT_DIR / "yukawa_tensor.pt")
# with open(OUTPUT_DIR / "yukawa_analysis.json", 'w') as f:
#     json.dump(yukawa_analysis, f, indent=2)

print("Yukawa extraction cell ready (uncomment to run)")

In [None]:
# =============================================================================
# RUN POST-HOC ANALYSIS
# =============================================================================

analysis = run_post_hoc_analysis(results, DEVICE)

print("\n" + "=" * 70)
print("K7 BREATHING v1.8 - COMPLETE")
print("=" * 70)
print(f"\nOutputs saved to: {OUTPUT_DIR}")
print("  - final_model.pt: Trained model checkpoint")
print("  - training_history.json: Full training metrics")
print("  - config.json: Experiment configuration")
print("  - metric_fit.json: Analytic metric approximation")
print("  - analysis_summary.json: Final results summary")
print("  - training_history.png: Training curves")
if CONFIG.breathing.enable:
    print("  - flux_profile.png: Breathing flux visualization")

In [None]:
# =============================================================================
# RUN FULL TRAINING
# =============================================================================

# Save configuration
CONFIG.save(str(OUTPUT_DIR / "config.json"))

# Run training
results = run_full_training(CONFIG, DEVICE)

print("\nTraining complete!")

In [None]:
# =============================================================================
# QUICK TEST CONFIGURATION (comment out for full training)
# =============================================================================

# Uncomment these lines for quick testing:
# CONFIG.training.n_epochs_core = 100
# CONFIG.training.n_epochs_breathing = 50
# CONFIG.logging.log_every = 10

# For pure geometry mode (no breathing diagnostics):
# CONFIG.breathing.enable = False

print("Configuration ready. Run next cell to start training.")

---
## Section 7: Run Training (Execute Cells Below)

**Quick Run Configuration:**
- For testing: Set `CONFIG.training.n_epochs_core = 100`
- For full training: Use default settings (5000 + 2000 epochs)

**Two Modes:**
1. **Pure Geometry**: Set `CONFIG.breathing.enable = False`
2. **Geometry + Breathing**: Keep `CONFIG.breathing.enable = True` (default)
---

In [None]:
# =============================================================================
# SECTION 6.3: COMPLETE POST-HOC ANALYSIS
# =============================================================================

def run_post_hoc_analysis(results: Dict, device: torch.device) -> Dict:
    """
    Run complete post-hoc analysis after training.
    
    Includes:
    1. Metric fitting
    2. Yukawa extraction (if enabled)
    3. Breathing period estimation
    4. Visualization
    """
    print("\n" + "=" * 70)
    print("POST-HOC ANALYSIS")
    print("=" * 70)
    
    phi_net = results['phi_net']
    flux_net = results['flux_net']
    h3_extractor = results['h3_extractor']
    mode_proj = results['mode_proj']
    h2_extractor = results['h2_extractor']
    history = results['history']
    config = results['config']
    
    analysis = {}
    
    # 1. Metric fitting
    print("\n--- Metric Ansatz Fitting ---")
    metric_fit = fit_metric_ansatz(phi_net, config.geometry, device)
    analysis['metric_fit'] = metric_fit
    
    # 2. Breathing period estimation
    if config.breathing.enable:
        print("\n--- Breathing Period Estimation ---")
        phi_net.eval()
        
        with torch.no_grad():
            n_lambda_pts = 100
            lam_vals = torch.linspace(0, 1, n_lambda_pts, device=device)
            
            # For each lambda, sample transverse coordinates and average
            observables = []
            for lam_val in lam_vals:
                coords = sample_TCS_coords(500, device)
                coords[:, 0] = lam_val  # Fix lambda
                
                phi = phi_net(coords)
                modes = h3_extractor(phi, coords)
                obs = compute_breathing_observable(modes, coords, mode_proj)
                observables.append(obs.mean().item())
            
            observables = torch.tensor(observables, device=device)
            
            period_est, period_info = estimate_breathing_period(observables, lam_vals)
            
            print(f"Estimated period: {period_est:.4f}")
            print(f"Target period (tau): {config.breathing.target_period:.4f}")
            
            analysis['breathing_period'] = {
                'estimated': period_est,
                'target': config.breathing.target_period,
                'info': period_info
            }
    
    # 3. Visualization
    print("\n--- Generating Plots ---")
    
    # Training history
    plot_training_history(history, str(OUTPUT_DIR / "training_history.png"))
    
    # Flux profile
    if config.breathing.enable:
        plot_flux_profile(flux_net, device, str(OUTPUT_DIR / "flux_profile.png"))
    
    # 4. Save complete analysis
    analysis_summary = {
        'metric_fit_r2': metric_fit['avg_r_squared'],
        'final_kappa_T': results['final_core_metrics']['kappa_T'],
        'final_kappa_T_error_pct': results['final_core_metrics']['kappa_T_error_pct'],
        'final_det_g': results['final_core_metrics']['det_g_mean'],
        'final_det_g_error_pct': results['final_core_metrics']['det_g_error_pct'],
    }
    
    if config.breathing.enable:
        analysis_summary['final_flux_integral'] = results['final_breathing_metrics']['flux_integral']
        analysis_summary['final_occupation_vis'] = results['final_breathing_metrics']['occupation_vis']
        if 'breathing_period' in analysis:
            analysis_summary['estimated_period'] = analysis['breathing_period']['estimated']
    
    with open(OUTPUT_DIR / "analysis_summary.json", 'w') as f:
        json.dump(analysis_summary, f, indent=2)
    
    print(f"\nAnalysis complete. Results saved to {OUTPUT_DIR}")
    return analysis


print("Post-hoc analysis function defined")

In [None]:
# =============================================================================
# SECTION 6.2: VISUALIZATION
# =============================================================================

def plot_training_history(history: TrainingHistory, save_path: Optional[str] = None):
    """Plot training curves."""
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    epochs = history.history['epoch']
    
    # Loss curve
    ax = axes[0, 0]
    ax.semilogy(epochs, history.history['total_loss'], 'b-', label='Total Loss')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('Training Loss')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # kappa_T
    ax = axes[0, 1]
    ax.plot(epochs, history.history['kappa_T'], 'b-', label='kappa_T')
    ax.axhline(y=1/61, color='r', linestyle='--', label='Target (1/61)')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('kappa_T')
    ax.set_title('Torsion Magnitude')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # det(g)
    ax = axes[0, 2]
    ax.plot(epochs, history.history['det_g'], 'b-', label='det(g)')
    ax.axhline(y=65/32, color='r', linestyle='--', label='Target (65/32)')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('det(g)')
    ax.set_title('Metric Determinant')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # kappa_T error
    ax = axes[1, 0]
    ax.semilogy(epochs, history.history['kappa_T_error_pct'], 'b-')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Error (%)')
    ax.set_title('kappa_T Error')
    ax.grid(True, alpha=0.3)
    
    # Flux integral (if available)
    ax = axes[1, 1]
    flux_vals = [v for v in history.history['flux_integral'] if not np.isnan(v)]
    if flux_vals:
        flux_epochs = [e for e, v in zip(epochs, history.history['flux_integral']) if not np.isnan(v)]
        ax.plot(flux_epochs, flux_vals, 'b-', label='Flux Integral')
        ax.axhline(y=-0.5, color='r', linestyle='--', label='Target (-1/2)')
        ax.legend()
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Flux Integral')
    ax.set_title('Breathing Flux')
    ax.grid(True, alpha=0.3)
    
    # Occupation ratio (if available)
    ax = axes[1, 2]
    occ_vals = [v for v in history.history['occupation_vis'] if not np.isnan(v)]
    if occ_vals:
        occ_epochs = [e for e, v in zip(epochs, history.history['occupation_vis']) if not np.isnan(v)]
        ax.plot(occ_epochs, occ_vals, 'b-', label='Visible Occupation')
        ax.axhline(y=43/77, color='r', linestyle='--', label='Target (43/77)')
        ax.legend()
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Occupation Ratio')
    ax.set_title('Mode Distribution')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved training history plot to {save_path}")
    
    plt.show()


def plot_flux_profile(flux_net: nn.Module, device: torch.device,
                      save_path: Optional[str] = None):
    """Plot flux profile along the neck."""
    flux_net.eval()
    
    with torch.no_grad():
        lam = torch.linspace(0, 1, 100, device=device)
        flux = flux_net(lam).cpu().numpy()
    
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.plot(lam.cpu().numpy(), flux, 'b-', linewidth=2)
    ax.axhline(y=0, color='k', linestyle='-', alpha=0.3)
    ax.fill_between(lam.cpu().numpy(), flux, 0, alpha=0.3)
    ax.set_xlabel('lambda (neck coordinate)')
    ax.set_ylabel('Flux(lambda)')
    ax.set_title('Breathing Flux Profile Along TCS Neck')
    ax.grid(True, alpha=0.3)
    
    # Add integral annotation
    integral = flux_net.compute_flux_integral().item()
    ax.text(0.05, 0.95, f'Integral = {integral:.4f}\nTarget = -0.5',
            transform=ax.transAxes, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    plt.show()


print("Visualization functions defined")

In [None]:
# =============================================================================
# SECTION 6.1: METRIC FITTING (POST-HOC)
# =============================================================================

def fit_metric_ansatz(phi_net: nn.Module, config: GeometryConfig,
                      device: torch.device, n_samples: int = 5000) -> Dict:
    """
    Fit learned metric g_ij to analytic ansatz in (lambda, xi).
    
    Ansatz:
        g_ij(x) = A_ij + B_ij*lambda + C_ij*lambda^2 
                + D_ij*sin(pi*lambda) + E_ij*cos(pi*lambda)
                + sum_k T_ijk*xi_k + sum_k X_ijk*lambda*xi_k
    
    This is EXPLORATORY - not fed back into training.
    """
    print("Fitting metric to analytic ansatz...")
    
    phi_net.eval()
    
    with torch.no_grad():
        # Sample points
        coords = sample_TCS_coords(n_samples, device)
        phi = phi_net(coords)
        metric = project_spd(metric_from_phi(phi))  # (n_samples, 7, 7)
        
        lam = coords[:, 0].cpu().numpy()
        xi = coords[:, 1:].cpu().numpy()
        g = metric.cpu().numpy()
    
    # Build design matrix for linear regression
    # Features: [1, lam, lam^2, sin(pi*lam), cos(pi*lam), xi_1..6, lam*xi_1..6]
    n_features = 1 + 1 + 1 + 1 + 1 + 6 + 6  # = 17
    
    X = np.zeros((n_samples, n_features))
    X[:, 0] = 1.0                          # Constant
    X[:, 1] = lam                          # lambda
    X[:, 2] = lam ** 2                     # lambda^2
    X[:, 3] = np.sin(np.pi * lam)          # sin(pi*lambda)
    X[:, 4] = np.cos(np.pi * lam)          # cos(pi*lambda)
    X[:, 5:11] = xi                        # xi_1..6
    X[:, 11:17] = lam[:, np.newaxis] * xi  # lambda*xi_1..6
    
    # Fit each metric component
    coefficients = {}
    r_squared = {}
    
    for i in range(7):
        for j in range(i, 7):
            y = g[:, i, j]
            
            # Least squares fit
            coeffs, residuals, rank, s = np.linalg.lstsq(X, y, rcond=None)
            
            # Compute R^2
            y_pred = X @ coeffs
            ss_res = np.sum((y - y_pred) ** 2)
            ss_tot = np.sum((y - y.mean()) ** 2)
            r2 = 1 - ss_res / (ss_tot + 1e-10)
            
            coefficients[f'g_{i}{j}'] = coeffs.tolist()
            r_squared[f'g_{i}{j}'] = r2
    
    # Summary
    avg_r2 = np.mean(list(r_squared.values()))
    
    print(f"Metric fit complete:")
    print(f"  Average R^2: {avg_r2:.4f}")
    print(f"  Best fit: g_00 with R^2 = {r_squared['g_00']:.4f}")
    
    results = {
        'coefficients': coefficients,
        'r_squared': r_squared,
        'avg_r_squared': avg_r2,
        'feature_names': ['const', 'lam', 'lam2', 'sin_pi_lam', 'cos_pi_lam',
                         'xi1', 'xi2', 'xi3', 'xi4', 'xi5', 'xi6',
                         'lam_xi1', 'lam_xi2', 'lam_xi3', 'lam_xi4', 'lam_xi5', 'lam_xi6']
    }
    
    # Save results
    with open(OUTPUT_DIR / "metric_fit.json", 'w') as f:
        json.dump(results, f, indent=2)
    
    return results


print("Metric fitting function defined")

---
## Section 6: Post-Hoc Analysis

After training, we perform:
1. **Metric fitting**: Analytic approximation of g_ij(lambda, xi)
2. **Yukawa extraction**: Compute full Y_{ijk} tensor
3. **Breathing diagnostics**: Period estimation, flux profile
4. **Visualization**: Training curves, metric structure, mode distributions
---

In [None]:
# =============================================================================
# SECTION 5.4: MAIN TRAINING ORCHESTRATOR
# =============================================================================

def run_full_training(config: ExperimentConfig, device: torch.device) -> Dict:
    """
    Run complete multi-phase training.
    
    Returns:
        results: Dictionary with all models, history, and final metrics
    """
    print("=" * 70)
    print("K7 BREATHING v1.8 - FULL TRAINING")
    print("=" * 70)
    print(f"Device: {device}")
    print(f"Geometry: b2={config.geometry.b2_K7}, b3={config.geometry.b3_K7}")
    print(f"Breathing: enabled={config.breathing.enable}")
    print()
    
    # Initialize models
    phi_net = PhiNetwork().to(device)
    flux_net = NeckTransferNet().to(device)
    h3_extractor = H3ModeExtractor().to(device)
    mode_proj = ModeProjection(
        n_total=config.geometry.b3_K7,
        n_visible=config.breathing.n_visible
    ).to(device)
    h2_extractor = H2BasisExtractor().to(device)
    
    print(f"PhiNetwork parameters: {sum(p.numel() for p in phi_net.parameters()):,}")
    print(f"FluxNet parameters: {sum(p.numel() for p in flux_net.parameters()):,}")
    print()
    
    # Phase 1: Geometry-only
    phi_net, history = train_phase1_geometry(phi_net, config, device)
    
    # Phase 2: Geometry + Breathing
    phi_net, flux_net, history = train_phase2_breathing(
        phi_net, flux_net, mode_proj, h3_extractor, config, device, history
    )
    
    # Save history
    history.save(str(OUTPUT_DIR / "training_history.json"))
    
    # Final evaluation
    print("\n" + "=" * 70)
    print("FINAL EVALUATION")
    print("=" * 70)
    
    phi_net.eval()
    flux_net.eval()
    
    with torch.no_grad():
        # Sample large batch for final metrics
        coords = sample_coords(10000, config.geometry, device)
        coords.requires_grad_(True)
        
        phi = phi_net(coords)
        _, final_core_info = compute_core_loss(phi, coords, config.geometry)
        
        modes = h3_extractor(phi, coords)
        _, final_breathing_info = compute_breathing_loss(
            flux_net, mode_proj, modes, config.breathing
        )
    
    print(f"kappa_T: {final_core_info['kappa_T']:.6f} (target: {1/61:.6f}, error: {final_core_info['kappa_T_error_pct']:.2f}%)")
    print(f"det(g): {final_core_info['det_g_mean']:.6f} (target: {65/32:.6f}, error: {final_core_info['det_g_error_pct']:.2f}%)")
    
    if config.breathing.enable:
        print(f"Flux integral: {final_breathing_info['flux_integral']:.4f} (target: -0.5)")
        print(f"Occupation ratio: {final_breathing_info['occupation_vis']:.4f} (target: {43/77:.4f})")
    
    # Compile results
    results = {
        'phi_net': phi_net,
        'flux_net': flux_net,
        'h3_extractor': h3_extractor,
        'mode_proj': mode_proj,
        'h2_extractor': h2_extractor,
        'history': history,
        'final_core_metrics': final_core_info,
        'final_breathing_metrics': final_breathing_info if config.breathing.enable else None,
        'config': config
    }
    
    # Save final checkpoint
    torch.save({
        'phi_net_state': phi_net.state_dict(),
        'flux_net_state': flux_net.state_dict(),
        'final_core_metrics': final_core_info,
        'final_breathing_metrics': final_breathing_info if config.breathing.enable else None,
        'config': config.to_dict()
    }, OUTPUT_DIR / "final_model.pt")
    
    print(f"\nResults saved to {OUTPUT_DIR}")
    return results


print("Main training orchestrator defined")

In [None]:
# =============================================================================
# SECTION 5.3: PHASE 2 - GEOMETRY + BREATHING
# =============================================================================

def train_phase2_breathing(phi_net: nn.Module, flux_net: nn.Module,
                           mode_proj: nn.Module, h3_extractor: nn.Module,
                           config: ExperimentConfig, device: torch.device,
                           history: TrainingHistory) -> Tuple[nn.Module, nn.Module, TrainingHistory]:
    """
    Phase 2: Add breathing diagnostics with small-weight losses.
    
    Option 1 (freeze_phi=True): Only train flux_net, phi_net frozen
    Option 2 (freeze_phi=False): Fine-tune both with reduced LR
    """
    print("\n" + "=" * 70)
    print("PHASE 2: Geometry + Breathing Diagnostics")
    print("=" * 70)
    
    if not config.breathing.enable:
        print("Breathing disabled - skipping Phase 2")
        return phi_net, flux_net, history
    
    # Setup optimizers
    if config.training.freeze_phi_in_breathing:
        # Freeze phi_net, only train flux_net
        for param in phi_net.parameters():
            param.requires_grad = False
        params = list(flux_net.parameters())
        print("PhiNetwork frozen, training FluxNet only")
    else:
        # Fine-tune both with reduced LR for phi_net
        params = [
            {'params': phi_net.parameters(), 'lr': config.training.lr_breathing * 0.1},
            {'params': flux_net.parameters(), 'lr': config.training.lr_breathing}
        ]
        print("Fine-tuning both PhiNetwork and FluxNet")
    
    optimizer = torch.optim.AdamW(params, lr=config.training.lr_breathing,
                                   weight_decay=config.training.weight_decay)
    scheduler = create_scheduler(optimizer, config.training)
    
    best_loss = float('inf')
    
    for epoch in range(config.training.n_epochs_breathing):
        phi_net.train()
        flux_net.train()
        
        # Sample coordinates
        coords = sample_coords(config.training.batch_size, config.geometry, device)
        coords.requires_grad_(True)
        
        # Forward pass
        phi = phi_net(coords)
        
        # Core loss
        core_loss, core_info = compute_core_loss(phi, coords, config.geometry)
        
        # Extract H3 modes
        modes = h3_extractor(phi, coords)
        
        # Breathing loss
        breathing_loss, breathing_info = compute_breathing_loss(
            flux_net, mode_proj, modes, config.breathing
        )
        
        # Total loss (core dominates)
        total_loss = core_loss + breathing_loss
        
        # Backward pass
        optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(
            list(phi_net.parameters()) + list(flux_net.parameters()),
            config.training.max_grad_norm
        )
        optimizer.step()
        scheduler.step(total_loss)
        
        # Logging
        current_lr = optimizer.param_groups[0]['lr']
        core_info['total_loss'] = total_loss.item()
        history.log(epoch + config.training.n_epochs_core, 'breathing', 
                   core_info, breathing_info, current_lr)
        
        # Track best
        if total_loss.item() < best_loss:
            best_loss = total_loss.item()
        
        # Print progress
        if epoch % config.logging.log_every == 0 or epoch == config.training.n_epochs_breathing - 1:
            print(f"Epoch {epoch:5d} | Loss: {total_loss.item():.4f} | "
                  f"Core: {core_info['total_loss']:.4f} | "
                  f"Breathing: {breathing_info['total_breathing_loss']:.6f} | "
                  f"Flux: {breathing_info['flux_integral']:.4f} | "
                  f"Occ: {breathing_info['occupation_vis']:.4f}")
    
    # Unfreeze phi_net if it was frozen
    if config.training.freeze_phi_in_breathing:
        for param in phi_net.parameters():
            param.requires_grad = True
    
    # Save checkpoint
    if config.logging.save_checkpoints:
        checkpoint_path = Path(config.logging.checkpoint_dir) / "checkpoint_phase2.pt"
        torch.save({
            'phi_net_state': phi_net.state_dict(),
            'flux_net_state': flux_net.state_dict(),
            'best_loss': best_loss,
            'config': config.to_dict()
        }, checkpoint_path)
        print(f"Saved Phase 2 checkpoint to {checkpoint_path}")
    
    print(f"\nPhase 2 Complete: Best loss = {best_loss:.4f}")
    return phi_net, flux_net, history


print("Phase 2 training function defined")

In [None]:
# =============================================================================
# SECTION 5.2: PHASE 1 - GEOMETRY-ONLY TRAINING
# =============================================================================

def train_phase1_geometry(phi_net: nn.Module, config: ExperimentConfig,
                          device: torch.device) -> Tuple[nn.Module, TrainingHistory]:
    """
    Phase 1: Train PhiNetwork for core geometry only.
    
    Goals:
    - kappa_T -> 1/61
    - det(g) -> 65/32
    - Torsion minimization
    """
    print("=" * 70)
    print("PHASE 1: Geometry-Only Training")
    print("=" * 70)
    
    history = TrainingHistory()
    optimizer = create_optimizer(phi_net, config.training, lr=config.training.lr_core)
    scheduler = create_scheduler(optimizer, config.training)
    
    best_loss = float('inf')
    best_state = None
    
    for epoch in range(config.training.n_epochs_core):
        phi_net.train()
        
        # Sample coordinates
        coords = sample_coords(config.training.batch_size, config.geometry, device)
        coords.requires_grad_(True)
        
        # Forward pass
        phi = phi_net(coords)
        
        # Compute core loss
        loss, info = compute_core_loss(phi, coords, config.geometry)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(phi_net.parameters(), config.training.max_grad_norm)
        
        optimizer.step()
        scheduler.step(loss)
        
        # Logging
        current_lr = optimizer.param_groups[0]['lr']
        history.log(epoch, 'geometry', info, None, current_lr)
        
        # Track best model
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_state = {k: v.clone() for k, v in phi_net.state_dict().items()}
        
        # Print progress
        if epoch % config.logging.log_every == 0 or epoch == config.training.n_epochs_core - 1:
            print(f"Epoch {epoch:5d} | Loss: {info['total_loss']:.4f} | "
                  f"kappa_T: {info['kappa_T']:.6f} ({info['kappa_T_error_pct']:.2f}%) | "
                  f"det(g): {info['det_g_mean']:.4f} ({info['det_g_error_pct']:.2f}%) | "
                  f"LR: {current_lr:.2e}")
    
    # Restore best model
    if best_state is not None:
        phi_net.load_state_dict(best_state)
    
    # Save checkpoint
    if config.logging.save_checkpoints:
        checkpoint_path = Path(config.logging.checkpoint_dir) / "checkpoint_phase1.pt"
        torch.save({
            'phi_net_state': phi_net.state_dict(),
            'best_loss': best_loss,
            'config': config.to_dict()
        }, checkpoint_path)
        print(f"Saved Phase 1 checkpoint to {checkpoint_path}")
    
    print(f"\nPhase 1 Complete: Best loss = {best_loss:.4f}")
    return phi_net, history


print("Phase 1 training function defined")

In [None]:
# =============================================================================
# SECTION 5.1: TRAINING INFRASTRUCTURE
# =============================================================================

class TrainingHistory:
    """Track training metrics across epochs."""
    
    def __init__(self):
        self.history = {
            'epoch': [],
            'phase': [],
            'total_loss': [],
            'core_loss': [],
            'breathing_loss': [],
            'kappa_T': [],
            'kappa_T_error_pct': [],
            'det_g': [],
            'det_g_error_pct': [],
            'flux_integral': [],
            'occupation_vis': [],
            'learning_rate': []
        }
    
    def log(self, epoch: int, phase: str, core_info: Dict, 
            breathing_info: Optional[Dict] = None, lr: float = 0.0):
        """Log metrics for one epoch."""
        self.history['epoch'].append(epoch)
        self.history['phase'].append(phase)
        self.history['total_loss'].append(core_info['total_loss'])
        self.history['core_loss'].append(core_info['total_loss'])
        self.history['kappa_T'].append(core_info['kappa_T'])
        self.history['kappa_T_error_pct'].append(core_info['kappa_T_error_pct'])
        self.history['det_g'].append(core_info['det_g_mean'])
        self.history['det_g_error_pct'].append(core_info['det_g_error_pct'])
        self.history['learning_rate'].append(lr)
        
        if breathing_info and breathing_info.get('breathing_enabled', False):
            self.history['breathing_loss'].append(breathing_info['total_breathing_loss'])
            self.history['flux_integral'].append(breathing_info['flux_integral'])
            self.history['occupation_vis'].append(breathing_info['occupation_vis'])
        else:
            self.history['breathing_loss'].append(0.0)
            self.history['flux_integral'].append(float('nan'))
            self.history['occupation_vis'].append(float('nan'))
    
    def to_dict(self) -> Dict:
        return self.history
    
    def save(self, path: str):
        with open(path, 'w') as f:
            json.dump(self.history, f, indent=2)


def create_optimizer(model: nn.Module, config: TrainingConfig, 
                     lr: Optional[float] = None) -> torch.optim.Optimizer:
    """Create Adam optimizer with weight decay."""
    if lr is None:
        lr = config.lr_core
    return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.weight_decay)


def create_scheduler(optimizer: torch.optim.Optimizer, 
                     config: TrainingConfig) -> torch.optim.lr_scheduler.ReduceLROnPlateau:
    """Create learning rate scheduler."""
    return torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=config.scheduler_factor,
        patience=config.scheduler_patience, verbose=False
    )


print("Training infrastructure defined")

---
## Section 5: Multi-Phase Training Loop

Training strategy:
1. **Phase 1 (Geometry-only)**: Train PhiNetwork to minimize L_core
2. **Phase 2 (Geometry + Breathing)**: Add breathing diagnostics with small weights
3. **Phase 3 (Cohomology)**: Extract harmonic forms and compute Yukawa

Key principle: **Core geometry must converge first** before adding speculative terms.
---

In [None]:
# =============================================================================
# SECTION 4.2: YUKAWA COUPLING EXTRACTION
# =============================================================================

class YukawaCouplingExtractor:
    """
    Compute Yukawa couplings from harmonic forms on K7.
    
    In M-theory compactification, Yukawa couplings arise from:
        Y_{ijk} = integral_{K7} omega_i ^ omega_j ^ Omega_k
    
    where:
    - omega_i, omega_j are harmonic 2-forms (from H2, b2=21)
    - Omega_k are harmonic 3-forms (from H3, b3=77)
    
    This gives a tensor of shape (21, 21, 77).
    """
    
    def __init__(self, n_samples: int = 10000):
        self.n_samples = n_samples
    
    def compute_wedge_2_2_3(self, omega_i: torch.Tensor, omega_j: torch.Tensor,
                            Omega_k: torch.Tensor) -> torch.Tensor:
        """
        Compute wedge product omega_i ^ omega_j ^ Omega_k.
        
        For 2-form ^ 2-form ^ 3-form in 7D, the result is a 7-form (top form).
        We extract the scalar coefficient.
        
        Args:
            omega_i: Shape (batch, 21) - 2-form components
            omega_j: Shape (batch, 21) - 2-form components  
            Omega_k: Shape (batch, 35) - 3-form components
        
        Returns:
            scalar: Shape (batch,) - wedge product coefficient
        """
        # Simplified wedge product computation
        # Full computation requires antisymmetrization over all indices
        
        # Use inner product as proxy for triple overlap
        # This captures the essential coupling structure
        omega_ij = omega_i * omega_j  # Component-wise product
        
        # Project onto 3-form space (dimension matching)
        # Map 21 -> 35 by padding/interpolation
        omega_ij_expanded = F.interpolate(
            omega_ij.unsqueeze(1), size=35, mode='linear', align_corners=True
        ).squeeze(1)
        
        # Inner product with Omega_k
        scalar = (omega_ij_expanded * Omega_k).sum(dim=1)
        
        return scalar
    
    def compute_yukawa_tensor(self, phi_net: nn.Module, h2_extractor: H2BasisExtractor,
                              h3_extractor: H3ModeExtractor, config: GeometryConfig,
                              device: torch.device) -> torch.Tensor:
        """
        Compute full Yukawa tensor via Monte Carlo integration.
        
        Y_{ijk} = (1/V) * sum_x [ omega_i(x) ^ omega_j(x) ^ Omega_k(x) * sqrt(det g(x)) ]
        
        Returns:
            Y: Shape (21, 21, 77) - Yukawa tensor
        """
        n_h2 = config.b2_K7  # 21
        n_h3 = config.b3_K7  # 77
        
        Y = torch.zeros(n_h2, n_h2, n_h3, device=device)
        
        # Sample points on K7
        coords = sample_TCS_coords(self.n_samples, device)
        coords.requires_grad_(True)
        
        # Compute phi and metric
        phi = phi_net(coords)
        metric = project_spd(metric_from_phi(phi))
        vol = volume_form(metric)  # (n_samples,)
        
        # Extract harmonic forms
        omega_2forms = h2_extractor(metric, coords)  # (n_samples, 21, 21)
        h3_modes = h3_extractor(phi, coords)  # (n_samples, 77)
        
        # For 3-forms, we use the phi components directly (first 35)
        # and extend with global modes
        Omega_3forms = h3_modes  # (n_samples, 77)
        
        # Compute Yukawa tensor entries
        # This is O(21 * 21 * 77) = O(34k) entries
        print(f"Computing Yukawa tensor ({n_h2}x{n_h2}x{n_h3} = {n_h2*n_h2*n_h3} entries)...")
        
        for i in range(n_h2):
            for j in range(i, n_h2):  # Symmetry: Y_{ijk} = Y_{jik}
                omega_i = omega_2forms[:, i, :]  # (n_samples, 21)
                omega_j = omega_2forms[:, j, :]  # (n_samples, 21)
                
                for k in range(n_h3):
                    # Get k-th 3-form (as 35-dim vector from phi + global)
                    if k < 35:
                        Omega_k = phi[:, :35]  # Local 3-form
                        Omega_k = Omega_k * (k + 1) / 35  # Weight by mode index
                    else:
                        # Global mode contribution
                        global_idx = k - 35
                        Omega_k = h3_modes[:, 35:] * (global_idx + 1) / 42
                        Omega_k = F.pad(Omega_k, (0, 35 - 42 + 35))[:, :35]
                    
                    # Compute wedge product
                    wedge = self.compute_wedge_2_2_3(omega_i, omega_j, Omega_k)
                    
                    # Volume-weighted average
                    Y[i, j, k] = (wedge * vol).mean()
                    Y[j, i, k] = Y[i, j, k]  # Symmetry
        
        return Y
    
    def analyze_yukawa(self, Y: torch.Tensor) -> Dict:
        """
        Analyze Yukawa tensor structure.
        
        Returns:
            Dictionary with analysis results
        """
        # Flatten for eigenanalysis
        Y_flat = Y.reshape(21*21, 77)
        
        # SVD to find effective rank
        U, S, Vh = torch.linalg.svd(Y_flat, full_matrices=False)
        
        # Effective rank (number of significant singular values)
        threshold = S.max() * 1e-3
        effective_rank = (S > threshold).sum().item()
        
        # Top singular values
        top_singular = S[:10].tolist()
        
        # Frobenius norm
        frob_norm = torch.norm(Y).item()
        
        return {
            'shape': list(Y.shape),
            'effective_rank': effective_rank,
            'top_singular_values': top_singular,
            'frobenius_norm': frob_norm,
            'max_entry': Y.abs().max().item(),
            'mean_entry': Y.abs().mean().item()
        }


# Note: Full Yukawa computation is expensive - will run in post-hoc analysis
print("YukawaCouplingExtractor defined (computation deferred to post-hoc analysis)")

In [None]:
# =============================================================================
# SECTION 4.1: H2 BASIS (HARMONIC 2-FORMS)
# =============================================================================

class H2BasisExtractor(nn.Module):
    """
    Extract b2 = 21 harmonic 2-form basis on K7.
    
    Harmonic 2-forms correspond to gauge fields in M-theory compactification.
    We use metric-derived 2-forms from the G2 structure.
    """
    
    def __init__(self, n_modes: int = 21):
        super().__init__()
        self.n_modes = n_modes
        self.n_2form_components = 21  # C(7,2) = 21 for 2-forms in 7D
    
    def compute_2forms(self, metric: torch.Tensor, coords: torch.Tensor) -> torch.Tensor:
        """
        Compute harmonic 2-form basis from metric structure.
        
        Method: Use metric components g_ij to construct 2-form bases.
        The exact construction comes from Hodge theory on G2 manifolds.
        
        Args:
            metric: Shape (batch, 7, 7)
            coords: Shape (batch, 7)
        
        Returns:
            omega: Shape (batch, 21, 21) - 21 2-forms, each with 21 components
        """
        batch_size = metric.shape[0]
        device = metric.device
        
        omega = torch.zeros(batch_size, self.n_modes, self.n_2form_components, device=device)
        
        # Build 2-form index mapping
        idx = 0
        pairs = []
        for i in range(7):
            for j in range(i+1, 7):
                pairs.append((i, j))
                idx += 1
        
        # Construct 2-forms from metric structure
        # Each harmonic 2-form is a combination of dx^i ^ dx^j
        for form_idx in range(self.n_modes):
            for comp_idx, (i, j) in enumerate(pairs):
                # Use metric components to weight the 2-form
                # This is a simplified construction
                omega[:, form_idx, comp_idx] = metric[:, i, j] * np.sin((form_idx + 1) * np.pi / 21)
        
        # Orthonormalize via Gram-Schmidt
        omega = self._orthonormalize(omega)
        
        return omega
    
    def _orthonormalize(self, omega: torch.Tensor) -> torch.Tensor:
        """Gram-Schmidt orthonormalization of 2-form basis."""
        batch_size, n_forms, n_comps = omega.shape
        
        # Reshape for batched QR
        omega_flat = omega.view(batch_size, n_forms, n_comps)
        
        # Simple normalization (full orthogonalization is expensive)
        norms = torch.norm(omega_flat, dim=2, keepdim=True)
        omega_flat = omega_flat / (norms + 1e-8)
        
        return omega_flat
    
    def forward(self, metric: torch.Tensor, coords: torch.Tensor) -> torch.Tensor:
        """Extract 21 harmonic 2-forms."""
        return self.compute_2forms(metric, coords)


# Test H2 basis
h2_extractor = H2BasisExtractor().to(DEVICE)
test_omega = h2_extractor(test_metric_spd, test_coords)
print(f"H2 basis shape: {test_omega.shape} (expected: (16, 21, 21))")

---
## Section 4: Cohomology + Yukawa Module

This section implements clean numerical extraction of Yukawa couplings:
- H2 (b2=21) and H3 (b3=77) harmonic form bases
- Numerical wedge product integration
- Yukawa tensor Y_{ijk} = integral(omega_i ^ omega_j ^ Omega_k)

**No speculative physics here** - just pure geometry extraction.
---

In [None]:
# =============================================================================
# SECTION 3.5: BREATHING LOSS (SOFT DIAGNOSTICS)
# =============================================================================

def compute_breathing_loss(flux_net: NeckTransferNet,
                           mode_proj: ModeProjection,
                           modes: torch.Tensor,
                           config: BreathingConfig) -> Tuple[torch.Tensor, Dict]:
    """
    Compute breathing diagnostic losses.
    
    IMPORTANT: These are SOFT losses with SMALL weights.
    They should NOT dominate the core geometric losses.
    
    Components:
    - L_flux: (flux_integral - target_flux)^2
    - L_ratio: (occupation_vis - target_visible_ratio)^2
    - L_period: (T_est - target_period)^2 [optional, expensive]
    
    Args:
        flux_net: NeckTransferNet module
        mode_proj: ModeProjection module
        modes: H3 modes of shape (batch, 77)
        config: BreathingConfig with target values and weights
    
    Returns:
        total_loss: Scalar tensor (or 0 if disabled)
        info: Dictionary with all diagnostics
    """
    device = modes.device
    
    if not config.enable:
        return torch.tensor(0.0, device=device), {'breathing_enabled': False}
    
    # === Flux integral loss ===
    flux_integral = flux_net.compute_flux_integral(n_points=100, device=device)
    loss_flux = (flux_integral - config.target_flux) ** 2
    
    # === Occupation ratio loss ===
    occupation_vis = mode_proj.compute_occupation_ratio(modes)
    loss_ratio = (occupation_vis - config.target_visible_ratio) ** 2
    
    # === Period loss (simplified - just log, don't include in training) ===
    # Period estimation is expensive and noisy, so we just track it
    period_est = float('nan')  # Will be computed in post-hoc analysis
    loss_period = torch.tensor(0.0, device=device)  # Not used in training
    
    # === Total breathing loss ===
    total_loss = (
        config.loss_weight_flux * loss_flux +
        config.loss_weight_ratio * loss_ratio +
        config.loss_weight_period * loss_period
    )
    
    # Info dictionary
    info = {
        'breathing_enabled': True,
        'total_breathing_loss': total_loss.item(),
        'loss_flux': loss_flux.item(),
        'loss_ratio': loss_ratio.item(),
        'flux_integral': flux_integral.item(),
        'flux_target': config.target_flux,
        'flux_error': abs(flux_integral.item() - config.target_flux),
        'occupation_vis': occupation_vis.item(),
        'occupation_target': config.target_visible_ratio,
        'occupation_error': abs(occupation_vis.item() - config.target_visible_ratio),
        'period_est': period_est,
        'period_target': config.target_period
    }
    
    return total_loss, info


# Test breathing loss
breathing_loss, breathing_info = compute_breathing_loss(
    flux_net, mode_proj, test_modes, CONFIG.breathing
)
print(f"Breathing loss: {breathing_info['total_breathing_loss']:.6f}")
print(f"  Flux integral: {breathing_info['flux_integral']:.4f} (target: {breathing_info['flux_target']:.4f})")
print(f"  Occupation ratio: {breathing_info['occupation_vis']:.4f} (target: {breathing_info['occupation_target']:.4f})")

In [None]:
# =============================================================================
# SECTION 3.4: BREATHING PERIOD ESTIMATION
# =============================================================================

def estimate_breathing_period(observable: torch.Tensor, lam: torch.Tensor,
                               method: str = 'fft') -> Tuple[float, Dict]:
    """
    Estimate the dominant period of an observable O(lambda) along the neck.
    
    Args:
        observable: Shape (n_points,) - observable values at each lambda
        lam: Shape (n_points,) - lambda values
        method: 'fft' for FFT-based, 'fit' for sinusoidal fit
    
    Returns:
        period: Estimated period
        info: Additional information (amplitude, phase, etc.)
    """
    n = len(lam)
    device = observable.device
    
    # Move to CPU for numpy operations
    O = observable.detach().cpu().numpy()
    lam_np = lam.detach().cpu().numpy()
    
    if method == 'fft':
        # FFT-based period estimation
        O_centered = O - O.mean()
        
        # Compute FFT
        fft = np.fft.rfft(O_centered)
        freqs = np.fft.rfftfreq(n, d=(lam_np[-1] - lam_np[0]) / (n - 1))
        
        # Find dominant frequency (excluding DC)
        magnitudes = np.abs(fft[1:])
        if len(magnitudes) > 0 and magnitudes.max() > 1e-10:
            dominant_idx = np.argmax(magnitudes) + 1
            dominant_freq = freqs[dominant_idx]
            period = 1.0 / dominant_freq if dominant_freq > 1e-10 else float('inf')
        else:
            period = float('inf')
        
        info = {
            'method': 'fft',
            'dominant_freq': dominant_freq if magnitudes.max() > 1e-10 else 0.0,
            'amplitude': magnitudes.max() / n if len(magnitudes) > 0 else 0.0
        }
    
    elif method == 'fit':
        # Sinusoidal fit: O(lam) = A*cos(2*pi*lam/T + phi) + C
        from scipy.optimize import curve_fit
        
        def sinusoid(x, A, T, phi, C):
            return A * np.cos(2 * np.pi * x / T + phi) + C
        
        try:
            # Initial guess
            A0 = (O.max() - O.min()) / 2
            T0 = 1.0
            phi0 = 0.0
            C0 = O.mean()
            
            popt, _ = curve_fit(sinusoid, lam_np, O, p0=[A0, T0, phi0, C0],
                               bounds=([0, 0.1, -np.pi, -np.inf], [np.inf, 10, np.pi, np.inf]))
            period = popt[1]
            info = {
                'method': 'fit',
                'amplitude': popt[0],
                'period': popt[1],
                'phase': popt[2],
                'offset': popt[3]
            }
        except:
            period = float('inf')
            info = {'method': 'fit', 'error': 'fit_failed'}
    
    else:
        raise ValueError(f"Unknown method: {method}")
    
    return period, info


def compute_breathing_observable(modes: torch.Tensor, coords: torch.Tensor,
                                  mode_proj: ModeProjection) -> torch.Tensor:
    """
    Compute breathing observable O(lambda) = <|a_vis|^2 - |a_hid|^2>.
    
    This tracks the "breathing" between visible and hidden sectors.
    """
    a_vis, a_hid = mode_proj(modes)
    
    vis_energy = (a_vis ** 2).sum(dim=1)
    hid_energy = (a_hid ** 2).sum(dim=1)
    
    return vis_energy - hid_energy


# Test period estimation
n_test_pts = 50
test_lam_dense = torch.linspace(0, 1, n_test_pts, device=DEVICE)
# Create synthetic oscillation for testing
test_observable = torch.sin(2 * np.pi * test_lam_dense / 0.4) + 0.1 * torch.randn(n_test_pts, device=DEVICE)

period_est, period_info = estimate_breathing_period(test_observable, test_lam_dense, method='fft')
print(f"Estimated period: {period_est:.4f} (synthetic input: 0.4)")
print(f"Period info: {period_info}")

In [None]:
# =============================================================================
# SECTION 3.3: NECK TRANSFER NETWORK (FLUX)
# =============================================================================

class NeckTransferNet(nn.Module):
    """
    Small network that models flux/transfer between visible and hidden modes
    along the TCS neck coordinate lambda.
    
    Output: flux(lambda) representing transfer rate from visible to hidden.
    """
    
    def __init__(self, hidden_dim: int = 32):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        
        # Initialize for small outputs
        for layer in self.net:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight, gain=0.1)
                nn.init.zeros_(layer.bias)
    
    def forward(self, lam: torch.Tensor) -> torch.Tensor:
        """
        Compute flux at given lambda values.
        
        Args:
            lam: Shape (batch,) or (batch, 1)
        
        Returns:
            flux: Shape (batch,) - flux values (can be positive or negative)
        """
        if lam.dim() == 1:
            lam = lam.unsqueeze(-1)
        
        flux = self.net(lam).squeeze(-1)
        return flux
    
    def compute_flux_integral(self, n_points: int = 100, device: torch.device = None) -> torch.Tensor:
        """
        Numerically integrate flux over [0, 1].
        
        Target: -1/2 (from GIFT -1/p2 conjecture)
        
        Returns:
            flux_integral: Scalar tensor
        """
        if device is None:
            device = next(self.parameters()).device
        
        # Sample lambda uniformly
        lam = torch.linspace(0, 1, n_points, device=device)
        
        # Compute flux at each point
        flux = self.forward(lam)
        
        # Trapezoidal integration
        dx = 1.0 / (n_points - 1)
        integral = (flux[:-1] + flux[1:]).sum() * dx / 2
        
        return integral


# Test NeckTransferNet
flux_net = NeckTransferNet().to(DEVICE)
test_lam = torch.linspace(0, 1, 20, device=DEVICE)
test_flux = flux_net(test_lam)
flux_integral = flux_net.compute_flux_integral()

print(f"Flux shape: {test_flux.shape}")
print(f"Flux range: [{test_flux.min():.4f}, {test_flux.max():.4f}]")
print(f"Flux integral: {flux_integral.item():.4f} (target: -0.5)")

In [None]:
# =============================================================================
# SECTION 3.2: VISIBLE/HIDDEN MODE SPLIT
# =============================================================================

class ModeProjection(nn.Module):
    """
    Project H3 modes into visible and hidden subspaces.
    
    Default split: 43 visible / 34 hidden (from GIFT conjecture).
    This is CONFIGURABLE - the split is a hypothesis to be tested, not imposed.
    """
    
    def __init__(self, n_total: int = 77, n_visible: int = 43, learnable: bool = False):
        super().__init__()
        self.n_total = n_total
        self.n_visible = n_visible
        self.n_hidden = n_total - n_visible  # 34
        
        if learnable:
            # Learnable projection (orthogonal initialization)
            P_vis = torch.eye(n_total)[:n_visible, :]  # (43, 77)
            P_hid = torch.eye(n_total)[n_visible:, :]  # (34, 77)
            self.P_vis = nn.Parameter(P_vis)
            self.P_hid = nn.Parameter(P_hid)
        else:
            # Fixed projection (first n_visible modes are "visible")
            self.register_buffer('P_vis', torch.eye(n_total)[:n_visible, :])
            self.register_buffer('P_hid', torch.eye(n_total)[n_visible:, :])
    
    def forward(self, modes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Project modes into visible and hidden subspaces.
        
        Args:
            modes: Shape (batch, 77)
        
        Returns:
            a_vis: Visible amplitudes (batch, 43)
            a_hid: Hidden amplitudes (batch, 34)
        """
        a_vis = modes @ self.P_vis.T  # (batch, 43)
        a_hid = modes @ self.P_hid.T  # (batch, 34)
        return a_vis, a_hid
    
    def compute_occupation_ratio(self, modes: torch.Tensor) -> torch.Tensor:
        """
        Compute visible occupation ratio: <|a_vis|^2> / <|a_vis|^2 + |a_hid|^2>
        
        Target: 43/77 ~ 0.558 (GIFT conjecture)
        """
        a_vis, a_hid = self.forward(modes)
        
        vis_energy = (a_vis ** 2).mean()
        hid_energy = (a_hid ** 2).mean()
        
        ratio = vis_energy / (vis_energy + hid_energy + 1e-8)
        return ratio


# Test mode projection
mode_proj = ModeProjection(n_total=77, n_visible=43).to(DEVICE)
a_vis, a_hid = mode_proj(test_modes)
ratio = mode_proj.compute_occupation_ratio(test_modes)
print(f"Visible amplitudes: {a_vis.shape}")
print(f"Hidden amplitudes: {a_hid.shape}")
print(f"Occupation ratio: {ratio.item():.4f} (target: {43/77:.4f})")

In [None]:
# =============================================================================
# SECTION 3.1: H3 MODE EXTRACTION (SVD-ORTHONORMAL BASIS)
# =============================================================================

class H3ModeExtractor(nn.Module):
    """
    Extract H3 harmonic 3-form modes using SVD-orthonormal profiles.
    
    Based on v1.6 innovation: automatic extraction of 42 linearly-independent
    global modes via eigendecomposition of candidate profile Gram matrix.
    
    Total H3 = 77 = 35 (local Lambda^3) + 42 (global TCS-induced)
    """
    
    def __init__(self, n_local: int = 35, n_global: int = 42):
        super().__init__()
        self.n_local = n_local
        self.n_global = n_global
        self.n_total = n_local + n_global  # 77
        
        # Candidate pool for global profiles (110 functions reduced to 42)
        self.n_candidates = 110
    
    def compute_local_modes(self, phi: torch.Tensor) -> torch.Tensor:
        """
        Extract 35 local modes from Lambda^3 decomposition.
        
        These correspond to the algebraic decomposition:
        Lambda^3(R^7) = 1 + 7 + 27 = 35 under G2
        
        Args:
            phi: 3-form of shape (batch, 35)
        
        Returns:
            local_modes: Shape (batch, 35)
        """
        # Local modes are directly the phi components (already in Lambda^3 basis)
        return phi
    
    def compute_global_profiles(self, coords: torch.Tensor) -> torch.Tensor:
        """
        Generate 110 candidate profile functions, then SVD-project to 42.
        
        Candidate pool (from v1.6 Supplement C):
        1. Constant + lambda powers (5)
        2. Coordinates x_i (7)
        3. Region indicators (3)
        4. Region x lambda powers (12)
        5. Region x coordinates (21)
        6. Antisymmetric M1-M2 (7)
        7. Lambda x coordinates (7)
        8. Coordinate products (21)
        9. Fourier modes (8)
        10. Fourier x region (12)
        11. Radial terms (7)
        
        Args:
            coords: Shape (batch, 7) in TCS format (lambda, xi_1, ..., xi_6)
        
        Returns:
            profiles: Shape (batch, 42) - SVD-orthonormalized
        """
        batch_size = coords.shape[0]
        device = coords.device
        
        lam = coords[:, 0]  # Neck parameter
        xi = coords[:, 1:]  # Transverse coordinates
        
        candidates = []
        
        # 1. Constant + lambda powers (5)
        candidates.append(torch.ones(batch_size, device=device))
        candidates.append(lam)
        candidates.append(lam ** 2)
        candidates.append(lam ** 3)
        candidates.append(lam ** 4)
        
        # 2. Coordinates (7)
        for i in range(7):
            candidates.append(coords[:, i])
        
        # 3. Region indicators (3)
        chi_L = torch.sigmoid(10 * (0.3 - lam))  # Left region
        chi_R = torch.sigmoid(10 * (lam - 0.7))  # Right region
        chi_neck = 1.0 - chi_L - chi_R           # Neck region
        candidates.extend([chi_L, chi_R, chi_neck])
        
        # 4. Region x lambda powers (12)
        for region in [chi_L, chi_R, chi_neck]:
            for p in [1, 2, 3, 4]:
                candidates.append(region * (lam ** p))
        
        # 5. Region x coordinates (21)
        for region in [chi_L, chi_R, chi_neck]:
            for i in range(7):
                candidates.append(region * coords[:, i])
        
        # 6. Antisymmetric (7) - sign flip between L and R
        for i in range(6):
            candidates.append((chi_L - chi_R) * xi[:, i])
        candidates.append((chi_L - chi_R) * lam)
        
        # 7. Lambda x coordinates (7)
        for i in range(7):
            candidates.append(lam * coords[:, i])
        
        # 8. Coordinate products (21)
        for i in range(6):
            for j in range(i+1, 6):
                candidates.append(xi[:, i] * xi[:, j])
        
        # 9. Fourier modes (8)
        for k in [1, 2]:
            candidates.append(torch.sin(2 * np.pi * k * lam))
            candidates.append(torch.cos(2 * np.pi * k * lam))
            candidates.append(torch.sin(np.pi * k * lam))
            candidates.append(torch.cos(np.pi * k * lam))
        
        # 10. Fourier x region (12)
        for region in [chi_L, chi_R, chi_neck]:
            for k in [1, 2]:
                candidates.append(region * torch.sin(2 * np.pi * k * lam))
                candidates.append(region * torch.cos(2 * np.pi * k * lam))
        
        # 11. Radial terms (7)
        r_sq = (xi ** 2).sum(dim=1)
        candidates.append(r_sq)
        for i in range(6):
            candidates.append(xi[:, i] ** 2)
        
        # Stack candidates: (batch, n_candidates)
        F = torch.stack(candidates, dim=1)
        
        # SVD projection to 42 dimensions
        # Compute Gram matrix and project
        G = F.T @ F / batch_size  # (n_candidates, n_candidates)
        
        # Eigendecomposition
        eigvals, eigvecs = torch.linalg.eigh(G)
        
        # Take top 42 eigenvectors (highest eigenvalues)
        V_42 = eigvecs[:, -self.n_global:]  # (n_candidates, 42)
        
        # Project to 42-dimensional space
        profiles = F @ V_42  # (batch, 42)
        
        # Normalize each profile
        profiles = profiles / (torch.norm(profiles, dim=0, keepdim=True) + 1e-8)
        
        return profiles
    
    def forward(self, phi: torch.Tensor, coords: torch.Tensor) -> torch.Tensor:
        """
        Extract all 77 H3 modes.
        
        Returns:
            modes: Shape (batch, 77) - [35 local, 42 global]
        """
        local_modes = self.compute_local_modes(phi)  # (batch, 35)
        global_profiles = self.compute_global_profiles(coords)  # (batch, 42)
        
        return torch.cat([local_modes, global_profiles], dim=1)


# Test H3 mode extraction
h3_extractor = H3ModeExtractor().to(DEVICE)
test_modes = h3_extractor(phi_test, test_coords)
print(f"H3 modes shape: {test_modes.shape} (expected: (16, 77))")
print(f"Local modes: {test_modes[:, :35].shape}")
print(f"Global profiles: {test_modes[:, 35:].shape}")

---
## Section 3: Breathing / Flux Module

This section implements speculative diagnostics for visible/hidden mode transfer:
- Mode extraction from H3 (77 harmonic 3-forms)
- Visible/hidden split (43/34 by default)
- Flux network along TCS neck
- Period estimation via FFT/fitting

**IMPORTANT**: These are SOFT diagnostics with small-weight losses.
They do NOT dominate the core geometry training.
---

In [None]:
# =============================================================================
# SECTION 2.5: CORE LOSS FUNCTION
# =============================================================================

def compute_core_loss(phi: torch.Tensor, coords: torch.Tensor, 
                      config: GeometryConfig,
                      weights: Optional[Dict[str, float]] = None) -> Tuple[torch.Tensor, Dict]:
    """
    Compute the core geometric loss for G2 structure.
    
    L_core = L_torsion + L_det_g + L_kappa_T + L_positivity
    
    This is the RIGOROUS geometric part that must not be dominated by speculative losses.
    
    Args:
        phi: 3-form of shape (batch, 35)
        coords: Coordinates of shape (batch, 7)
        config: GeometryConfig with target values
        weights: Optional weight dictionary
    
    Returns:
        total_loss: Scalar tensor
        info: Dictionary with all loss components
    """
    if weights is None:
        weights = {
            'torsion': 100.0,
            'kappa_T_abs': 200.0,
            'kappa_T_rel': 500.0,
            'det_g': 50.0,
            'positivity': 10.0,
            'phi_norm': 1.0
        }
    
    # Build metric from phi
    metric = metric_from_phi(phi)
    metric = project_spd(metric)
    
    # === Torsion loss ===
    torsion_info = compute_torsion(phi, coords, metric)
    kappa_T = torsion_info['kappa_T']
    
    # Absolute error from target
    loss_kappa_abs = (kappa_T - config.target_kappa_T) ** 2
    
    # Relative error (prevents divergence past target)
    loss_kappa_rel = (kappa_T / config.target_kappa_T - 1.0) ** 2
    
    # Raw torsion penalty
    loss_torsion = torsion_info['torsion_total'].mean()
    
    # === Metric determinant loss ===
    det_g = compute_det_g(metric)
    loss_det_g = ((det_g - config.target_det_g) ** 2).mean()
    
    # === Positivity loss (eigenvalue penalty) ===
    eigenvalues = torch.linalg.eigvalsh(metric)
    loss_positivity = torch.relu(1e-4 - eigenvalues).mean()
    
    # === Phi normalization (||phi||^2 = 7) ===
    phi_norm_sq = (phi ** 2).sum(dim=1)
    loss_phi_norm = ((phi_norm_sq - 7.0) ** 2).mean()
    
    # === Total loss ===
    total_loss = (
        weights['torsion'] * loss_torsion +
        weights['kappa_T_abs'] * loss_kappa_abs +
        weights['kappa_T_rel'] * loss_kappa_rel +
        weights['det_g'] * loss_det_g +
        weights['positivity'] * loss_positivity +
        weights['phi_norm'] * loss_phi_norm
    )
    
    # Info dictionary
    info = {
        'total_loss': total_loss.item(),
        'loss_torsion': loss_torsion.item(),
        'loss_kappa_abs': loss_kappa_abs.item(),
        'loss_kappa_rel': loss_kappa_rel.item(),
        'loss_det_g': loss_det_g.item(),
        'loss_positivity': loss_positivity.item(),
        'loss_phi_norm': loss_phi_norm.item(),
        'kappa_T': kappa_T.item(),
        'kappa_T_target': config.target_kappa_T,
        'kappa_T_error_pct': abs(kappa_T.item() - config.target_kappa_T) / config.target_kappa_T * 100,
        'det_g_mean': det_g.mean().item(),
        'det_g_target': config.target_det_g,
        'det_g_error_pct': abs(det_g.mean().item() - config.target_det_g) / config.target_det_g * 100,
        'phi_norm_sq_mean': phi_norm_sq.mean().item(),
        'min_eigenvalue': eigenvalues.min().item()
    }
    
    return total_loss, info


# Test core loss
loss, info = compute_core_loss(phi_test, coords_grad, CONFIG.geometry)
print(f"Core loss: {info['total_loss']:.4f}")
print(f"  kappa_T: {info['kappa_T']:.6f} (error: {info['kappa_T_error_pct']:.2f}%)")
print(f"  det(g): {info['det_g_mean']:.4f} (error: {info['det_g_error_pct']:.2f}%)")
print(f"  min eigenvalue: {info['min_eigenvalue']:.6f}")

In [None]:
# =============================================================================
# SECTION 2.4: EXTERIOR DERIVATIVES AND TORSION
# =============================================================================

def exterior_derivative_3form(phi: torch.Tensor, coords: torch.Tensor,
                               create_graph: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute exterior derivative d(phi) using automatic differentiation.
    
    For 3-form phi in 7D: d(phi) is a 4-form with C(7,4) = 35 components.
    
    Args:
        phi: 3-form of shape (batch, 35)
        coords: Coordinates of shape (batch, 7) - must have requires_grad
        create_graph: Whether to create computation graph for higher derivatives
    
    Returns:
        d_phi: 4-form of shape (batch, 35)
        d_phi_norm_sq: ||d(phi)||^2 of shape (batch,)
    """
    batch_size = phi.shape[0]
    device = phi.device
    
    if not coords.requires_grad:
        coords = coords.requires_grad_(True)
    
    d_phi = torch.zeros(batch_size, 35, device=device)
    
    for comp_idx in range(35):
        grad_outputs = torch.ones_like(phi[:, comp_idx])
        
        grads = torch.autograd.grad(
            outputs=phi[:, comp_idx],
            inputs=coords,
            grad_outputs=grad_outputs,
            create_graph=create_graph,
            retain_graph=True,
            allow_unused=True
        )[0]
        
        if grads is not None:
            d_phi[:, comp_idx] = torch.norm(grads, dim=1)
    
    d_phi_norm_sq = torch.sum(d_phi ** 2, dim=1)
    return d_phi, d_phi_norm_sq


def hodge_star_3form(phi: torch.Tensor, metric: torch.Tensor) -> torch.Tensor:
    """
    Compute Hodge dual *phi: 3-form -> 4-form.
    
    Simplified computation using volume scaling.
    Full computation requires Levi-Civita tensor contractions.
    
    Args:
        phi: 3-form of shape (batch, 35)
        metric: Metric tensor of shape (batch, 7, 7)
    
    Returns:
        phi_dual: 4-form of shape (batch, 35)
    """
    vol = volume_form(metric).unsqueeze(-1)
    
    # Simplified: scale phi by volume
    phi_dual = phi * vol
    
    # Normalize to preserve structure
    phi_dual_norm = torch.norm(phi_dual, dim=1, keepdim=True)
    phi_dual = phi_dual / (phi_dual_norm + 1e-8) * np.sqrt(7.0)
    
    return phi_dual


def compute_torsion(phi: torch.Tensor, coords: torch.Tensor, 
                    metric: torch.Tensor) -> Dict[str, torch.Tensor]:
    """
    Compute G2 torsion: ||d(phi)||^2 + ||d(*phi)||^2.
    
    For torsion-free G2: both terms should be zero.
    
    Returns:
        Dictionary with torsion components and total
    """
    # d(phi)
    d_phi, d_phi_norm_sq = exterior_derivative_3form(phi, coords)
    
    # *phi (Hodge dual)
    phi_dual = hodge_star_3form(phi, metric)
    
    # d(*phi)
    d_phi_dual, d_phi_dual_norm_sq = exterior_derivative_3form(phi_dual, coords)
    
    # Total torsion
    torsion_total = d_phi_norm_sq + d_phi_dual_norm_sq
    
    return {
        'd_phi_norm_sq': d_phi_norm_sq,
        'd_phi_dual_norm_sq': d_phi_dual_norm_sq,
        'torsion_total': torsion_total,
        'kappa_T': torsion_total.mean()  # Scalar torsion magnitude
    }


# Test torsion computation
coords_grad = test_coords.clone().requires_grad_(True)
phi_test = phi_net(coords_grad)
metric_test = project_spd(metric_from_phi(phi_test))

torsion_info = compute_torsion(phi_test, coords_grad, metric_test)
print(f"Torsion ||d(phi)||^2: {torsion_info['d_phi_norm_sq'].mean():.6f}")
print(f"Torsion ||d(*phi)||^2: {torsion_info['d_phi_dual_norm_sq'].mean():.6f}")
print(f"kappa_T (mean torsion): {torsion_info['kappa_T']:.6f}")
print(f"Target kappa_T: {1/61:.6f}")

In [None]:
# =============================================================================
# SECTION 2.3: METRIC RECONSTRUCTION FROM PHI
# =============================================================================

def build_triple_to_idx() -> Dict[Tuple[int,int,int], int]:
    """Build mapping from (i,j,k) triple to phi component index."""
    triple_to_idx = {}
    idx = 0
    for i in range(7):
        for j in range(i+1, 7):
            for k in range(j+1, 7):
                triple_to_idx[(i, j, k)] = idx
                idx += 1
    return triple_to_idx

TRIPLE_TO_IDX = build_triple_to_idx()


def metric_from_phi(phi: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
    """
    Reconstruct metric g_ij from G2 3-form phi.
    
    Uses the approximate algebraic relation that captures G2 structure.
    The exact formula involves Levi-Civita contractions (very expensive).
    
    Args:
        phi: 3-form of shape (batch, 35)
        epsilon: Regularization for numerical stability
    
    Returns:
        metric: Shape (batch, 7, 7) - symmetric positive definite
    """
    batch_size = phi.shape[0]
    device = phi.device
    
    metric = torch.zeros(batch_size, 7, 7, device=device)
    
    # Phi norm for scaling
    phi_norm = torch.norm(phi, dim=1, keepdim=True)
    
    # Diagonal: g_ii ~ sum_jk phi_ijk^2
    for i in range(7):
        contrib = torch.zeros(batch_size, device=device)
        for j in range(7):
            for k in range(j+1, 7):
                if j != i and k != i:
                    triple = tuple(sorted([i, j, k]))
                    if triple in TRIPLE_TO_IDX:
                        contrib += phi[:, TRIPLE_TO_IDX[triple]] ** 2
        metric[:, i, i] = torch.sqrt(contrib + epsilon) / np.sqrt(5.0)
    
    # Off-diagonal: g_ij ~ sum_k phi_ijk * normalization
    for i in range(7):
        for j in range(i+1, 7):
            contrib = torch.zeros(batch_size, device=device)
            count = 0
            for k in range(7):
                if k != i and k != j:
                    triple = tuple(sorted([i, j, k]))
                    if triple in TRIPLE_TO_IDX:
                        contrib += phi[:, TRIPLE_TO_IDX[triple]] ** 2
                        count += 1
            if count > 0:
                metric[:, i, j] = torch.sqrt(contrib / count + epsilon) * 0.1
                metric[:, j, i] = metric[:, i, j]
    
    return metric


def project_spd(metric: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
    """Project metric to be symmetric positive definite via eigenvalue clamping."""
    # Ensure symmetry
    metric = 0.5 * (metric + metric.transpose(-2, -1))
    
    # Eigendecomposition
    eigenvalues, eigenvectors = torch.linalg.eigh(metric)
    
    # Clamp eigenvalues to be positive
    eigenvalues = torch.clamp(eigenvalues, min=epsilon)
    
    # Reconstruct: M = V @ diag(lambda) @ V^T
    return eigenvectors @ torch.diag_embed(eigenvalues) @ eigenvectors.transpose(-2, -1)


def compute_det_g(metric: torch.Tensor) -> torch.Tensor:
    """Compute metric determinant det(g)."""
    return torch.det(metric)


def volume_form(metric: torch.Tensor) -> torch.Tensor:
    """Compute volume element sqrt(|det(g)|)."""
    return torch.sqrt(torch.abs(torch.det(metric)) + 1e-10)


# Test metric reconstruction
test_metric = metric_from_phi(test_phi)
test_metric_spd = project_spd(test_metric)
test_det = compute_det_g(test_metric_spd)

print(f"Metric shape: {test_metric.shape}")
print(f"Metric symmetric: {torch.allclose(test_metric, test_metric.transpose(-2,-1), atol=1e-5)}")
print(f"det(g): mean={test_det.mean():.4f}, std={test_det.std():.4f}")
print(f"Target det(g): {65/32:.4f}")

In [None]:
# =============================================================================
# SECTION 2.2: NEURAL NETWORK ARCHITECTURE
# =============================================================================

class FourierFeatures(nn.Module):
    """Random Fourier features for periodic coordinate encoding."""
    
    def __init__(self, input_dim: int = 7, n_modes: int = 32, scale: float = 1.0):
        super().__init__()
        self.input_dim = input_dim
        self.n_modes = n_modes
        self.output_dim = 2 * n_modes
        
        # Random frequency matrix (fixed, not trainable)
        B = torch.randn(input_dim, n_modes) * scale
        self.register_buffer('B', B)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Map x to [cos(2*pi*B*x), sin(2*pi*B*x)]."""
        proj = 2 * np.pi * torch.matmul(x, self.B)
        return torch.cat([torch.cos(proj), torch.sin(proj)], dim=-1)


class PhiNetwork(nn.Module):
    """
    Neural network for learning G2 3-form phi(x) on K7.
    
    Architecture:
        Input: 7D coordinates -> Fourier encoding -> MLP -> 35 phi components
    
    The metric is reconstructed algebraically from phi.
    """
    
    def __init__(self, 
                 hidden_dims: List[int] = [256, 256, 128],
                 fourier_modes: int = 32,
                 fourier_scale: float = 1.0,
                 normalize_phi: bool = True,
                 dropout: float = 0.0):
        super().__init__()
        
        self.normalize_phi = normalize_phi
        self.input_dim = 7
        self.output_dim = 35  # C(7,3) = 35 components for 3-form
        
        # Fourier encoding
        self.encoding = FourierFeatures(
            input_dim=7, 
            n_modes=fourier_modes,
            scale=fourier_scale
        )
        
        # MLP with LayerNorm and residual-friendly structure
        layers = []
        prev_dim = self.encoding.output_dim
        
        for h_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, h_dim))
            layers.append(nn.SiLU())
            layers.append(nn.LayerNorm(h_dim))
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            prev_dim = h_dim
        
        self.mlp = nn.Sequential(*layers)
        
        # Output layer
        self.output_layer = nn.Linear(prev_dim, self.output_dim)
        
        # Initialize output with small weights for stability
        with torch.no_grad():
            self.output_layer.weight.mul_(0.01)
            self.output_layer.bias.zero_()
    
    def forward(self, coords: torch.Tensor) -> torch.Tensor:
        """
        Compute phi(x) at given coordinates.
        
        Args:
            coords: Shape (batch, 7)
        
        Returns:
            phi: Shape (batch, 35) - 3-form components
        """
        # Encode coordinates
        x = self.encoding(coords)
        
        # MLP
        x = self.mlp(x)
        
        # Output phi
        phi = self.output_layer(x)
        
        # Normalize: ||phi||^2 = 7 (G2 structure)
        if self.normalize_phi:
            phi_norm = torch.norm(phi, dim=-1, keepdim=True)
            phi = phi * (np.sqrt(7.0) / (phi_norm + 1e-8))
        
        return phi
    
    @staticmethod
    def get_phi_indices() -> List[Tuple[int, int, int]]:
        """Return mapping from component index to (i,j,k) triple."""
        indices = []
        for i in range(7):
            for j in range(i+1, 7):
                for k in range(j+1, 7):
                    indices.append((i, j, k))
        return indices


# Test PhiNetwork
phi_net = PhiNetwork().to(DEVICE)
test_phi = phi_net(test_coords)
print(f"Phi shape: {test_phi.shape}")
print(f"||phi||^2: {(test_phi**2).sum(dim=1).mean():.4f} (target: 7.0)")
print(f"Parameters: {sum(p.numel() for p in phi_net.parameters()):,}")

In [None]:
# =============================================================================
# SECTION 2.1: COORDINATE SAMPLING
# =============================================================================

def sample_T7_coords(batch_size: int, device: torch.device) -> torch.Tensor:
    """Sample uniformly on T^7 (7-torus with period 2*pi)."""
    return torch.rand(batch_size, 7, device=device) * 2 * np.pi


def sample_TCS_coords(batch_size: int, device: torch.device,
                      lambda_range: Tuple[float, float] = (0.0, 1.0)) -> torch.Tensor:
    """
    Sample on TCS (Twisted Connected Sum) parameterization.
    
    Coordinates: (lambda, xi_1, ..., xi_6)
    - lambda in [0, 1]: neck parameter
    - xi_i in [0, 2*pi]: transverse coordinates
    
    Returns:
        coords: Shape (batch, 7) with lambda as first coordinate
    """
    coords = torch.zeros(batch_size, 7, device=device)
    
    # Neck coordinate lambda in [lambda_range]
    coords[:, 0] = torch.rand(batch_size, device=device) * (lambda_range[1] - lambda_range[0]) + lambda_range[0]
    
    # Transverse coordinates xi_i in [0, 2*pi]
    coords[:, 1:] = torch.rand(batch_size, 6, device=device) * 2 * np.pi
    
    return coords


def split_TCS_coords(coords: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Split TCS coordinates into neck and transverse components.
    
    Args:
        coords: Shape (batch, 7) with format (lambda, xi_1, ..., xi_6)
    
    Returns:
        lambda_coords: Shape (batch, 1) - neck parameter
        transverse_coords: Shape (batch, 6) - transverse coordinates
    """
    lambda_coords = coords[:, 0:1]
    transverse_coords = coords[:, 1:]
    return lambda_coords, transverse_coords


def sample_coords(batch_size: int, config: GeometryConfig, device: torch.device) -> torch.Tensor:
    """Unified coordinate sampling based on configuration."""
    if config.coordinate_type == "T7":
        return sample_T7_coords(batch_size, device)
    elif config.coordinate_type == "TCS":
        return sample_TCS_coords(batch_size, device)
    else:
        raise ValueError(f"Unknown coordinate type: {config.coordinate_type}")


# Test coordinate sampling
test_coords = sample_TCS_coords(16, DEVICE)
lam, xi = split_TCS_coords(test_coords)
print(f"TCS coords shape: {test_coords.shape}")
print(f"Lambda range: [{lam.min():.3f}, {lam.max():.3f}]")
print(f"Transverse range: [{xi.min():.3f}, {xi.max():.3f}]")

---
## Section 2: Geometry Core (from v1.6)

This section contains the core geometric operations:
- Coordinate sampling on T^7 / TCS
- PhiNetwork for learning the G2 3-form
- Metric reconstruction from phi
- Exterior derivatives and Hodge star
- Torsion-free conditions
---

# K7 Breathing Mode PINN v1.8

**Physics-Informed Neural Network for G2 Holonomy Metrics with Breathing/Flux Diagnostics**

## Overview

This notebook extends v1.6 with:
1. **Breathing/Flux Module**: Visible/hidden H3 mode decomposition along TCS neck
2. **Yukawa Extraction**: Clean numerical computation of triple overlap integrals
3. **Post-hoc Metric Fitting**: Analytic approximation of learned g_ij(x)

## Design Principles

- **Core geometry preserved**: All v1.6 topological targets (b2=21, b3=77, det(g)=65/32, kappa_T=1/61)
- **Speculative quantities as diagnostics**: 43/77 ratio, -1/2 flux, tau period are SOFT probes
- **Modular architecture**: Easy to disable breathing losses (weights=0) for pure geometry mode

## References

- v1.6: K7_GIFT_v1_6.ipynb (SVD-orthonormal harmonic basis)
- GIFT Framework: publications/gift_2_2_main.md

---
## Section 1: Imports and Configuration
---

In [None]:
# =============================================================================
# SECTION 1: IMPORTS AND CONFIGURATION
# =============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional, Any
from pathlib import Path
import json
import time
from fractions import Fraction

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Output directory
OUTPUT_DIR = Path("outputs_v1_8")
OUTPUT_DIR.mkdir(exist_ok=True)
(OUTPUT_DIR / "checkpoints").mkdir(exist_ok=True)

In [None]:
# =============================================================================
# EXPERIMENT CONFIGURATION (Dataclass)
# =============================================================================

@dataclass
class GeometryConfig:
    """Configuration for K7 manifold geometry."""
    dim: int = 7
    coordinate_type: str = "TCS"  # "T7" or "TCS"
    has_neck: bool = True
    
    # Topological constants (from GIFT v2.2)
    b2_K7: int = 21          # Second Betti number
    b3_K7: int = 77          # Third Betti number
    b3_local: int = 35       # Local Lambda^3 decomposition
    b3_global: int = 42      # Global TCS-induced modes
    dim_G2: int = 14         # dim(G2)
    p2: int = 2              # Binary duality
    
    # Target metric invariants
    target_det_g: float = 65/32      # = 2.03125
    target_kappa_T: float = 1/61     # = 0.016393...


@dataclass
class BreathingConfig:
    """Configuration for breathing/flux diagnostics.
    
    IMPORTANT: These are SOFT diagnostics, not hard constraints.
    The speculative values (43/77, -1/2, tau) are measured, not imposed.
    """
    enable: bool = True
    
    # Speculative target values (for diagnostics only)
    target_flux: float = -0.5              # -1/p2 hypothesis
    target_visible_ratio: float = 43/77   # ~ 0.558
    target_period: float = 3472/891       # tau = 3.896... (GIFT hierarchy parameter)
    
    # Loss weights (SMALL to avoid dominating geometry)
    loss_weight_flux: float = 1e-3
    loss_weight_ratio: float = 1e-3
    loss_weight_period: float = 1e-4
    
    # Mode split configuration
    n_visible: int = 43
    n_hidden: int = 34  # 77 - 43
    
    # Neck parameterization
    lambda_mid: float = 0.5  # Boundary between visible/hidden regions


@dataclass
class TrainingConfig:
    """Training hyperparameters."""
    # Phase 1: Geometry-only
    n_epochs_core: int = 5000
    lr_core: float = 1e-3
    
    # Phase 2: Geometry + breathing
    n_epochs_breathing: int = 2000
    lr_breathing: float = 5e-4
    freeze_phi_in_breathing: bool = False  # If True, only train NeckTransferNet
    
    # General
    batch_size: int = 2048
    weight_decay: float = 1e-5
    scheduler_patience: int = 500
    scheduler_factor: float = 0.5
    
    # Gradient clipping
    max_grad_norm: float = 1.0


@dataclass
class LoggingConfig:
    """Logging and checkpointing configuration."""
    log_every: int = 100
    save_checkpoints: bool = True
    checkpoint_dir: str = "outputs_v1_8/checkpoints"
    verbose: bool = True


@dataclass
class ExperimentConfig:
    """Master configuration combining all sub-configs."""
    geometry: GeometryConfig = field(default_factory=GeometryConfig)
    breathing: BreathingConfig = field(default_factory=BreathingConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    logging: LoggingConfig = field(default_factory=LoggingConfig)
    
    def to_dict(self) -> Dict:
        """Convert to dictionary for JSON serialization."""
        return {
            'geometry': self.geometry.__dict__,
            'breathing': self.breathing.__dict__,
            'training': self.training.__dict__,
            'logging': self.logging.__dict__
        }
    
    def save(self, path: str):
        """Save configuration to JSON file."""
        with open(path, 'w') as f:
            json.dump(self.to_dict(), f, indent=2)


# Create default configuration
CONFIG = ExperimentConfig()
print("Configuration loaded:")
print(f"  Geometry: dim={CONFIG.geometry.dim}, b2={CONFIG.geometry.b2_K7}, b3={CONFIG.geometry.b3_K7}")
print(f"  Breathing: enabled={CONFIG.breathing.enable}, target_flux={CONFIG.breathing.target_flux}")
print(f"  Training: core_epochs={CONFIG.training.n_epochs_core}, batch={CONFIG.training.batch_size}")