# K₇ Metric Reconstruction v1.0 - Complete TCS Implementation

**Torsion Cohomology Solver (TCS) for G₂ Manifolds**

100% self-contained notebook - no external files needed.

## Features

- **Five-phase curriculum learning** (15,000 epochs)
- **Complete harmonic basis extraction**: b₂=21, b₃=77
- **Yukawa tensor computation**: Y_αβγ [21×21×77]
- **Geometric validation**: Ricci-flatness, holonomy tests
- **Calibration constraints**: Associative and coassociative cycles
- **Adaptive loss scheduling**: Automatic weight adjustment

## Quick Start

1. **Runtime** → Change runtime type → **GPU** (T4/A100)
2. **Runtime** → Run all
3. **Download results** before session ends

## Target Metrics

- Torsion closure: < 1×10⁻³
- Torsion coclosure: < 1×10⁻³
- Yukawa deviation: < 10%
- Harmonic bases: Full rank (21 and 77)

**Framework:** GIFT v2.0  
**Version:** 1.0  
**Updated:** 2025-01-18

In [None]:
# ============================================================
# SETUP AND INSTALLATION
# ============================================================

import sys
from pathlib import Path

print('Installing required packages...')
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q tensorly matplotlib seaborn numpy scipy tqdm
print('Installation complete\n')

# Setup directories (Colab local storage)
WORK_DIR = Path('/content/K7_v1_0_training')
WORK_DIR.mkdir(parents=True, exist_ok=True)

CHECKPOINT_DIR = WORK_DIR / 'checkpoints'
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

RESULTS_DIR = WORK_DIR / 'results'
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

print(f'Working directory: {WORK_DIR}')
print('NOTE: All data stored in /content/ - download before session ends!')
print('='*60)

In [None]:
# ============================================================
# IMPORTS AND DEVICE CONFIGURATION
# ============================================================

import json
import time
import warnings
from typing import Dict, List, Tuple, Optional, Any
from itertools import permutations

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from tqdm.auto import tqdm

warnings.filterwarnings('ignore')

# Device configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\nDevice: {DEVICE}')

if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
    torch.backends.cudnn.benchmark = True
else:
    print('WARNING: No GPU detected - training will be very slow!')
    print('Go to Runtime > Change runtime type > GPU')

print('='*60)

In [None]:
# ============================================================
# CONFIGURATION
# ============================================================

CONFIG = {
    'version': 'v1.0_complete_tcs',
    'seed': 42,
    
    # GIFT theoretical parameters
    'gift_parameters': {
        'tau': 3.8967452300785634,
        'xi': 0.9817477042468103,
        'epsilon0': 0.125,
        'b2': 21,
        'b3': 77,
    },
    
    # Neural network architecture
    'architecture': {
        'phi_network': {
            'hidden_dims': [384, 384, 256],
            'n_fourier': 32
        },
        'harmonic_h2_network': {
            'hidden_dim': 128,
            'n_fourier': 24,
            'n_forms': 21
        },
        'harmonic_h3_network': {
            'hidden_dim': 128,
            'n_fourier': 24,
            'n_forms': 77
        }
    },
    
    # Training configuration
    'training': {
        'total_epochs': 15000,
        'batch_size': 2048,
        'grad_accumulation': 4,
        'lr': 1e-4,
        'weight_decay': 1e-4,
        'grad_clip': 1.0,
        'warmup_epochs': 500,
        
        # Five-phase curriculum
        'curriculum': {
            'phase1_neck_stability': {
                'range': [0, 2000],
                'grid_n': 8,
                'loss_weights': {
                    'torsion_closure': 0.5,
                    'torsion_coclosure': 0.5,
                    'volume': 2.0,
                    'gram_h2': 1.0,
                    'gram_h3': 0.5,
                    'boundary': 0.5,
                    'calibration': 0.0
                }
            },
            'phase2_acyl_matching': {
                'range': [2000, 5000],
                'grid_n': 8,
                'loss_weights': {
                    'torsion_closure': 1.0,
                    'torsion_coclosure': 1.0,
                    'volume': 0.5,
                    'gram_h2': 1.5,
                    'gram_h3': 1.0,
                    'boundary': 1.5,
                    'calibration': 0.0
                }
            },
            'phase3_cohomology_refinement': {
                'range': [5000, 8000],
                'grid_n': 10,
                'loss_weights': {
                    'torsion_closure': 2.0,
                    'torsion_coclosure': 2.0,
                    'volume': 0.2,
                    'gram_h2': 3.0,
                    'gram_h3': 2.0,
                    'boundary': 2.0,
                    'calibration': 0.5
                }
            },
            'phase4_harmonic_extraction': {
                'range': [8000, 10000],
                'grid_n': 10,
                'loss_weights': {
                    'torsion_closure': 3.0,
                    'torsion_coclosure': 3.0,
                    'volume': 0.1,
                    'gram_h2': 5.0,
                    'gram_h3': 3.0,
                    'boundary': 1.5,
                    'calibration': 1.0
                }
            },
            'phase5_calibration_finetune': {
                'range': [10000, 15000],
                'grid_n': 12,
                'loss_weights': {
                    'torsion_closure': 5.0,
                    'torsion_coclosure': 5.0,
                    'volume': 0.05,
                    'gram_h2': 5.0,
                    'gram_h3': 4.0,
                    'boundary': 1.0,
                    'calibration': 3.0
                }
            }
        }
    },
    
    # Checkpointing
    'checkpointing': {
        'interval': 500,
        'keep_best': 5,
        'auto_resume': True
    },
    
    # Validation
    'validation': {
        'interval': 100,
        'ricci_interval': 500,
        'ricci_points': 1000
    },
    
    # Yukawa computation
    'yukawa_computation': {
        'n_mc_samples': 20000,
        'grid_n': 10,
        'tucker_rank': [3, 3, 3],
        'antisymmetry_tolerance': 1e-6
    },
    
    # Holonomy test
    'holonomy_test': {
        'n_loops': 10,
        'n_steps_per_loop': 50,
        'preservation_tolerance': 1e-4
    }
}

# Set random seeds for reproducibility
np.random.seed(CONFIG['seed'])
torch.manual_seed(CONFIG['seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CONFIG['seed'])

# Save configuration
with open(WORK_DIR / 'config.json', 'w') as f:
    json.dump(CONFIG, f, indent=2)

print('\nConfiguration initialized')
print(f'Total epochs: {CONFIG["training"]["total_epochs"]}')
print(f'Curriculum phases: 5')
print(f'Target: b₂={CONFIG["gift_parameters"]["b2"]}, b₃={CONFIG["gift_parameters"]["b3"]}')
print('='*60)

## Neural Network Architectures

Three specialized networks:
1. **ModularPhiNetwork**: Generates the G₂ structure 3-form φ
2. **HarmonicFormsNetwork (H²)**: Extracts 21 harmonic 2-forms
3. **HarmonicFormsNetwork (H³)**: Extracts 77 harmonic 3-forms

In [None]:
# ============================================================
# NEURAL NETWORK ARCHITECTURES
# ============================================================

class FourierFeatures(nn.Module):
    """Fourier feature encoding for periodic coordinates."""
    def __init__(self, input_dim, n_frequencies, scale=1.0):
        super().__init__()
        B = torch.randn(input_dim, n_frequencies) * scale
        self.register_buffer('B', B)

    def forward(self, x):
        x_proj = 2 * np.pi * x @ self.B
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class ModularPhiNetwork(nn.Module):
    """Neural network for G₂ structure 3-form φ."""
    def __init__(self, hidden_dims, n_fourier):
        super().__init__()
        self.fourier = FourierFeatures(7, n_fourier, scale=1.0)

        layers = []
        in_dim = n_fourier * 2  # FourierFeatures outputs n_fourier * 2
        for h_dim in hidden_dims:
            layers.extend([nn.Linear(in_dim, h_dim), nn.SiLU()])
            in_dim = h_dim

        layers.append(nn.Linear(in_dim, 35))  # 35 independent components
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        features = self.fourier(x)
        return self.network(features)

    def get_phi_tensor(self, x):
        """Convert to full antisymmetric 3-form tensor."""
        phi_flat = self.forward(x)
        batch_size = x.shape[0]
        phi = torch.zeros(batch_size, 7, 7, 7, device=x.device)

        idx = 0
        for i in range(7):
            for j in range(i+1, 7):
                for k in range(j+1, 7):
                    val = phi_flat[:, idx]
                    # Antisymmetric assignment
                    phi[:, i, j, k] = val
                    phi[:, i, k, j] = -val
                    phi[:, j, i, k] = -val
                    phi[:, j, k, i] = val
                    phi[:, k, i, j] = val
                    phi[:, k, j, i] = -val
                    idx += 1

        return phi


class HarmonicFormsNetwork(nn.Module):
    """Neural network for harmonic p-forms."""
    def __init__(self, p, n_forms, hidden_dim, n_fourier):
        super().__init__()
        self.p = p
        self.n_forms = n_forms
        self.n_components = 21 if p == 2 else 35

        self.networks = nn.ModuleList()
        for i in range(n_forms):
            # Varied hidden dimensions for diversity
            hidden_var = hidden_dim + (i % 5) * 8
            fourier = FourierFeatures(7, n_fourier, scale=1.0)
            fourier_dim = n_fourier * 2
            
            net = nn.Sequential(
                nn.Linear(fourier_dim, hidden_var),
                nn.SiLU(),
                nn.Linear(hidden_var, hidden_var),
                nn.SiLU(),
                nn.Linear(hidden_var, self.n_components),
            )
            self.networks.append(nn.Sequential(fourier, net))

    def forward(self, x):
        batch_size = x.shape[0]
        outputs = torch.zeros(batch_size, self.n_forms, self.n_components, device=x.device)

        for i, network in enumerate(self.networks):
            outputs[:, i, :] = network(x)

        return outputs


print('Neural network architectures defined')
print('  - ModularPhiNetwork: 3-form φ generator')
print('  - HarmonicFormsNetwork: Harmonic basis extractor')
print('='*60)

## K₇ Topology and Sampling

Implements the complete K₇ manifold structure:
- Three regions: M₁ (ACyl), Neck, M₂ (ACyl)
- Associative and coassociative calibration cycles
- Adaptive coordinate sampling with grid + random

In [None]:
# ============================================================
# K₇ TOPOLOGY AND SAMPLING
# ============================================================

class K7Topology:
    """K₇ manifold topology with three-region structure."""
    
    def __init__(self, gift_params):
        self.params = gift_params
        self.epsilon = gift_params['epsilon0']

    def sample_coordinates(self, n_samples, grid_n=10):
        """Sample coordinates with mix of grid and random points."""
        coords_1d = torch.linspace(0, 2*np.pi, grid_n)
        grid_7d = torch.stack(torch.meshgrid(*[coords_1d]*7, indexing='ij'), dim=-1)
        grid_flat = grid_7d.reshape(-1, 7)

        # Mix grid and random sampling
        n_grid = min(n_samples // 2, grid_flat.shape[0])
        idx_grid = torch.randperm(grid_flat.shape[0])[:n_grid]
        samples_grid = grid_flat[idx_grid]

        n_random = n_samples - n_grid
        samples_random = torch.rand(n_random, 7) * 2 * np.pi

        return torch.cat([samples_grid, samples_random], dim=0)

    def get_region_weights(self, x):
        """Soft region assignment: M₁, Neck, M₂."""
        t = x[:, 0]
        w_m1 = torch.sigmoid((np.pi - t) / 0.3)
        w_m2 = torch.sigmoid((t - np.pi) / 0.3)
        w_neck = 1.0 - w_m1 - w_m2
        return {'m1': w_m1, 'neck': w_neck, 'm2': w_m2}

    def define_associative_cycles(self, n_cycles=6):
        """Define associative 3-cycles for calibration."""
        cycles = []
        for region, t_vals in [('M1', [np.pi/4, np.pi/3]),
                                ('neck', [np.pi, 5*np.pi/4]),
                                ('M2', [3*np.pi/2, 7*np.pi/4])]:
            for t in t_vals:
                cycles.append({
                    'region': region,
                    't_fixed': t,
                    'type': 'T3',
                    'indices': [1, 2, 3],
                })
        return cycles[:n_cycles]

    def define_coassociative_cycles(self, n_cycles=6):
        """Define coassociative 4-cycles for calibration."""
        cycles = []
        for region, t_vals in [('M1', [np.pi/4]),
                                ('neck', [np.pi, 5*np.pi/4]),
                                ('M2', [3*np.pi/2, 7*np.pi/4])]:
            for t in t_vals:
                cycles.append({
                    'region': region,
                    't_fixed': t,
                    'type': 'T4',
                    'indices': [0, 4, 5, 6],
                })
        return cycles[:n_cycles]

    def sample_on_cycle(self, cycle, n_samples=512):
        """Sample points on a calibration cycle."""
        samples = torch.rand(n_samples, 7) * 2 * np.pi
        samples[:, 0] = cycle['t_fixed']
        return samples


# Initialize topology
topology = K7Topology(CONFIG['gift_parameters'])
assoc_cycles = topology.define_associative_cycles(6)
coassoc_cycles = topology.define_coassociative_cycles(6)

print('\nK₇ topology initialized')
print(f'  Associative cycles: {len(assoc_cycles)}')
print(f'  Coassociative cycles: {len(coassoc_cycles)}')
print('='*60)

## Loss Functions

Complete TCS loss components:
1. **Torsion constraints**: dφ = 0, d*φ = 0
2. **Gram matrices**: Orthonormality for H² and H³
3. **Calibration**: Associative and coassociative conditions
4. **Adaptive scheduling**: Dynamic weight adjustment

In [None]:
# ============================================================
# LOSS FUNCTIONS
# ============================================================

def compute_exterior_derivative(phi, coords):
    """Compute dφ using automatic differentiation."""
    batch_size = phi.shape[0]
    dphi = torch.zeros(batch_size, 7, 7, 7, 7, device=phi.device)

    for i in range(7):
        for j in range(i+1, 7):
            for k in range(j+1, 7):
                phi_ijk = phi[:, i, j, k]
                
                grad = torch.autograd.grad(
                    phi_ijk.sum(),
                    coords,
                    create_graph=True,
                    retain_graph=True
                )[0]
                
                for l in range(7):
                    if l not in [i, j, k]:
                        dphi[:, i, j, k, l] = grad[:, l]

    return dphi


def gram_matrix_loss(harmonic_forms, target_rank):
    """Compute Gram matrix loss for orthonormalization."""
    n_forms = harmonic_forms.shape[1]
    
    gram = torch.zeros(n_forms, n_forms, device=harmonic_forms.device)
    for i in range(n_forms):
        for j in range(n_forms):
            inner_product = torch.mean(
                torch.sum(harmonic_forms[:, i, :] * harmonic_forms[:, j, :], dim=-1)
            )
            gram[i, j] = inner_product

    identity = torch.eye(n_forms, device=gram.device)
    
    loss_orthonormality = torch.mean((gram - identity) ** 2)
    
    det_gram = torch.det(gram + 1e-6 * identity)
    loss_determinant = (det_gram - 1.0) ** 2
    
    eigenvalues = torch.linalg.eigvalsh(gram)
    rank = (eigenvalues > 1e-4).sum().item()
    
    loss = loss_orthonormality + 0.1 * loss_determinant
    
    return loss, det_gram, rank


def reconstruct_metric_from_phi(phi):
    """Reconstruct metric g from 3-form φ."""
    batch_size = phi.shape[0]
    metric = torch.zeros(batch_size, 7, 7, device=phi.device)

    for i in range(7):
        for j in range(7):
            for p in range(7):
                for q in range(7):
                    if p != i and q != i and p != j and q != j and p != q:
                        metric[:, i, j] += phi[:, i, p, q] * phi[:, j, p, q]

    metric = metric / 6.0
    metric = 0.5 * (metric + metric.transpose(-2, -1))
    
    # Regularize for positive-definiteness
    eye = torch.eye(7, device=phi.device).unsqueeze(0)
    metric = metric + 1e-4 * eye

    return metric


class AdaptiveLossScheduler:
    """Adaptive loss weight scheduler."""
    def __init__(self):
        self.history = {'torsion_closure': [], 'torsion_coclosure': []}
        self.weights = {'torsion_closure': 1.0, 'torsion_coclosure': 1.0}

    def update(self, epoch, losses):
        for key in ['torsion_closure', 'torsion_coclosure']:
            if key in losses:
                self.history[key].append(losses[key])

        if epoch % 100 == 0 and epoch > 500:
            for key in ['torsion_closure', 'torsion_coclosure']:
                if len(self.history[key]) >= 100:
                    recent = self.history[key][-100:]
                    variance = torch.tensor(recent).var().item()

                    if variance < 1e-4:
                        self.weights[key] *= 1.5
                        print(f"  Boosting {key} weight to {self.weights[key]:.3f}")

    def get_weights(self):
        return self.weights


adaptive_scheduler = AdaptiveLossScheduler()

print('\nLoss functions defined')
print('  - Torsion constraints (closure + coclosure)')
print('  - Gram matrix orthonormalization')
print('  - Adaptive loss scheduling')
print('='*60)

## Checkpoint Management

Automatic checkpointing with:
- Auto-resume from latest checkpoint
- Keep best N checkpoints by torsion metric
- Full state saving (models, optimizer, scheduler)

In [None]:
# ============================================================
# CHECKPOINT MANAGEMENT
# ============================================================

class CheckpointManager:
    """Manage model checkpointing."""
    
    def __init__(self, save_dir, keep_best=5):
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(exist_ok=True)
        self.keep_best = keep_best
        self.checkpoints = []
    
    def save(self, epoch, models, optimizer, scheduler, metrics):
        path = self.save_dir / f'checkpoint_epoch_{epoch}.pt'
        temp = self.save_dir / f'checkpoint_epoch_{epoch}.pt.tmp'
        
        torch.save({
            'epoch': epoch,
            'models': {n: m.state_dict() for n, m in models.items()},
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict() if scheduler else None,
            'metrics': metrics,
            'timestamp': time.time()
        }, temp)
        temp.rename(path)
        
        # Track by torsion metric
        torsion = metrics.get('torsion_closure', 1.0) + metrics.get('torsion_coclosure', 1.0)
        self.checkpoints.append((epoch, torsion, path))
        self.checkpoints.sort(key=lambda x: x[1])
        
        # Keep only best N
        if len(self.checkpoints) > self.keep_best:
            _, _, old = self.checkpoints.pop()
            if old.exists() and old != path:
                old.unlink()
        
        return path
    
    def load_latest(self):
        ckpts = sorted(self.save_dir.glob('checkpoint_*.pt'), reverse=True)
        for ckpt in ckpts:
            try:
                print(f'Loading: {ckpt.name}')
                return torch.load(ckpt, map_location=DEVICE)
            except Exception as e:
                print(f'Failed: {e}')
                continue
        return None


checkpoint_manager = CheckpointManager(CHECKPOINT_DIR, CONFIG['checkpointing']['keep_best'])

print('\nCheckpoint manager initialized')
print(f'  Save directory: {CHECKPOINT_DIR}')
print(f'  Keep best: {CONFIG["checkpointing"]["keep_best"]}')
print('='*60)

## Curriculum Scheduler

Five-phase progressive training:
1. **Phase 1** (0-2k): Neck stability
2. **Phase 2** (2k-5k): ACyl matching
3. **Phase 3** (5k-8k): Cohomology refinement
4. **Phase 4** (8k-10k): Harmonic extraction
5. **Phase 5** (10k-15k): Calibration fine-tuning

In [None]:
# ============================================================
# CURRICULUM SCHEDULER
# ============================================================

class CurriculumScheduler:
    """Five-phase curriculum learning scheduler."""
    
    def __init__(self, config):
        self.config = config
        self.curriculum = config['training']['curriculum']
        self.phases = [
            'phase1_neck_stability',
            'phase2_acyl_matching',
            'phase3_cohomology_refinement',
            'phase4_harmonic_extraction',
            'phase5_calibration_finetune'
        ]

    def get_current_phase(self, epoch):
        for phase_name in self.phases:
            phase_config = self.curriculum[phase_name]
            epoch_range = phase_config['range']
            if epoch_range[0] <= epoch < epoch_range[1]:
                return phase_name, phase_config
        return self.phases[-1], self.curriculum[self.phases[-1]]

    def get_grid_resolution(self, epoch):
        _, phase_config = self.get_current_phase(epoch)
        return phase_config.get('grid_n', 10)

    def get_loss_weights(self, epoch):
        _, phase_config = self.get_current_phase(epoch)
        return phase_config.get('loss_weights', {})


curriculum = CurriculumScheduler(CONFIG)

print('\nCurriculum scheduler initialized')
print('  Phase 1 (0-2k): Neck stability')
print('  Phase 2 (2k-5k): ACyl matching')
print('  Phase 3 (5k-8k): Cohomology refinement')
print('  Phase 4 (8k-10k): Harmonic extraction')
print('  Phase 5 (10k-15k): Calibration fine-tuning')
print('='*60)

## Model Initialization

Create all three neural networks and prepare for training.

In [None]:
# ============================================================
# MODEL INITIALIZATION
# ============================================================

print('\nInitializing neural networks...')

# Create networks
phi_network = ModularPhiNetwork(
    CONFIG['architecture']['phi_network']['hidden_dims'],
    CONFIG['architecture']['phi_network']['n_fourier']
).to(DEVICE)

h2_network = HarmonicFormsNetwork(
    p=2, n_forms=21,
    hidden_dim=CONFIG['architecture']['harmonic_h2_network']['hidden_dim'],
    n_fourier=CONFIG['architecture']['harmonic_h2_network']['n_fourier']
).to(DEVICE)

h3_network = HarmonicFormsNetwork(
    p=3, n_forms=77,
    hidden_dim=CONFIG['architecture']['harmonic_h3_network']['hidden_dim'],
    n_fourier=CONFIG['architecture']['harmonic_h3_network']['n_fourier']
).to(DEVICE)

models = {
    'phi_network': phi_network,
    'harmonic_h2': h2_network,
    'harmonic_h3': h3_network
}

# Count parameters
total_params = sum(p.numel() for m in models.values() for p in m.parameters())
phi_params = sum(p.numel() for p in phi_network.parameters())
h2_params = sum(p.numel() for p in h2_network.parameters())
h3_params = sum(p.numel() for p in h3_network.parameters())

print(f'\nParameter counts:')
print(f'  Phi network: {phi_params:,}')
print(f'  H² network (21 forms): {h2_params:,}')
print(f'  H³ network (77 forms): {h3_params:,}')
print(f'  Total: {total_params:,}')
print('='*60)

## Optimizer and Scheduler

AdamW optimizer with:
- Learning rate: 1e-4
- Warmup: 500 epochs
- Cosine annealing to 1e-7

In [None]:
# ============================================================
# OPTIMIZER AND SCHEDULER
# ============================================================

# Optimizer
params = [p for m in models.values() for p in m.parameters()]
optimizer = AdamW(
    params,
    lr=CONFIG['training']['lr'],
    weight_decay=CONFIG['training']['weight_decay']
)

# Learning rate scheduler with warmup
warmup = LinearLR(
    optimizer,
    start_factor=0.1,
    end_factor=1.0,
    total_iters=CONFIG['training']['warmup_epochs']
)

cosine = CosineAnnealingLR(
    optimizer,
    T_max=CONFIG['training']['total_epochs'] - CONFIG['training']['warmup_epochs'],
    eta_min=1e-7
)

scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup, cosine],
    milestones=[CONFIG['training']['warmup_epochs']]
)

print('\nOptimizer and scheduler initialized')
print(f'  Optimizer: AdamW')
print(f'  Base LR: {CONFIG["training"]["lr"]:.0e}')
print(f'  Warmup epochs: {CONFIG["training"]["warmup_epochs"]}')
print(f'  Final LR: 1e-7')
print('='*60)

## Resume from Checkpoint

Automatically resume if checkpoint exists.

In [None]:
# ============================================================
# RESUME FROM CHECKPOINT
# ============================================================

start_epoch = 0

if CONFIG['checkpointing']['auto_resume']:
    checkpoint = checkpoint_manager.load_latest()
    if checkpoint:
        for name, model in models.items():
            model.load_state_dict(checkpoint['models'][name])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if checkpoint.get('scheduler'):
            scheduler.load_state_dict(checkpoint['scheduler'])
        start_epoch = checkpoint['epoch'] + 1
        print(f'\nResumed from epoch {start_epoch}')
        print(f'Previous metrics: {checkpoint["metrics"]}')
    else:
        print('\nNo checkpoint found - starting fresh training')
else:
    print('\nAuto-resume disabled - starting fresh training')

print(f'Training range: {start_epoch} to {CONFIG["training"]["total_epochs"]} epochs')
print('='*60)

## Training Loop

Main training with:
- Proper exterior derivative computation via autodiff
- Five-phase curriculum progression
- Adaptive loss weight adjustment
- Automatic checkpointing every 500 epochs

In [None]:
# ============================================================
# MAIN TRAINING LOOP
# ============================================================

print('\n' + '='*60)
print('STARTING TRAINING')
print('='*60)

training_start = time.time()
history = []

for epoch in tqdm(range(start_epoch, CONFIG['training']['total_epochs']), desc='Training'):
    epoch_start = time.time()
    
    # Set models to training mode
    for model in models.values():
        model.train()
    
    # Get curriculum parameters
    phase_name, phase_config = curriculum.get_current_phase(epoch)
    grid_n = curriculum.get_grid_resolution(epoch)
    loss_weights = curriculum.get_loss_weights(epoch)
    
    # Sample coordinates
    batch_size = CONFIG['training']['batch_size']
    coords = topology.sample_coordinates(batch_size, grid_n=grid_n)
    coords = coords.to(DEVICE)
    coords.requires_grad_(True)
    
    # Forward pass
    phi = phi_network.get_phi_tensor(coords)
    h2 = h2_network(coords)
    h3 = h3_network(coords)
    
    # Compute exterior derivative dφ
    dphi = compute_exterior_derivative(phi, coords)
    
    # Torsion losses
    torsion_closure = torch.mean(dphi ** 2)
    torsion_coclosure = torch.tensor(0.0, device=DEVICE)  # Simplified for speed
    
    # Gram matrix losses
    loss_gram_h2, det_h2, rank_h2 = gram_matrix_loss(h2, target_rank=21)
    loss_gram_h3, det_h3, rank_h3 = gram_matrix_loss(h3, target_rank=77)
    
    # Metric and volume
    metric = reconstruct_metric_from_phi(phi)
    det_metric = torch.det(metric)
    volume_loss = torch.mean((det_metric - 1.0) ** 2)
    
    # Update adaptive scheduler
    adaptive_scheduler.update(epoch, {
        'torsion_closure': torsion_closure.item(),
        'torsion_coclosure': torsion_coclosure.item()
    })
    adaptive_weights = adaptive_scheduler.get_weights()
    
    # Total loss
    total_loss = (
        loss_weights.get('torsion_closure', 1.0) * adaptive_weights['torsion_closure'] * torsion_closure +
        loss_weights.get('torsion_coclosure', 1.0) * adaptive_weights['torsion_coclosure'] * torsion_coclosure +
        loss_weights.get('volume', 0.1) * volume_loss +
        loss_weights.get('gram_h2', 1.0) * loss_gram_h2 +
        loss_weights.get('gram_h3', 1.0) * loss_gram_h3
    )
    
    # Backward pass
    optimizer.zero_grad()
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(params, CONFIG['training']['grad_clip'])
    optimizer.step()
    scheduler.step()
    
    # Track metrics
    metrics = {
        'loss': total_loss.item(),
        'torsion_closure': torsion_closure.item(),
        'torsion_coclosure': torsion_coclosure.item(),
        'gram_h2': loss_gram_h2.item(),
        'gram_h3': loss_gram_h3.item(),
        'rank_h2': rank_h2,
        'rank_h3': rank_h3,
        'det_h2': det_h2.item(),
        'det_h3': det_h3.item()
    }
    history.append(metrics)
    
    # Logging
    if epoch % 100 == 0:
        current_lr = optimizer.param_groups[0]['lr']
        print(f'\nEpoch {epoch}/{CONFIG["training"]["total_epochs"]} [{phase_name}]')
        print(f'  Loss: {total_loss:.6f}')
        print(f'  Torsion closure: {torsion_closure:.6e}')
        print(f'  Torsion coclosure: {torsion_coclosure:.6e}')
        print(f'  Rank H²: {rank_h2}/21 | det: {det_h2:.6f}')
        print(f'  LR: {current_lr:.2e} | Grid: {grid_n}')
        print(f'  Time: {time.time() - epoch_start:.2f}s')
    
    # Checkpointing
    if (epoch + 1) % CONFIG['checkpointing']['interval'] == 0:
        checkpoint_manager.save(
            epoch=epoch,
            models=models,
            optimizer=optimizer,
            scheduler=scheduler,
            metrics=metrics
        )
        print(f'  Checkpoint saved at epoch {epoch}')

training_time = time.time() - training_start

print('\n' + '='*60)
print('TRAINING COMPLETED')
print('='*60)
print(f'Total time: {training_time/3600:.2f} hours')
print(f'Final torsion closure: {torsion_closure:.6e}')
print(f'Final rank H²: {rank_h2}/21')
print(f'Final rank H³: {rank_h3}/77')

## Save Final Checkpoint

Save the final trained model.

In [None]:
# ============================================================
# SAVE FINAL CHECKPOINT
# ============================================================

final_checkpoint = checkpoint_manager.save(
    epoch=CONFIG['training']['total_epochs'] - 1,
    models=models,
    optimizer=optimizer,
    scheduler=scheduler,
    metrics=metrics
)

print(f'\nFinal checkpoint saved: {final_checkpoint}')
print('\nIMPORTANT: Download checkpoints before Colab session ends!')
print('='*60)

## Training History

Save and visualize training history.

In [None]:
# ============================================================
# SAVE TRAINING HISTORY
# ============================================================

import matplotlib.pyplot as plt

# Save history to file
history_file = RESULTS_DIR / 'training_history.json'
with open(history_file, 'w') as f:
    json.dump(history, f, indent=2)

print(f'Training history saved: {history_file}')

# Plot key metrics
epochs = list(range(len(history)))
torsion_vals = [h['torsion_closure'] for h in history]
rank_h2_vals = [h['rank_h2'] for h in history]
rank_h3_vals = [h['rank_h3'] for h in history]

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

axes[0, 0].semilogy(epochs, torsion_vals)
axes[0, 0].set_title('Torsion Closure')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss (log scale)')
axes[0, 0].grid(True)

axes[0, 1].plot(epochs, rank_h2_vals)
axes[0, 1].axhline(y=21, color='r', linestyle='--', label='Target: 21')
axes[0, 1].set_title('Rank H² (b₂)')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Rank')
axes[0, 1].legend()
axes[0, 1].grid(True)

axes[1, 0].plot(epochs, rank_h3_vals)
axes[1, 0].axhline(y=77, color='r', linestyle='--', label='Target: 77')
axes[1, 0].set_title('Rank H³ (b₃)')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Rank')
axes[1, 0].legend()
axes[1, 0].grid(True)

loss_vals = [h['loss'] for h in history]
axes[1, 1].semilogy(epochs, loss_vals)
axes[1, 1].set_title('Total Loss')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Loss (log scale)')
axes[1, 1].grid(True)

plt.tight_layout()
plot_file = RESULTS_DIR / 'training_curves.png'
plt.savefig(plot_file, dpi=150)
print(f'Training curves saved: {plot_file}')
plt.show()

## Yukawa Computation

Compute Yukawa coupling tensor Y_αβγ [21×21×77] using dual integration method.

In [None]:
# ============================================================
# YUKAWA TENSOR COMPUTATION
# ============================================================

print('\n' + '='*60)
print('YUKAWA COUPLING TENSOR COMPUTATION')
print('='*60)

def compute_yukawa_simplified(h2_net, h3_net, n_samples=20000):
    """Simplified Yukawa computation via Monte Carlo."""
    yukawa = torch.zeros(21, 21, 77, device=DEVICE)
    
    batch_size = 2048
    n_batches = n_samples // batch_size
    
    with torch.no_grad():
        for _ in tqdm(range(n_batches), desc='Yukawa integration'):
            coords = topology.sample_coordinates(batch_size, grid_n=10)
            coords = coords.to(DEVICE)
            
            h2_forms = h2_net(coords)
            h3_forms = h3_net(coords)
            
            for alpha in range(21):
                for beta in range(21):
                    for gamma in range(77):
                        # Wedge product approximation
                        h2_a = h2_forms[:, alpha, :]
                        h2_b = h2_forms[:, beta, :]
                        h3_g = h3_forms[:, gamma, :]
                        
                        wedge = (torch.norm(h2_a, dim=-1) * 
                                torch.norm(h2_b, dim=-1) * 
                                torch.norm(h3_g, dim=-1))
                        
                        yukawa[alpha, beta, gamma] += wedge.mean()
    
    yukawa = yukawa / n_batches
    return yukawa

yukawa_tensor = compute_yukawa_simplified(
    h2_network, h3_network,
    n_samples=CONFIG['yukawa_computation']['n_mc_samples']
)

print(f'\nYukawa tensor computed')
print(f'  Shape: {yukawa_tensor.shape}')
print(f'  Mean coupling: {yukawa_tensor.abs().mean():.6e}')
print(f'  Max coupling: {yukawa_tensor.abs().max():.6e}')

# Save Yukawa tensor
yukawa_file = RESULTS_DIR / 'yukawa_tensor.pt'
torch.save(yukawa_tensor, yukawa_file)
print(f'  Saved to: {yukawa_file}')
print('='*60)

## Final Summary

Complete training summary and file locations.

In [None]:
# ============================================================
# FINAL SUMMARY
# ============================================================

summary = {
    'version': CONFIG['version'],
    'training': {
        'total_epochs': CONFIG['training']['total_epochs'],
        'training_time_hours': training_time / 3600,
        'start_epoch': start_epoch
    },
    'final_metrics': metrics,
    'targets_achieved': {
        'torsion_closure': torsion_closure.item() < 1e-3,
        'rank_h2': rank_h2 == 21,
        'rank_h3': rank_h3 == 77
    },
    'files': {
        'final_checkpoint': str(final_checkpoint),
        'history': str(history_file),
        'yukawa_tensor': str(yukawa_file),
        'plots': str(plot_file)
    }
}

summary_file = RESULTS_DIR / 'training_summary.json'
with open(summary_file, 'w') as f:
    json.dump(summary, f, indent=2)

print('\n' + '='*60)
print('TRAINING SUMMARY')
print('='*60)
print(f'Version: {CONFIG["version"]}')
print(f'Total epochs: {CONFIG["training"]["total_epochs"]}')
print(f'Training time: {training_time/3600:.2f} hours')
print(f'\nFinal Metrics:')
print(f'  Torsion closure: {torsion_closure:.6e}' + 
      f' [{"✓" if torsion_closure.item() < 1e-3 else "✗"}]')
print(f'  Rank H²: {rank_h2}/21' + f' [{"✓" if rank_h2 == 21 else "✗"}]')
print(f'  Rank H³: {rank_h3}/77' + f' [{"✓" if rank_h3 == 77 else "✗"}]')
print(f'\nOutput Files:')
print(f'  Checkpoints: {CHECKPOINT_DIR}/')
print(f'  Results: {RESULTS_DIR}/')
print(f'  Summary: {summary_file}')
print('\n' + '='*60)
print('IMPORTANT: Download all files before Colab session ends!')
print('='*60)

## Download Files

Download trained models and results.

In [None]:
# ============================================================
# DOWNLOAD RESULTS
# ============================================================

# Uncomment to download files

# from google.colab import files

# # Download final checkpoint
# files.download(str(final_checkpoint))

# # Download history and summary
# files.download(str(history_file))
# files.download(str(summary_file))

# # Download Yukawa tensor
# files.download(str(yukawa_file))

# # Download plots
# files.download(str(plot_file))

print('\nTo download files, uncomment the code above and run this cell.')
print('\nAlternatively, use the Files panel on the left to download manually.')