# K₇ Metric Reconstruction via Neural Networks

This notebook implements a complete pipeline for reconstructing the Riemannian metric on the compact G₂ manifold K₇ using deep learning techniques. The architecture employs curriculum learning across three phases to learn a smooth, Ricci-flat metric consistent with G₂ holonomy.

## Key Features

- Three-phase curriculum training (region-specific → global blending → curvature refinement)
- Automatic checkpoint management with Google Drive integration
- Auto-resume functionality for interrupted training sessions
- Multi-format output (.pt, .npy, .json)
- H³ harmonic forms extraction
- Comprehensive loss composition (6 geometric constraints)

## Architecture Overview

**Phase 1 (Epochs 0-2000)**: Train region-specific networks (M1, Neck, M2)

**Phase 2 (Epochs 2000-5000)**: Global network with smooth blending and H² constraint

**Phase 3 (Epochs 5000-8000)**: Ricci curvature refinement

## References

Based on the GIFT framework for G₂ manifold construction using hierarchical gluing techniques.

## 1. Environment Setup and Google Drive Integration

In [None]:
# Mount Google Drive for checkpoint persistence
from google.colab import drive
import os

drive.mount('/content/drive')

# Create project directory in Drive
PROJECT_DIR = '/content/drive/MyDrive/K7_Reconstruction'
CHECKPOINT_DIR = f'{PROJECT_DIR}/checkpoints'
OUTPUT_DIR = f'{PROJECT_DIR}/outputs'
DATA_DIR = f'{PROJECT_DIR}/data'

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(DATA_DIR, exist_ok=True)

print(f"Project directory: {PROJECT_DIR}")
print(f"Checkpoints: {CHECKPOINT_DIR}")
print(f"Outputs: {OUTPUT_DIR}")

In [None]:
# Install required packages
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install scipy numpy matplotlib tqdm

In [None]:
# Core imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import json
from pathlib import Path
from typing import Dict, Optional, List, Tuple
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from datetime import datetime
import scipy
from scipy.sparse.linalg import eigsh
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Configuration and Hyperparameters

In [None]:
# Training configuration
CONFIG = {
    # Training parameters
    'epochs': 8000,
    'batch_size': 1024,
    'lr': 2e-4,
    'weight_decay': 1e-4,
    'grad_clip': 1.0,
    
    # Phase boundaries
    'phase_boundaries': [(0, 2000), (2000, 5000), (5000, 8000)],
    
    # Network architecture
    'n_fourier': 32,
    'hidden_dims': [256, 256, 128],
    'phi_hidden_dims': [384, 384, 256],
    'h2_n_forms': 21,
    'h2_n_fourier': 24,
    'h2_hidden_dim': 128,
    
    # GIFT parameters (from v0.7)
    'gift_params': {
        'tau': 3.8967,
        'xi': 2.5,
        'transition_radius': 2.0
    },
    
    # Checkpoint and logging
    'checkpoint_frequency': 500,  # Save every N epochs
    'log_frequency': 100,          # Print every N epochs
    'auto_resume': True,           # Automatically resume from latest checkpoint
    
    # H³ extraction
    'extract_h3': True,
    'h3_n_forms': 77,
    'h3_n_sample_points': 8192,
    
    # Device
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Random seed
    'seed': 42
}

# Set random seeds for reproducibility
torch.manual_seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(CONFIG['seed'])
    torch.backends.cudnn.deterministic = True

device = torch.device(CONFIG['device'])
print(f"Using device: {device}")
print(f"Configuration loaded successfully")

## 3. Data Loading and Snapshot Preparation

In [None]:
# Upload metric snapshot data from v0.7
# Users should upload: metric_data.json, metric_m1.npy, metric_neck.npy, metric_m2.npy

from google.colab import files
import shutil

print("Please upload the following files from G2_ML/0.7:")
print("  - metric_data.json")
print("  - metric_m1.npy")
print("  - metric_neck.npy")
print("  - metric_m2.npy")
print("  - config.json (optional)")

uploaded = files.upload()

# Move uploaded files to data directory
for filename in uploaded.keys():
    shutil.move(filename, f"{DATA_DIR}/{filename}")
    print(f"Moved {filename} to {DATA_DIR}")

In [None]:
def load_snapshots_from_data(
    data_dir: str = DATA_DIR,
    device: torch.device = device
) -> Dict[str, Dict[str, torch.Tensor]]:
    """Load metric snapshots from uploaded data.
    
    Returns:
        Dictionary containing snapshot data for M1, Neck, and M2 regions
    """
    # Load metric data JSON
    metric_data_path = Path(data_dir) / 'metric_data.json'
    with open(metric_data_path, 'r') as f:
        metric_data = json.load(f)
    
    # Extract mean metrics from JSON
    g_m1 = torch.tensor(metric_data['M1']['metric_mean'], dtype=torch.float32, device=device)
    g_neck = torch.tensor(metric_data['Neck']['metric_mean'], dtype=torch.float32, device=device)
    g_m2 = torch.tensor(metric_data['M2']['metric_mean'], dtype=torch.float32, device=device)
    
    # Define representative coordinates for each region
    # TCS structure: (t, θ, x₁, x₂, x₃, x₄, x₅)
    
    coords_m1 = torch.tensor([
        -5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
    ], dtype=torch.float32, device=device)
    
    coords_neck = torch.tensor([
        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
    ], dtype=torch.float32, device=device)
    
    coords_m2 = torch.tensor([
        5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
    ], dtype=torch.float32, device=device)
    
    snapshots = {
        'm1': {'coords': coords_m1, 'metric': g_m1},
        'neck': {'coords': coords_neck, 'metric': g_neck},
        'm2': {'coords': coords_m2, 'metric': g_m2}
    }
    
    print("Metric snapshots loaded successfully:")
    print(f"  M1 metric shape: {g_m1.shape}")
    print(f"  Neck metric shape: {g_neck.shape}")
    print(f"  M2 metric shape: {g_m2.shape}")
    
    return snapshots

# Load snapshots
snapshots = load_snapshots_from_data()
print(f"\nGIFT parameters: τ={CONFIG['gift_params']['tau']:.4f}, ξ={CONFIG['gift_params']['xi']:.4f}")

## 4. Utility Functions and Geometric Operations

In [None]:
class FourierFeatures(nn.Module):
    """Random Fourier Features for coordinate encoding.
    
    Maps x ∈ ℝⁿ to [sin(2πBx), cos(2πBx)] for better neural network expressiveness.
    """
    
    def __init__(self, n_dim: int = 7, n_fourier: int = 32, scale: float = 1.0):
        super().__init__()
        self.n_dim = n_dim
        self.n_fourier = n_fourier
        self.scale = scale
        
        # Random projection matrix B ~ N(0, scale²)
        B = torch.randn(n_fourier, n_dim) * scale
        self.register_buffer('B', B)
        self.output_dim = 2 * n_fourier * n_dim
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.shape[0]
        x_expanded = x.unsqueeze(1)
        Bx = x_expanded * self.B.unsqueeze(0)
        Bx_flat = Bx.reshape(batch_size, -1)
        
        features_sin = torch.sin(2 * np.pi * Bx_flat)
        features_cos = torch.cos(2 * np.pi * Bx_flat)
        features = torch.cat([features_sin, features_cos], dim=1)
        
        return features

print("Fourier feature encoding module defined")

In [None]:
def ensure_positive_definite(g: torch.Tensor, min_eig: float = 0.1) -> torch.Tensor:
    """Ensure metric is positive definite via eigenvalue clipping."""
    eigvals, eigvecs = torch.linalg.eigh(g)
    eigvals = torch.clamp(eigvals, min=min_eig)
    g_positive = eigvecs @ torch.diag_embed(eigvals) @ eigvecs.transpose(-2, -1)
    return g_positive

def compute_metric_determinant(g: torch.Tensor) -> torch.Tensor:
    """Compute det(g) for batch of metrics."""
    return torch.det(g)

def compute_volume_form(g: torch.Tensor) -> torch.Tensor:
    """Compute volume form √det(g)."""
    det_g = compute_metric_determinant(g)
    return torch.sqrt(torch.abs(det_g))

def compute_metric_derivatives(
    g: torch.Tensor,
    coords: torch.Tensor,
    create_graph: bool = True
) -> torch.Tensor:
    """Compute ∂_k g_ij via automatic differentiation."""
    batch_size = coords.shape[0]
    n_dim = 7
    g_derivs = torch.zeros(batch_size, n_dim, n_dim, n_dim, device=coords.device)
    
    for i in range(n_dim):
        for j in range(n_dim):
            g_ij = g[:, i, j]
            if coords.grad is not None:
                coords.grad.zero_()
            
            grad_outputs = torch.ones_like(g_ij)
            grads = torch.autograd.grad(
                outputs=g_ij,
                inputs=coords,
                grad_outputs=grad_outputs,
                create_graph=create_graph,
                retain_graph=True,
                allow_unused=True
            )[0]
            
            if grads is not None:
                g_derivs[:, i, j, :] = grads
    
    return g_derivs

def compute_christoffel(
    g: torch.Tensor,
    g_inv: torch.Tensor,
    g_derivs: torch.Tensor
) -> torch.Tensor:
    """Compute Christoffel symbols Γ^k_ij."""
    batch_size = g.shape[0]
    n_dim = 7
    gamma = torch.zeros(batch_size, n_dim, n_dim, n_dim, device=g.device)
    
    for k in range(n_dim):
        for i in range(n_dim):
            for j in range(n_dim):
                term = torch.zeros(batch_size, device=g.device)
                for l in range(n_dim):
                    d_i_gjl = g_derivs[:, j, l, i]
                    d_j_gil = g_derivs[:, i, l, j]
                    d_l_gij = g_derivs[:, i, j, l]
                    g_kl = g_inv[:, k, l]
                    term += 0.5 * g_kl * (d_i_gjl + d_j_gil - d_l_gij)
                gamma[:, k, i, j] = term
    
    return gamma

def compute_ricci_tensor(
    gamma: torch.Tensor,
    coords: torch.Tensor,
    create_graph: bool = False
) -> torch.Tensor:
    """Compute Ricci tensor (simplified approximation)."""
    batch_size = coords.shape[0]
    n_dim = 7
    ricci = torch.zeros(batch_size, n_dim, n_dim, device=coords.device)
    
    for i in range(n_dim):
        for j in range(n_dim):
            term_algebraic = torch.zeros(batch_size, device=coords.device)
            for k in range(n_dim):
                for l in range(n_dim):
                    gamma_klk = gamma[:, k, l, k]
                    gamma_lij = gamma[:, l, i, j]
                    gamma_klj = gamma[:, k, l, j]
                    gamma_lik = gamma[:, l, i, k]
                    term_algebraic += gamma_klk * gamma_lij - gamma_klj * gamma_lik
            ricci[:, i, j] = term_algebraic
    
    return ricci

def exterior_derivative_3form(
    phi: torch.Tensor,
    coords: torch.Tensor,
    create_graph: bool = True
) -> torch.Tensor:
    """Compute exterior derivative dφ for 3-form φ."""
    batch_size = coords.shape[0]
    d_phi = torch.zeros(batch_size, 35, device=coords.device)
    
    for i in range(35):
        phi_i = phi[:, i]
        if coords.grad is not None:
            coords.grad.zero_()
        
        grads = torch.autograd.grad(
            outputs=phi_i,
            inputs=coords,
            grad_outputs=torch.ones_like(phi_i),
            create_graph=create_graph,
            retain_graph=True,
            allow_unused=True
        )[0]
        
        if grads is not None:
            d_phi[:, i] = grads.norm(dim=1)
    
    return d_phi

print("Geometric utility functions defined")

## 5. Neural Network Architectures

In [None]:
class RegionMetricNetwork(nn.Module):
    """Region-specific metric network for Phase 1.
    
    Maps coordinates x ∈ ℝ⁷ to metric tensor g ∈ S⁺(7).
    """
    
    def __init__(
        self,
        n_fourier: int = 32,
        hidden_dims: list = [256, 256, 128],
        use_layer_norm: bool = True
    ):
        super().__init__()
        self.n_fourier = n_fourier
        self.hidden_dims = hidden_dims
        
        # Fourier feature encoding
        self.fourier = FourierFeatures(n_dim=7, n_fourier=n_fourier, scale=1.0)
        fourier_dim = self.fourier.output_dim
        
        # MLP backbone
        layers = []
        layers.append(nn.Linear(fourier_dim, hidden_dims[0]))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dims[0]))
        layers.append(nn.SiLU())
        
        for i in range(len(hidden_dims) - 1):
            layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dims[i + 1]))
            layers.append(nn.SiLU())
        
        self.backbone = nn.Sequential(*layers)
        
        # Output layer: 28 values for upper triangular 7×7 matrix
        self.output_layer = nn.Linear(hidden_dims[-1], 28)
        nn.init.normal_(self.output_layer.weight, std=0.01)
        nn.init.zeros_(self.output_layer.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.shape[0]
        
        # Fourier features
        features = self.fourier(x)
        hidden = self.backbone(features)
        g_upper = self.output_layer(hidden)
        
        # Reconstruct symmetric 7×7 matrix
        g = torch.zeros(batch_size, 7, 7, device=x.device)
        idx = torch.triu_indices(7, 7, device=x.device)
        g[:, idx[0], idx[1]] = g_upper
        
        # Symmetrize
        g = g + g.transpose(-2, -1)
        diag_idx = torch.arange(7, device=x.device)
        g[:, diag_idx, diag_idx] = g[:, diag_idx, diag_idx] / 2.0
        
        # Add identity for positive definiteness
        eye = torch.eye(7, device=x.device).unsqueeze(0)
        g = g + 0.5 * eye
        
        # Ensure positive definite
        g = ensure_positive_definite(g, min_eig=0.1)
        
        return g

print("RegionMetricNetwork defined")

In [None]:
class GlobalMetricNetwork(nn.Module):
    """Global metric network with region blending for Phase 2.
    
    Combines three region networks with smooth transition functions.
    """
    
    def __init__(
        self,
        region_networks: Dict[str, RegionMetricNetwork],
        tau: float = 3.8967,
        transition_radius: float = 2.0
    ):
        super().__init__()
        self.tau = tau
        self.R = transition_radius
        
        self.net_m1 = region_networks['m1']
        self.net_neck = region_networks['neck']
        self.net_m2 = region_networks['m2']
    
    def region_weights(self, t: torch.Tensor) -> torch.Tensor:
        """Compute smooth blending weights for three regions."""
        w_m1 = torch.sigmoid(-self.tau * (t + self.R))
        w_m2 = torch.sigmoid(self.tau * (t - self.R))
        w_neck = 1.0 - w_m1 - w_m2
        w_neck = torch.clamp(w_neck, min=0.0)
        
        weights = torch.stack([w_m1, w_neck, w_m2], dim=-1)
        weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-8)
        return weights
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        g_m1 = self.net_m1(x)
        g_neck = self.net_neck(x)
        g_m2 = self.net_m2(x)
        
        t = x[:, 0]
        weights = self.region_weights(t)
        
        g = (
            weights[:, 0, None, None] * g_m1 +
            weights[:, 1, None, None] * g_neck +
            weights[:, 2, None, None] * g_m2
        )
        
        g = ensure_positive_definite(g, min_eig=0.1)
        return g

print("GlobalMetricNetwork defined")

In [None]:
class PhiNetwork(nn.Module):
    """Network for G₂ 3-form φ (associative form).
    
    A 3-form on ℝ⁷ has C(7,3) = 35 independent components.
    """
    
    def __init__(
        self,
        n_fourier: int = 32,
        hidden_dims: list = [384, 384, 256]
    ):
        super().__init__()
        
        self.fourier = FourierFeatures(n_dim=7, n_fourier=n_fourier, scale=1.0)
        fourier_dim = self.fourier.output_dim
        
        layers = []
        layers.append(nn.Linear(fourier_dim, hidden_dims[0]))
        layers.append(nn.LayerNorm(hidden_dims[0]))
        layers.append(nn.SiLU())
        
        for i in range(len(hidden_dims) - 1):
            layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
            layers.append(nn.LayerNorm(hidden_dims[i + 1]))
            layers.append(nn.SiLU())
        
        self.backbone = nn.Sequential(*layers)
        self.output_layer = nn.Linear(hidden_dims[-1], 35)
        
        nn.init.normal_(self.output_layer.weight, std=0.01)
        nn.init.zeros_(self.output_layer.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.fourier(x)
        hidden = self.backbone(features)
        phi = self.output_layer(hidden)
        return phi

print("PhiNetwork defined")

In [None]:
class HarmonicFormsNetwork(nn.Module):
    """Network for harmonic 2-forms (H² basis).
    
    Outputs 21 harmonic 2-forms that should be L²-orthonormal.
    """
    
    def __init__(
        self,
        n_forms: int = 21,
        n_fourier: int = 24,
        hidden_dim: int = 128
    ):
        super().__init__()
        self.n_forms = n_forms
        
        self.fourier = FourierFeatures(n_dim=7, n_fourier=n_fourier, scale=1.0)
        fourier_dim = self.fourier.output_dim
        
        self.backbone = nn.Sequential(
            nn.Linear(fourier_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU()
        )
        
        # Separate heads for each harmonic form (each has 21 components)
        self.form_heads = nn.ModuleList([
            nn.Linear(hidden_dim, 21) for _ in range(n_forms)
        ])
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.fourier(x)
        hidden = self.backbone(features)
        forms = torch.stack([head(hidden) for head in self.form_heads], dim=1)
        return forms

print("HarmonicFormsNetwork defined")

## 6. Loss Functions

In [None]:
class SnapshotAnchoringLoss(nn.Module):
    """Loss 1: Anchor network predictions to known metric snapshots."""
    
    def __init__(self, snapshots: Dict[str, Dict[str, torch.Tensor]]):
        super().__init__()
        for region in ['m1', 'neck', 'm2']:
            self.register_buffer(
                f'coords_{region}',
                snapshots[region]['coords'].unsqueeze(0)
            )
            self.register_buffer(
                f'metric_{region}',
                snapshots[region]['metric']
            )
    
    def forward(self, network: nn.Module) -> torch.Tensor:
        loss = 0.0
        for region in ['m1', 'neck', 'm2']:
            coords = getattr(self, f'coords_{region}')
            g_target = getattr(self, f'metric_{region}')
            g_pred = network(coords).squeeze(0)
            loss += torch.norm(g_pred - g_target, p='fro') ** 2
        return loss / 3.0


class VolumeNormalizationLoss(nn.Module):
    """Loss 3: Enforce unit volume form √det(g) ≈ 1."""
    
    def forward(self, g: torch.Tensor) -> torch.Tensor:
        det_g = compute_metric_determinant(g)
        sqrt_det = torch.sqrt(torch.abs(det_g))
        loss = torch.mean((sqrt_det - 1.0) ** 2)
        return loss


class TCSAsymptoticLoss(nn.Module):
    """Loss 4: Enforce TCS asymptotic behavior."""
    
    def __init__(self, tau: float = 3.8967, t_min: float = -8.0, t_max: float = 8.0, n_samples: int = 100):
        super().__init__()
        self.tau = tau
        self.t_min = t_min
        self.t_max = t_max
        self.n_samples = n_samples
        self.register_buffer('identity_2', torch.eye(2))
    
    def sample_far_regions(self, device: torch.device) -> tuple:
        t_m1 = torch.linspace(self.t_min, self.t_min + 2, self.n_samples, device=device)
        t_m2 = torch.linspace(self.t_max - 2, self.t_max, self.n_samples, device=device)
        
        theta_m1 = torch.rand(self.n_samples, device=device) * 2 * np.pi
        theta_m2 = torch.rand(self.n_samples, device=device) * 2 * np.pi
        
        base_m1 = torch.randn(self.n_samples, 5, device=device) * 0.5
        base_m2 = torch.randn(self.n_samples, 5, device=device) * 0.5
        
        coords_m1 = torch.cat([t_m1.unsqueeze(1), theta_m1.unsqueeze(1), base_m1], dim=1)
        coords_m2 = torch.cat([t_m2.unsqueeze(1), theta_m2.unsqueeze(1), base_m2], dim=1)
        
        return coords_m1, coords_m2
    
    def forward(self, network: nn.Module) -> torch.Tensor:
        device = next(network.parameters()).device
        coords_m1, coords_m2 = self.sample_far_regions(device)
        
        g_m1 = network(coords_m1)
        g_m2 = network(coords_m2)
        
        g_fiber_m1 = g_m1[:, :2, :2]
        g_fiber_m2 = g_m2[:, :2, :2]
        
        loss_m1 = torch.mean(torch.norm(g_fiber_m1 - self.identity_2, dim=(1, 2)) ** 2)
        loss_m2 = torch.mean(torch.norm(g_fiber_m2 - self.identity_2, dim=(1, 2)) ** 2)
        
        return loss_m1 + loss_m2


class RicciFlatnessLoss(nn.Module):
    """Loss 5: Enforce Ricci-flatness Ric(g) = 0."""
    
    def __init__(self, use_approximate: bool = True):
        super().__init__()
        self.use_approximate = use_approximate
    
    def forward(self, network: nn.Module, coords: torch.Tensor) -> torch.Tensor:
        coords = coords.detach()
        coords.requires_grad_(True)
        
        g = network(coords)
        g_derivs = compute_metric_derivatives(g, coords, create_graph=True)
        g_inv = torch.inverse(g)
        gamma = compute_christoffel(g, g_inv, g_derivs)
        ricci = compute_ricci_tensor(gamma, coords, create_graph=False)
        
        loss = torch.mean(torch.norm(ricci, dim=(1, 2)) ** 2)
        return loss


class G2TorsionLoss(nn.Module):
    """Loss 6: Enforce G₂ torsion-free condition."""
    
    def __init__(self, phi_network: Optional[nn.Module] = None, use_simplified: bool = True):
        super().__init__()
        self.phi_network = phi_network
        self.use_simplified = use_simplified
    
    def forward(self, metric_network: nn.Module, coords: torch.Tensor) -> torch.Tensor:
        if self.phi_network is None:
            return torch.tensor(0.0, device=coords.device)
        
        coords = coords.detach()
        coords.requires_grad_(True)
        
        phi = self.phi_network(coords)
        d_phi = exterior_derivative_3form(phi, coords, create_graph=False)
        torsion = torch.mean(torch.norm(d_phi, dim=1) ** 2)
        
        return torsion


class H2OrthonormalityLoss(nn.Module):
    """Loss 2: Enforce orthonormality of H² forms."""
    
    def __init__(self, h2_network: Optional[nn.Module] = None):
        super().__init__()
        self.h2_network = h2_network
        self.register_buffer('identity_21', torch.eye(21))
    
    def forward(self, metric_network: nn.Module, coords: torch.Tensor) -> torch.Tensor:
        if self.h2_network is None:
            return torch.tensor(0.0, device=coords.device)
        
        g = metric_network(coords)
        h2_forms = self.h2_network(coords)
        
        batch_size = coords.shape[0]
        gram = torch.zeros(21, 21, device=coords.device)
        
        for b in range(min(batch_size, 32)):
            forms_b = h2_forms[b]
            gram += forms_b @ forms_b.T
        
        gram = gram / min(batch_size, 32)
        loss = torch.norm(gram - self.identity_21, p='fro') ** 2
        
        return loss


class CompositeLoss(nn.Module):
    """Composite loss combining all 6 components."""
    
    def __init__(
        self,
        snapshots: Dict[str, Dict[str, torch.Tensor]],
        phi_network: Optional[nn.Module] = None,
        h2_network: Optional[nn.Module] = None,
        weights: Optional[Dict[str, float]] = None,
        phase: int = 1
    ):
        super().__init__()
        self.phase = phase
        
        self.snapshot_loss = SnapshotAnchoringLoss(snapshots)
        self.volume_loss = VolumeNormalizationLoss()
        self.asymptotic_loss = TCSAsymptoticLoss()
        self.ricci_loss = RicciFlatnessLoss(use_approximate=True)
        self.torsion_loss = G2TorsionLoss(phi_network, use_simplified=True)
        self.h2_loss = H2OrthonormalityLoss(h2_network)
        
        if weights is None:
            if phase == 1:
                weights = {
                    'snapshot': 10.0, 'volume': 2.0, 'asymptotic': 3.0,
                    'ricci': 0.0, 'torsion': 2.0, 'h2': 0.0
                }
            elif phase == 2:
                weights = {
                    'snapshot': 10.0, 'volume': 2.0, 'asymptotic': 3.0,
                    'ricci': 0.0, 'torsion': 2.0, 'h2': 5.0
                }
            else:
                weights = {
                    'snapshot': 10.0, 'volume': 1.0, 'asymptotic': 2.0,
                    'ricci': 1.0, 'torsion': 3.0, 'h2': 5.0
                }
        
        self.weights = weights
    
    def forward(self, metric_network: nn.Module, coords: torch.Tensor) -> Dict[str, torch.Tensor]:
        losses = {}
        
        losses['snapshot'] = self.snapshot_loss(metric_network)
        g = metric_network(coords)
        
        if self.weights['h2'] > 0:
            losses['h2'] = self.h2_loss(metric_network, coords)
        else:
            losses['h2'] = torch.tensor(0.0, device=coords.device)
        
        losses['volume'] = self.volume_loss(g)
        losses['asymptotic'] = self.asymptotic_loss(metric_network)
        
        if self.weights['ricci'] > 0 and self.phase >= 3:
            coords_ricci = coords[:min(64, coords.shape[0])]
            losses['ricci'] = self.ricci_loss(metric_network, coords_ricci)
        else:
            losses['ricci'] = torch.tensor(0.0, device=coords.device)
        
        if self.weights['torsion'] > 0:
            losses['torsion'] = self.torsion_loss(metric_network, coords)
        else:
            losses['torsion'] = torch.tensor(0.0, device=coords.device)
        
        total = sum(self.weights[key] * losses[key] for key in self.weights.keys())
        losses['total'] = total
        
        return losses

print("Loss functions defined")

## 7. Curriculum Trainer with Checkpoint Management

In [None]:
def sample_training_coordinates(
    batch_size: int,
    phase: int,
    device: torch.device
) -> torch.Tensor:
    """Sample coordinates for training based on phase."""
    
    if phase == 1:
        # Dense sampling near snapshots
        n_per_region = batch_size // 3
        t_m1 = torch.randn(n_per_region, device=device) * 0.5 - 5.0
        t_neck = torch.randn(n_per_region, device=device) * 0.5
        t_m2 = torch.randn(n_per_region, device=device) * 0.5 + 5.0
        t_samples = torch.cat([t_m1, t_neck, t_m2])
    
    elif phase == 2:
        # Uniform sampling across manifold
        t_samples = torch.rand(batch_size, device=device) * 12 - 6
    
    else:
        # Focus on transition regions
        n_half = batch_size // 2
        t_trans1 = torch.rand(n_half, device=device) * 4 - 4
        t_trans2 = torch.rand(n_half, device=device) * 4
        t_samples = torch.cat([t_trans1, t_trans2])
    
    theta = torch.rand(batch_size, device=device) * 2 * np.pi
    base_coords = torch.randn(batch_size, 5, device=device)
    coords = torch.stack([t_samples, theta, *base_coords.T], dim=1)
    
    return coords

print("Coordinate sampling function defined")

In [None]:
class CurriculumTrainer:
    """Trainer with 3-phase curriculum learning and checkpoint management."""
    
    def __init__(
        self,
        config: Dict,
        snapshots: Dict[str, Dict[str, torch.Tensor]],
        device: torch.device,
        checkpoint_dir: str,
        output_dir: str
    ):
        self.config = config
        self.snapshots = snapshots
        self.device = device
        self.checkpoint_dir = Path(checkpoint_dir)
        self.output_dir = Path(output_dir)
        
        self.epochs_total = config['epochs']
        self.batch_size = config['batch_size']
        self.lr = config['lr']
        self.weight_decay = config['weight_decay']
        self.grad_clip = config['grad_clip']
        self.phase_boundaries = config['phase_boundaries']
        
        self.networks = {}
        self.current_phase = 1
        self.start_epoch = 0
        
        self.history = {
            'epoch': [], 'phase': [], 'total_loss': [],
            'snapshot_loss': [], 'volume_loss': [], 'asymptotic_loss': [],
            'ricci_loss': [], 'torsion_loss': [], 'h2_loss': []
        }
        
        self._initialize_networks()
    
    def _initialize_networks(self):
        """Initialize networks for all phases."""
        print("Initializing networks...")
        
        self.networks['m1'] = RegionMetricNetwork(
            n_fourier=self.config['n_fourier'],
            hidden_dims=self.config['hidden_dims']
        ).to(self.device)
        
        self.networks['neck'] = RegionMetricNetwork(
            n_fourier=self.config['n_fourier'],
            hidden_dims=self.config['hidden_dims']
        ).to(self.device)
        
        self.networks['m2'] = RegionMetricNetwork(
            n_fourier=self.config['n_fourier'],
            hidden_dims=self.config['hidden_dims']
        ).to(self.device)
        
        self.networks['global'] = None
        
        self.networks['phi'] = PhiNetwork(
            n_fourier=self.config['n_fourier'],
            hidden_dims=self.config['phi_hidden_dims']
        ).to(self.device)
        
        self.networks['h2'] = HarmonicFormsNetwork(
            n_forms=self.config['h2_n_forms'],
            n_fourier=self.config['h2_n_fourier'],
            hidden_dim=self.config['h2_hidden_dim']
        ).to(self.device)
        
        print(f"  Region networks: {sum(p.numel() for p in self.networks['m1'].parameters()):,} params each")
        print(f"  Phi network: {sum(p.numel() for p in self.networks['phi'].parameters()):,} params")
        print(f"  H² network: {sum(p.numel() for p in self.networks['h2'].parameters()):,} params")
    
    def _initialize_phase(self, phase: int):
        """Initialize optimizers and loss for given phase."""
        print(f"\n{'='*60}")
        print(f"Starting Phase {phase}")
        print(f"{'='*60}")
        
        self.current_phase = phase
        
        if phase == 1:
            params = (
                list(self.networks['m1'].parameters()) +
                list(self.networks['neck'].parameters()) +
                list(self.networks['m2'].parameters()) +
                list(self.networks['phi'].parameters())
            )
            lr = self.lr
            self.active_network = None
        
        elif phase == 2:
            if self.networks['global'] is None:
                self.networks['global'] = GlobalMetricNetwork(
                    region_networks={
                        'm1': self.networks['m1'],
                        'neck': self.networks['neck'],
                        'm2': self.networks['m2']
                    },
                    tau=self.config['gift_params']['tau']
                ).to(self.device)
            
            params = (
                list(self.networks['global'].parameters()) +
                list(self.networks['phi'].parameters()) +
                list(self.networks['h2'].parameters())
            )
            lr = self.lr * 0.5
            self.active_network = self.networks['global']
        
        else:
            params = (
                list(self.networks['global'].parameters()) +
                list(self.networks['phi'].parameters()) +
                list(self.networks['h2'].parameters())
            )
            lr = self.lr * 0.1
            self.active_network = self.networks['global']
        
        self.optimizer = optim.AdamW(
            params, lr=lr, weight_decay=self.weight_decay, betas=(0.9, 0.999)
        )
        
        phase_start, phase_end = self.phase_boundaries[phase - 1]
        phase_epochs = phase_end - phase_start
        
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=phase_epochs, eta_min=lr * 0.1
        )
        
        self.criterion = CompositeLoss(
            snapshots=self.snapshots,
            phi_network=self.networks['phi'],
            h2_network=self.networks['h2'] if phase >= 2 else None,
            phase=phase
        ).to(self.device)
        
        print(f"Learning rate: {lr:.2e}")
        print(f"Phase duration: {phase_epochs} epochs")
    
    def train_epoch(self, epoch: int) -> Dict[str, float]:
        """Train for one epoch."""
        phase = 1
        for p, (start, end) in enumerate(self.phase_boundaries, 1):
            if start <= epoch < end:
                phase = p
                break
        
        if phase != self.current_phase:
            self._initialize_phase(phase)
        
        coords = sample_training_coordinates(self.batch_size, phase, self.device)
        self.optimizer.zero_grad()
        
        if phase == 1:
            losses_total = None
            for region, network in [
                ('m1', self.networks['m1']),
                ('neck', self.networks['neck']),
                ('m2', self.networks['m2'])
            ]:
                losses = self.criterion(network, coords)
                if losses_total is None:
                    losses_total = {k: v / 3.0 for k, v in losses.items()}
                else:
                    for k in losses.keys():
                        losses_total[k] += losses[k] / 3.0
            loss = losses_total['total']
        else:
            losses = self.criterion(self.active_network, coords)
            losses_total = losses
            loss = losses['total']
        
        loss.backward()
        
        if self.grad_clip > 0:
            torch.nn.utils.clip_grad_norm_(
                self.optimizer.param_groups[0]['params'], self.grad_clip
            )
        
        self.optimizer.step()
        self.scheduler.step()
        
        loss_dict = {k: v.item() for k, v in losses_total.items()}
        loss_dict['lr'] = self.optimizer.param_groups[0]['lr']
        loss_dict['phase'] = phase
        
        return loss_dict
    
    def train(self):
        """Run full training loop."""
        print(f"\n{'='*60}")
        print("Starting Curriculum Training")
        print(f"{'='*60}")
        print(f"Total epochs: {self.epochs_total}")
        print(f"Starting from epoch: {self.start_epoch}")
        print(f"Batch size: {self.batch_size}")
        print(f"Device: {self.device}")
        
        self._initialize_phase(1)
        
        for epoch in tqdm(range(self.start_epoch, self.epochs_total), desc="Training"):
            loss_dict = self.train_epoch(epoch)
            
            self.history['epoch'].append(epoch)
            self.history['phase'].append(loss_dict['phase'])
            self.history['total_loss'].append(loss_dict['total'])
            self.history['snapshot_loss'].append(loss_dict.get('snapshot', 0))
            self.history['volume_loss'].append(loss_dict.get('volume', 0))
            self.history['asymptotic_loss'].append(loss_dict.get('asymptotic', 0))
            self.history['ricci_loss'].append(loss_dict.get('ricci', 0))
            self.history['torsion_loss'].append(loss_dict.get('torsion', 0))
            self.history['h2_loss'].append(loss_dict.get('h2', 0))
            
            if epoch % self.config['log_frequency'] == 0 or epoch == self.epochs_total - 1:
                print(f"\nEpoch {epoch:5d} | Phase {loss_dict['phase']} | "
                      f"Loss: {loss_dict['total']:.6f} | "
                      f"Snap: {loss_dict.get('snapshot', 0):.4f} | "
                      f"Vol: {loss_dict.get('volume', 0):.4f} | "
                      f"Tor: {loss_dict.get('torsion', 0):.2e} | "
                      f"LR: {loss_dict['lr']:.2e}")
            
            if (epoch + 1) % self.config['checkpoint_frequency'] == 0:
                self.save_checkpoint(epoch)
        
        print("\nTraining complete!")
        self.save_checkpoint(self.epochs_total - 1, final=True)
    
    def save_checkpoint(self, epoch: int, final: bool = False):
        """Save checkpoint to Google Drive."""
        suffix = 'final' if final else f'epoch_{epoch}'
        
        checkpoint = {
            'epoch': epoch,
            'phase': self.current_phase,
            'config': self.config,
            'history': self.history
        }
        
        for name, network in self.networks.items():
            if network is not None:
                checkpoint[f'{name}_state'] = network.state_dict()
        
        checkpoint_path = self.checkpoint_dir / f'checkpoint_{suffix}.pt'
        torch.save(checkpoint, checkpoint_path)
        
        history_path = self.checkpoint_dir / f'history_{suffix}.json'
        with open(history_path, 'w') as f:
            json.dump(self.history, f, indent=2)
        
        print(f"  Checkpoint saved: {checkpoint_path}")
    
    def load_checkpoint(self, checkpoint_path: Path):
        """Load checkpoint from file."""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        for name, network in self.networks.items():
            if network is not None and f'{name}_state' in checkpoint:
                network.load_state_dict(checkpoint[f'{name}_state'])
        
        self.history = checkpoint['history']
        self.start_epoch = checkpoint['epoch'] + 1
        self.current_phase = checkpoint['phase']
        
        print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
        print(f"Resuming from epoch {self.start_epoch}")
    
    def find_latest_checkpoint(self):
        """Find latest checkpoint in directory."""
        checkpoints = list(self.checkpoint_dir.glob('checkpoint_epoch_*.pt'))
        if not checkpoints:
            return None
        
        latest = max(checkpoints, key=lambda p: int(p.stem.split('_')[-1]))
        return latest

print("CurriculumTrainer defined")

## 8. H³ Harmonic Forms Extraction

In [None]:
def sample_uniform_grid(
    n_points: int = 8192,
    t_range: Tuple[float, float] = (-6.0, 6.0),
    device: torch.device = device
) -> torch.Tensor:
    """Sample uniform grid of points on K₇."""
    t = torch.linspace(t_range[0], t_range[1], n_points, device=device)
    t = t[torch.randperm(n_points)]
    theta = torch.rand(n_points, device=device) * 2 * np.pi
    base_coords = torch.randn(n_points, 5, device=device)
    coords = torch.stack([t, theta, *base_coords.T], dim=1)
    return coords


def compute_laplacian_3forms_approximate(
    metric_network: nn.Module,
    coords: torch.Tensor,
    eps: float = 1e-6
) -> torch.Tensor:
    """Compute approximate Laplacian on 3-forms."""
    n_points = coords.shape[0]
    device = coords.device
    
    with torch.no_grad():
        g = metric_network(coords)
    
    g_inv = torch.inverse(g)
    vol = compute_volume_form(g)
    laplacian = torch.zeros(n_points, 35, 35, device=device)
    trace_g_inv = torch.diagonal(g_inv, dim1=1, dim2=2).sum(dim=1)
    
    for i in range(35):
        laplacian[:, i, i] = trace_g_inv / vol + eps
    
    return laplacian


def extract_harmonic_forms_local(
    metric_network: nn.Module,
    coords: torch.Tensor,
    n_forms: int = 77
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Extract harmonic 3-forms using local eigenvalue decomposition."""
    print(f"Extracting {n_forms} harmonic 3-forms...")
    print(f"  Sample points: {coords.shape[0]}")
    
    n_points = coords.shape[0]
    device = coords.device
    
    print("  Computing local Laplacians...")
    laplacian_local = compute_laplacian_3forms_approximate(metric_network, coords)
    
    print("  Averaging to global Laplacian...")
    laplacian_global = laplacian_local.mean(dim=0)
    laplacian_global = 0.5 * (laplacian_global + laplacian_global.T)
    
    laplacian_np = laplacian_global.cpu().numpy()
    
    print("  Computing eigendecomposition...")
    if n_forms < 35:
        try:
            eigvals, eigvecs = eigsh(laplacian_np, k=min(n_forms, 34), which='SM')
        except Exception as e:
            print(f"  Warning: eigsh failed, using full decomposition")
            eigvals, eigvecs = np.linalg.eigh(laplacian_np)
            eigvals = eigvals[:n_forms]
            eigvecs = eigvecs[:, :n_forms]
    else:
        eigvals, eigvecs = np.linalg.eigh(laplacian_np)
        eigvals = eigvals[:n_forms]
        eigvecs = eigvecs[:, :n_forms]
    
    eigvals = torch.from_numpy(eigvals).float().to(device)
    eigvecs_global = torch.from_numpy(eigvecs).float().to(device)
    
    eigvecs_expanded = eigvecs_global.unsqueeze(0).expand(n_points, -1, -1)
    eigvecs_expanded = eigvecs_expanded.transpose(1, 2)
    
    print(f"  Found {n_forms} forms with eigenvalues in [{eigvals.min():.2e}, {eigvals.max():.2e}]")
    print(f"  Number below 1e-6: {(eigvals < 1e-6).sum().item()}")
    
    return eigvals, eigvecs_expanded


def extract_h3_forms_complete(
    metric_network: nn.Module,
    n_forms: int = 77,
    n_sample_points: int = 8192,
    device: torch.device = device
) -> Dict:
    """Complete H³ extraction pipeline."""
    print("\n" + "="*60)
    print("H³ Harmonic Forms Extraction")
    print("="*60)
    
    metric_network = metric_network.to(device)
    metric_network.eval()
    
    print("\nSampling manifold...")
    coords = sample_uniform_grid(n_sample_points, device=device)
    
    eigvals, forms = extract_harmonic_forms_local(metric_network, coords, n_forms=n_forms)
    
    print("\nFinal statistics:")
    print(f"  Number of forms: {n_forms}")
    print(f"  Eigenvalue range: [{eigvals.min():.2e}, {eigvals.max():.2e}]")
    print(f"  Harmonic forms (λ < 1e-6): {(eigvals < 1e-6).sum().item()}")
    print(f"  Forms shape: {forms.shape}")
    
    return {
        'eigenvalues': eigvals.cpu().numpy(),
        'forms': forms.cpu().numpy(),
        'coords': coords.cpu().numpy(),
        'n_harmonic': (eigvals < 1e-6).sum().item()
    }

print("H³ extraction functions defined")

## 9. Training Execution

In [None]:
# Initialize trainer
trainer = CurriculumTrainer(
    config=CONFIG,
    snapshots=snapshots,
    device=device,
    checkpoint_dir=CHECKPOINT_DIR,
    output_dir=OUTPUT_DIR
)

# Auto-resume from latest checkpoint if available
if CONFIG['auto_resume']:
    latest_checkpoint = trainer.find_latest_checkpoint()
    if latest_checkpoint is not None:
        print(f"\nFound checkpoint: {latest_checkpoint}")
        response = input("Resume from this checkpoint? (y/n): ")
        if response.lower() == 'y':
            trainer.load_checkpoint(latest_checkpoint)
        else:
            print("Starting fresh training")
    else:
        print("No checkpoint found, starting fresh training")

print("\nTrainer initialized and ready")

In [None]:
# Run training
try:
    trainer.train()
except KeyboardInterrupt:
    print("\n\nTraining interrupted by user")
    trainer.save_checkpoint(trainer.current_phase * 1000, final=False)
except Exception as e:
    print(f"\n\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    trainer.save_checkpoint(trainer.current_phase * 1000, final=False)

## 10. Export Results in Multiple Formats

In [None]:
# Export trained networks and results
print("\n" + "="*60)
print("Exporting Results")
print("="*60)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# 1. Export as PyTorch (.pt)
print("\n1. Exporting PyTorch models (.pt)...")
for name, network in trainer.networks.items():
    if network is not None:
        pt_path = Path(OUTPUT_DIR) / f'{name}_network_{timestamp}.pt'
        torch.save(network.state_dict(), pt_path)
        print(f"   Saved: {pt_path}")

# 2. Export training history as JSON
print("\n2. Exporting training history (.json)...")
history_path = Path(OUTPUT_DIR) / f'training_history_{timestamp}.json'
with open(history_path, 'w') as f:
    json.dump(trainer.history, f, indent=2)
print(f"   Saved: {history_path}")

# 3. Export configuration as JSON
print("\n3. Exporting configuration (.json)...")
config_path = Path(OUTPUT_DIR) / f'config_{timestamp}.json'
with open(config_path, 'w') as f:
    json.dump(CONFIG, f, indent=2)
print(f"   Saved: {config_path}")

# 4. Export sample metric evaluations as NumPy (.npy)
print("\n4. Exporting sample metric evaluations (.npy)...")
sample_coords = sample_uniform_grid(1000, device=device)
with torch.no_grad():
    if trainer.networks['global'] is not None:
        sample_metrics = trainer.networks['global'](sample_coords).cpu().numpy()
    else:
        sample_metrics = trainer.networks['m2'](sample_coords).cpu().numpy()

metrics_path = Path(OUTPUT_DIR) / f'sample_metrics_{timestamp}.npy'
np.save(metrics_path, sample_metrics)
print(f"   Saved: {metrics_path}")

coords_path = Path(OUTPUT_DIR) / f'sample_coords_{timestamp}.npy'
np.save(coords_path, sample_coords.cpu().numpy())
print(f"   Saved: {coords_path}")

print("\nExport complete!")

## 11. H³ Harmonic Forms Extraction and Export

In [None]:
if CONFIG['extract_h3']:
    print("\n" + "="*60)
    print("H³ Harmonic Forms Extraction")
    print("="*60)
    
    # Get appropriate network
    if trainer.networks['global'] is not None:
        metric_network = trainer.networks['global']
    else:
        print("Warning: Global network not available, using M2 network")
        metric_network = trainer.networks['m2']
    
    # Extract H³ forms
    h3_results = extract_h3_forms_complete(
        metric_network=metric_network,
        n_forms=CONFIG['h3_n_forms'],
        n_sample_points=CONFIG['h3_n_sample_points'],
        device=device
    )
    
    # Export results
    print("\nExporting H³ results...")
    
    # Export as NumPy (.npz)
    h3_npz_path = Path(OUTPUT_DIR) / f'h3_forms_{timestamp}.npz'
    np.savez(
        h3_npz_path,
        eigenvalues=h3_results['eigenvalues'],
        forms=h3_results['forms'],
        coords=h3_results['coords']
    )
    print(f"  Saved NPZ: {h3_npz_path}")
    
    # Export as separate NumPy files
    np.save(Path(OUTPUT_DIR) / f'h3_eigenvalues_{timestamp}.npy', h3_results['eigenvalues'])
    np.save(Path(OUTPUT_DIR) / f'h3_forms_{timestamp}.npy', h3_results['forms'])
    np.save(Path(OUTPUT_DIR) / f'h3_coords_{timestamp}.npy', h3_results['coords'])
    
    # Export summary as JSON
    h3_summary = {
        'n_forms': CONFIG['h3_n_forms'],
        'n_harmonic': int(h3_results['n_harmonic']),
        'eigenvalue_min': float(h3_results['eigenvalues'].min()),
        'eigenvalue_max': float(h3_results['eigenvalues'].max()),
        'eigenvalues_below_1e6': int((h3_results['eigenvalues'] < 1e-6).sum()),
        'timestamp': timestamp
    }
    
    h3_summary_path = Path(OUTPUT_DIR) / f'h3_summary_{timestamp}.json'
    with open(h3_summary_path, 'w') as f:
        json.dump(h3_summary, f, indent=2)
    print(f"  Saved JSON: {h3_summary_path}")
    
    print("\nH³ extraction summary:")
    print(f"  Total forms: {h3_summary['n_forms']}")
    print(f"  Harmonic (λ < 1e-6): {h3_summary['n_harmonic']}")
    print(f"  Eigenvalue range: [{h3_summary['eigenvalue_min']:.2e}, {h3_summary['eigenvalue_max']:.2e}]")
else:
    print("\nH³ extraction skipped (CONFIG['extract_h3'] = False)")

## 12. Training Visualization

In [None]:
# Plot training history
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('K₇ Metric Reconstruction Training History', fontsize=16, fontweight='bold')

epochs = trainer.history['epoch']

# Total loss
axes[0, 0].plot(epochs, trainer.history['total_loss'], linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Total Loss')
axes[0, 0].set_title('Total Loss')
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].set_yscale('log')

# Snapshot loss
axes[0, 1].plot(epochs, trainer.history['snapshot_loss'], linewidth=2, color='orange')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Snapshot Loss')
axes[0, 1].set_title('Snapshot Anchoring Loss')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_yscale('log')

# Volume loss
axes[0, 2].plot(epochs, trainer.history['volume_loss'], linewidth=2, color='green')
axes[0, 2].set_xlabel('Epoch')
axes[0, 2].set_ylabel('Volume Loss')
axes[0, 2].set_title('Volume Normalization Loss')
axes[0, 2].grid(True, alpha=0.3)
axes[0, 2].set_yscale('log')

# Torsion loss
axes[1, 0].plot(epochs, trainer.history['torsion_loss'], linewidth=2, color='red')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Torsion Loss')
axes[1, 0].set_title('G₂ Torsion Loss')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_yscale('log')

# H² loss
axes[1, 1].plot(epochs, trainer.history['h2_loss'], linewidth=2, color='purple')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('H² Loss')
axes[1, 1].set_title('H² Orthonormality Loss')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_yscale('log')

# Ricci loss
axes[1, 2].plot(epochs, trainer.history['ricci_loss'], linewidth=2, color='brown')
axes[1, 2].set_xlabel('Epoch')
axes[1, 2].set_ylabel('Ricci Loss')
axes[1, 2].set_title('Ricci Flatness Loss')
axes[1, 2].grid(True, alpha=0.3)
axes[1, 2].set_yscale('log')

plt.tight_layout()
plot_path = Path(OUTPUT_DIR) / f'training_history_{timestamp}.png'
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
print(f"\nTraining plot saved: {plot_path}")
plt.show()

## 13. Training Summary and Final Statistics

In [None]:
# Generate final summary report
print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)

summary_report = {
    'training': {
        'total_epochs': CONFIG['epochs'],
        'batch_size': CONFIG['batch_size'],
        'final_total_loss': float(trainer.history['total_loss'][-1]),
        'final_snapshot_loss': float(trainer.history['snapshot_loss'][-1]),
        'final_volume_loss': float(trainer.history['volume_loss'][-1]),
        'final_torsion_loss': float(trainer.history['torsion_loss'][-1]),
    },
    'networks': {
        'm1_params': sum(p.numel() for p in trainer.networks['m1'].parameters()),
        'neck_params': sum(p.numel() for p in trainer.networks['neck'].parameters()),
        'm2_params': sum(p.numel() for p in trainer.networks['m2'].parameters()),
        'phi_params': sum(p.numel() for p in trainer.networks['phi'].parameters()),
        'h2_params': sum(p.numel() for p in trainer.networks['h2'].parameters()),
    },
    'outputs': {
        'checkpoint_dir': str(CHECKPOINT_DIR),
        'output_dir': str(OUTPUT_DIR),
        'timestamp': timestamp
    },
    'gift_params': CONFIG['gift_params']
}

if CONFIG['extract_h3']:
    summary_report['h3_extraction'] = h3_summary

# Save summary report
summary_path = Path(OUTPUT_DIR) / f'summary_report_{timestamp}.json'
with open(summary_path, 'w') as f:
    json.dump(summary_report, f, indent=2)

# Print summary
print("\nTraining Statistics:")
print(f"  Total epochs completed: {CONFIG['epochs']}")
print(f"  Final total loss: {summary_report['training']['final_total_loss']:.6e}")
print(f"  Final snapshot loss: {summary_report['training']['final_snapshot_loss']:.6e}")
print(f"  Final volume loss: {summary_report['training']['final_volume_loss']:.6e}")
print(f"  Final torsion loss: {summary_report['training']['final_torsion_loss']:.6e}")

print("\nNetwork Parameters:")
for name, count in summary_report['networks'].items():
    print(f"  {name}: {count:,}")

print("\nOutput Locations:")
print(f"  Checkpoints: {CHECKPOINT_DIR}")
print(f"  Results: {OUTPUT_DIR}")

if CONFIG['extract_h3']:
    print("\nH³ Extraction:")
    print(f"  Total forms: {h3_summary['n_forms']}")
    print(f"  Harmonic forms: {h3_summary['n_harmonic']}")

print(f"\nSummary report saved: {summary_path}")
print("\n" + "="*60)
print("ALL TASKS COMPLETED SUCCESSFULLY")
print("="*60)