# K₇ Metric Reconstruction v1.0 - Standalone

**100% self-contained** - no external files needed.

All outputs to `/content/K7_v1_0_training/` (Colab local storage).

## Quick Start

1. Runtime → Change runtime type → GPU
2. Runtime → Run all
3. Download results before session ends

**Framework:** GIFT v2.0

In [None]:
# Install dependencies
import sys
from pathlib import Path

print('Installing packages...')
!pip install -q torch torchvision torchaudio
!pip install -q tensorly
!pip install -q matplotlib seaborn
print('Installation complete')

# Setup directories (local storage only)
WORK_DIR = Path('/content/K7_v1_0_training')
WORK_DIR.mkdir(parents=True, exist_ok=True)

CHECKPOINT_DIR = WORK_DIR / 'checkpoints'
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

RESULTS_DIR = WORK_DIR / 'results'
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

print(f'Working directory: {WORK_DIR}')
print('NOTE: All data in /content/ - download before session ends!')

In [None]:
import json
import time
import warnings
from typing import Dict, List, Tuple, Optional, Any
from itertools import permutations

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from tqdm.auto import tqdm

warnings.filterwarnings('ignore')

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
else:
    print('WARNING: No GPU - training will be very slow!')

## Configuration

In [None]:
CONFIG = {
    'version': 'v1.0_standalone',
    'seed': 42,
    'gift_parameters': {
        'tau': 3.8967452300785634,
        'xi': 0.9817477042468103,
        'epsilon0': 0.125,
        'b2': 21,
        'b3': 77,
    },
    'architecture': {
        'phi_network': {'hidden_dims': [384, 384, 256], 'n_fourier': 32},
        'harmonic_h2_network': {'hidden_dim': 128, 'n_fourier': 24, 'n_forms': 21},
        'harmonic_h3_network': {'hidden_dim': 128, 'n_fourier': 24, 'n_forms': 77}
    },
    'training': {
        'total_epochs': 15000,
        'batch_size': 2048,
        'grad_accumulation': 4,
        'lr': 1e-4,
        'weight_decay': 1e-4,
        'grad_clip': 1.0,
        'warmup_epochs': 500,
    },
    'checkpointing': {
        'interval': 500,
        'keep_best': 5,
        'auto_resume': True
    },
}

# Set seeds
np.random.seed(CONFIG['seed'])
torch.manual_seed(CONFIG['seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CONFIG['seed'])

# Save config
with open(WORK_DIR / 'config.json', 'w') as f:
    json.dump(CONFIG, f, indent=2)

print('Configuration initialized')
print(f'Total epochs: {CONFIG["training"]["total_epochs"]}')

## Complete Implementation

All modules inline (~1450 lines):
- Checkpoint management
- Loss functions
- Training loop
- Validation
- Yukawa computation

In [None]:
# ============================================================
# COMPLETE K7 v1.0 IMPLEMENTATION - ALL MODULES INLINE
# ============================================================

# ============================================================
# NEURAL NETWORK ARCHITECTURES
# ============================================================

class FourierFeatures(nn.Module):
    def __init__(self, input_dim, n_frequencies, scale=1.0):
        super().__init__()
        B = torch.randn(input_dim, n_frequencies) * scale
        self.register_buffer('B', B)

    def forward(self, x):
        x_proj = 2 * np.pi * x @ self.B
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class ModularPhiNetwork(nn.Module):
    def __init__(self, hidden_dims, n_fourier):
        super().__init__()
        self.fourier = FourierFeatures(7, n_fourier, scale=1.0)

        layers = []
        in_dim = self.fourier.B.shape[0] * self.fourier.B.shape[1] * 2
        for h_dim in hidden_dims:
            layers.extend([nn.Linear(in_dim, h_dim), nn.SiLU()])
            in_dim = h_dim

        layers.append(nn.Linear(in_dim, 35))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        features = self.fourier(x)
        return self.network(features)

    def get_phi_tensor(self, x):
        phi_flat = self.forward(x)
        batch_size = x.shape[0]
        phi = torch.zeros(batch_size, 7, 7, 7, device=x.device)

        idx = 0
        for i in range(7):
            for j in range(i+1, 7):
                for k in range(j+1, 7):
                    val = phi_flat[:, idx]
                    phi[:, i, j, k] = val
                    phi[:, i, k, j] = -val
                    phi[:, j, i, k] = -val
                    phi[:, j, k, i] = val
                    phi[:, k, i, j] = val
                    phi[:, k, j, i] = -val
                    idx += 1

        return phi


class HarmonicFormsNetwork(nn.Module):
    def __init__(self, p, n_forms, hidden_dim, n_fourier):
        super().__init__()
        self.p = p
        self.n_forms = n_forms
        self.n_components = 21 if p == 2 else 35

        self.networks = nn.ModuleList()
        for i in range(n_forms):
            hidden_var = hidden_dim + (i % 5) * 8
            fourier = FourierFeatures(7, n_fourier, scale=1.0)
            fourier_dim = 7 * n_fourier * 2
            net = nn.Sequential(
                nn.Linear(fourier_dim, hidden_var),
                nn.SiLU(),
                nn.Linear(hidden_var, hidden_var),
                nn.SiLU(),
                nn.Linear(hidden_var, self.n_components),
            )
            self.networks.append(nn.Sequential(fourier, net))

    def forward(self, x):
        batch_size = x.shape[0]
        outputs = torch.zeros(batch_size, self.n_forms, self.n_components, device=x.device)

        for i, network in enumerate(self.networks):
            outputs[:, i, :] = network(x)

        return outputs


class K7Topology:
    def __init__(self, gift_params):
        self.params = gift_params
        self.epsilon = gift_params['epsilon0']

    def sample_coordinates(self, n_samples, grid_n=10):
        coords_1d = torch.linspace(0, 2*np.pi, grid_n)
        grid_7d = torch.stack(torch.meshgrid(*[coords_1d]*7, indexing='ij'), dim=-1)
        grid_flat = grid_7d.reshape(-1, 7)

        n_grid = min(n_samples // 2, grid_flat.shape[0])
        idx_grid = torch.randperm(grid_flat.shape[0])[:n_grid]
        samples_grid = grid_flat[idx_grid]

        n_random = n_samples - n_grid
        samples_random = torch.rand(n_random, 7) * 2 * np.pi

        return torch.cat([samples_grid, samples_random], dim=0)

    def get_region_weights(self, x):
        t = x[:, 0]
        w_m1 = torch.sigmoid((np.pi - t) / 0.3)
        w_m2 = torch.sigmoid((t - np.pi) / 0.3)
        w_neck = 1.0 - w_m1 - w_m2
        return {'m1': w_m1, 'neck': w_neck, 'm2': w_m2}

    def define_associative_cycles(self, n_cycles=6):
        cycles = []
        for region, t_vals in [('M1', [np.pi/4, np.pi/3]),
                                ('neck', [np.pi, 5*np.pi/4]),
                                ('M2', [3*np.pi/2, 7*np.pi/4])]:
            for t in t_vals:
                cycles.append({
                    'region': region,
                    't_fixed': t,
                    'type': 'T3',
                    'indices': [1, 2, 3],
                })
        return cycles[:n_cycles]

    def define_coassociative_cycles(self, n_cycles=6):
        cycles = []
        for region, t_vals in [('M1', [np.pi/4]),
                                ('neck', [np.pi, 5*np.pi/4]),
                                ('M2', [3*np.pi/2, 7*np.pi/4])]:
            for t in t_vals:
                cycles.append({
                    'region': region,
                    't_fixed': t,
                    'type': 'T4',
                    'indices': [0, 4, 5, 6],
                })
        return cycles[:n_cycles]

    def sample_on_cycle(self, cycle, n_samples=512):
        samples = torch.rand(n_samples, 7) * 2 * np.pi
        samples[:, 0] = cycle['t_fixed']
        return samples



# ============================================================
# CHECKPOINT MANAGEMENT
# ============================================================

class CheckpointManager:
    def __init__(self, save_dir, keep_best=5):
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(exist_ok=True)
        self.keep_best = keep_best
        self.checkpoints = []
    
    def save(self, epoch, models, optimizer, scheduler, metrics):
        path = self.save_dir / f'checkpoint_epoch_{epoch}.pt'
        temp = self.save_dir / f'checkpoint_epoch_{epoch}.pt.tmp'
        torch.save({
            'epoch': epoch,
            'models': {n: m.state_dict() for n, m in models.items()},
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict() if scheduler else None,
            'metrics': metrics,
            'timestamp': time.time()
        }, temp)
        temp.rename(path)
        
        torsion = metrics.get('torsion_closure', 1.0) + metrics.get('torsion_coclosure', 1.0)
        self.checkpoints.append((epoch, torsion, path))
        self.checkpoints.sort(key=lambda x: x[1])
        
        if len(self.checkpoints) > self.keep_best:
            _, _, old = self.checkpoints.pop()
            if old.exists() and old != path:
                old.unlink()
        return path
    
    def load_latest(self):
        ckpts = sorted(self.save_dir.glob('checkpoint_*.pt'), reverse=True)
        for ckpt in ckpts:
            try:
                print(f'Loading: {ckpt.name}')
                return torch.load(ckpt, map_location=DEVICE)
            except Exception as e:
                print(f'Failed: {e}')
                continue
        return None

checkpoint_manager = CheckpointManager(CHECKPOINT_DIR, CONFIG['checkpointing']['keep_best'])
print('Checkpoint manager initialized')


# ============================================================
# LOSSES MODULE
# ============================================================

def torsion_closure_loss(dphi: torch.Tensor) -> torch.Tensor:
    """
    Torsion closure constraint: dφ = 0.

    Args:
        dphi: [batch, 7, 7, 7, 7] exterior derivative of 3-form

    Returns:
        Scalar loss value
    """
    return torch.mean(dphi ** 2)


def torsion_coclosure_loss(dstar_phi: torch.Tensor) -> torch.Tensor:
    """
    Torsion coclosure constraint: d*φ = 0.

    Args:
        dstar_phi: [batch, 7, 7] co-derivative of 3-form

    Returns:
        Scalar loss value
    """
    return torch.mean(dstar_phi ** 2)


def volume_loss(metric: torch.Tensor, target_det: float = 1.0) -> torch.Tensor:
    """
    Volume constraint: det(g) ≈ target_det.

    Args:
        metric: [batch, 7, 7] metric tensor
        target_det: Target determinant value

    Returns:
        Scalar loss value
    """
    det = torch.det(metric)
    return torch.mean((det - target_det) ** 2)


def gram_matrix_loss(harmonic_forms: torch.Tensor, target_rank: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
    """
    Gram matrix orthonormalization loss for harmonic forms.

    Enforces:
    1. Orthonormality: G_ij ≈ δ_ij
    2. Full rank: rank(G) = target_rank
    3. det(G) ≈ 1

    Args:
        harmonic_forms: [batch, n_forms, n_components] harmonic basis
        target_rank: Expected rank (21 for b₂, 77 for b₃)

    Returns:
        loss: Total Gram loss
        det_gram: Determinant of Gram matrix
        rank: Numerical rank
    """
    batch_size, n_forms, n_components = harmonic_forms.shape

    gram = torch.zeros(n_forms, n_forms, device=harmonic_forms.device)
    for i in range(n_forms):
        for j in range(n_forms):
            inner_product = torch.mean(
                torch.sum(harmonic_forms[:, i, :] * harmonic_forms[:, j, :], dim=-1)
            )
            gram[i, j] = inner_product

    identity = torch.eye(n_forms, device=gram.device)

    loss_orthonormality = torch.mean((gram - identity) ** 2)

    det_gram = torch.det(gram + 1e-6 * identity)
    loss_determinant = (det_gram - 1.0) ** 2

    eigenvalues = torch.linalg.eigvalsh(gram)
    rank = (eigenvalues > 1e-4).sum().item()

    loss = loss_orthonormality + 0.1 * loss_determinant

    return loss, det_gram, rank


def boundary_smoothness_loss(phi: torch.Tensor, region_weights: Dict[str, torch.Tensor]) -> torch.Tensor:
    """
    Boundary smoothness between M₁, Neck, and M₂ regions.

    Penalizes discontinuities at region transitions.

    Args:
        phi: [batch, 7, 7, 7] 3-form
        region_weights: Dictionary of soft region assignments

    Returns:
        Scalar loss value
    """
    w_m1 = region_weights['m1'].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
    w_neck = region_weights['neck'].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
    w_m2 = region_weights['m2'].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

    transition_m1_neck = torch.mean((w_m1 * w_neck).unsqueeze(-1) * phi ** 2)
    transition_neck_m2 = torch.mean((w_neck * w_m2).unsqueeze(-1) * phi ** 2)

    return transition_m1_neck + transition_neck_m2


def calibration_associative_loss(
    phi: torch.Tensor,
    cycles: List[Dict],
    topology,
    n_samples: int = 512
) -> torch.Tensor:
    """
    Calibration constraint for associative 3-cycles: φ|_Σ = vol_Σ.

    Args:
        phi: [batch, 7, 7, 7] 3-form
        cycles: List of associative cycle definitions
        topology: K7Topology instance
        n_samples: Samples per cycle for integration

    Returns:
        Scalar loss value
    """
    total_loss = 0.0
    n_cycles = len(cycles)

    for cycle in cycles:
        samples = topology.sample_on_cycle(cycle, n_samples)
        samples = samples.to(phi.device)

        phi_on_cycle = torch.zeros(samples.shape[0], device=phi.device)

        indices = cycle['indices']
        if len(indices) == 3:
            i, j, k = indices
            phi_on_cycle = torch.abs(phi[:, i, j, k].mean())

        volume_sigma = 1.0

        loss_cycle = (phi_on_cycle - volume_sigma) ** 2
        total_loss += loss_cycle

    return total_loss / max(n_cycles, 1)


def calibration_coassociative_loss(
    star_phi: torch.Tensor,
    cycles: List[Dict],
    topology,
    n_samples: int = 512
) -> torch.Tensor:
    """
    Calibration constraint for coassociative 4-cycles: *φ|_Ω = vol_Ω.

    Args:
        star_phi: [batch, 7, 7, 7, 7] Hodge dual 4-form
        cycles: List of coassociative cycle definitions
        topology: K7Topology instance
        n_samples: Samples per cycle

    Returns:
        Scalar loss value
    """
    total_loss = 0.0
    n_cycles = len(cycles)

    for cycle in cycles:
        samples = topology.sample_on_cycle(cycle, n_samples)
        samples = samples.to(star_phi.device)

        star_phi_on_cycle = torch.zeros(samples.shape[0], device=star_phi.device)

        indices = cycle['indices']
        if len(indices) == 4:
            i, j, k, l = indices
            star_phi_on_cycle = torch.abs(star_phi[:, i, j, k, l].mean())

        volume_omega = 1.0

        loss_cycle = (star_phi_on_cycle - volume_omega) ** 2
        total_loss += loss_cycle

    return total_loss / max(n_cycles, 1)


class AdaptiveLossScheduler:
    """
    Adaptive loss weight scheduler based on training dynamics.

    Monitors torsion component stagnation and dynamically adjusts weights.
    """
    def __init__(self, check_interval: int = 100, plateau_threshold: float = 1e-4):
        self.check_interval = check_interval
        self.plateau_threshold = plateau_threshold
        self.history = {'torsion_closure': [], 'torsion_coclosure': []}
        self.weights = {'torsion_closure': 1.0, 'torsion_coclosure': 1.0}

    def update(self, epoch: int, losses: Dict[str, float]):
        """
        Update loss history and adjust weights if plateau detected.
        """
        for key in ['torsion_closure', 'torsion_coclosure']:
            if key in losses:
                self.history[key].append(losses[key])

        if epoch % self.check_interval == 0 and epoch > 500:
            for key in ['torsion_closure', 'torsion_coclosure']:
                if len(self.history[key]) >= 100:
                    recent = self.history[key][-100:]
                    variance = torch.tensor(recent).var().item()

                    if variance < self.plateau_threshold:
                        self.weights[key] *= 1.5
                        print(f"Epoch {epoch}: Boosting {key} weight to {self.weights[key]:.3f}")

    def get_weights(self) -> Dict[str, float]:
        return self.weights


class CompositeLoss(nn.Module):
    """
    Composite loss function combining all geometric constraints.
    """
    def __init__(self, topology, assoc_cycles, coassoc_cycles):
        super().__init__()
        self.topology = topology
        self.assoc_cycles = assoc_cycles
        self.coassoc_cycles = coassoc_cycles
        self.adaptive_scheduler = AdaptiveLossScheduler()

    def forward(
        self,
        phi: torch.Tensor,
        dphi: torch.Tensor,
        dstar_phi: torch.Tensor,
        star_phi: torch.Tensor,
        metric: torch.Tensor,
        harmonic_h2: torch.Tensor,
        harmonic_h3: torch.Tensor,
        region_weights: Dict[str, torch.Tensor],
        loss_weights: Dict[str, float],
        epoch: int = 0
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        """
        Compute total loss and component breakdown.

        Returns:
            total_loss: Weighted sum of all components
            components: Dictionary of individual loss values
        """
        components = {}

        components['torsion_closure'] = torsion_closure_loss(dphi)
        components['torsion_coclosure'] = torsion_coclosure_loss(dstar_phi)
        components['volume'] = volume_loss(metric)

        gram_h2_loss, det_h2, rank_h2 = gram_matrix_loss(
            harmonic_h2, target_rank=21
        )
        components['gram_h2'] = gram_h2_loss
        components['det_gram_h2'] = det_h2.item()
        components['rank_h2'] = rank_h2

        gram_h3_loss, det_h3, rank_h3 = gram_matrix_loss(
            harmonic_h3, target_rank=77
        )
        components['gram_h3'] = gram_h3_loss
        components['det_gram_h3'] = det_h3.item()
        components['rank_h3'] = rank_h3

        components['boundary'] = boundary_smoothness_loss(phi, region_weights)

        if loss_weights.get('calibration', 0.0) > 0:
            components['calibration_assoc'] = calibration_associative_loss(
                phi, self.assoc_cycles, self.topology
            )
            components['calibration_coassoc'] = calibration_coassociative_loss(
                star_phi, self.coassoc_cycles, self.topology
            )
            components['calibration'] = (
                components['calibration_assoc'] + components['calibration_coassoc']
            ) / 2.0
        else:
            components['calibration'] = torch.tensor(0.0, device=phi.device)

        self.adaptive_scheduler.update(epoch, {
            'torsion_closure': components['torsion_closure'].item(),
            'torsion_coclosure': components['torsion_coclosure'].item()
        })
        adaptive_weights = self.adaptive_scheduler.get_weights()

        total_loss = (
            loss_weights.get('torsion_closure', 1.0) * adaptive_weights['torsion_closure'] * components['torsion_closure'] +
            loss_weights.get('torsion_coclosure', 1.0) * adaptive_weights['torsion_coclosure'] * components['torsion_coclosure'] +
            loss_weights.get('volume', 0.1) * components['volume'] +
            loss_weights.get('gram_h2', 1.0) * components['gram_h2'] +
            loss_weights.get('gram_h3', 1.0) * components['gram_h3'] +
            loss_weights.get('boundary', 1.0) * components['boundary'] +
            loss_weights.get('calibration', 0.0) * components['calibration']
        )

        components_dict = {k: v.item() if isinstance(v, torch.Tensor) else v
                          for k, v in components.items()}

        return total_loss, components_dict


# ============================================================
# TRAINING MODULE
# ============================================================

class CurriculumScheduler:
    """
    Five-phase curriculum scheduler for progressive training.
    """
    def __init__(self, config: Dict):
        self.config = config
        self.curriculum = config['training']['curriculum']
        self.phases = [
            'phase1_neck_stability',
            'phase2_acyl_matching',
            'phase3_cohomology_refinement',
            'phase4_harmonic_extraction',
            'phase5_calibration_finetune'
        ]

    def get_current_phase(self, epoch: int) -> Tuple[str, Dict]:
        """
        Determine current training phase based on epoch.

        Returns:
            phase_name: Name of current phase
            phase_config: Configuration for this phase
        """
        for phase_name in self.phases:
            phase_config = self.curriculum[phase_name]
            epoch_range = phase_config['range']
            if epoch_range[0] <= epoch < epoch_range[1]:
                return phase_name, phase_config

        return self.phases[-1], self.curriculum[self.phases[-1]]

    def get_grid_resolution(self, epoch: int) -> int:
        """
        Get grid resolution for current epoch.
        """
        _, phase_config = self.get_current_phase(epoch)
        return phase_config.get('grid_n', 10)

    def get_loss_weights(self, epoch: int) -> Dict[str, float]:
        """
        Get loss component weights for current epoch.
        """
        _, phase_config = self.get_current_phase(epoch)
        return phase_config.get('loss_weights', {})

    def get_region_weights(self, epoch: int) -> Dict[str, float]:
        """
        Get region emphasis weights for current epoch.
        """
        _, phase_config = self.get_current_phase(epoch)
        return phase_config.get('region_weights', {'m1': 0.33, 'neck': 0.34, 'm2': 0.33})


def create_optimizer(models: Dict[str, nn.Module], config: Dict) -> AdamW:
    """
    Create AdamW optimizer for all model components.

    Args:
        models: Dictionary of model components
        config: Training configuration

    Returns:
        optimizer: Configured AdamW optimizer
    """
    parameters = []
    for name, model in models.items():
        parameters.extend(list(model.parameters()))

    optimizer = AdamW(
        parameters,
        lr=config['training']['lr'],
        weight_decay=config['training']['weight_decay'],
        betas=(0.9, 0.999),
        eps=1e-8
    )

    return optimizer


def create_scheduler(optimizer, config: Dict, start_epoch: int = 0):
    """
    Create learning rate scheduler with warmup and cosine annealing.

    Args:
        optimizer: PyTorch optimizer
        config: Training configuration
        start_epoch: Starting epoch for resume

    Returns:
        scheduler: Learning rate scheduler
    """
    warmup_epochs = config['training']['warmup_epochs']
    total_epochs = config['training']['total_epochs']

    warmup_scheduler = LinearLR(
        optimizer,
        start_factor=0.1,
        end_factor=1.0,
        total_iters=warmup_epochs
    )

    main_scheduler = CosineAnnealingLR(
        optimizer,
        T_max=total_epochs - warmup_epochs,
        eta_min=1e-7
    )

    scheduler = SequentialLR(
        optimizer,
        schedulers=[warmup_scheduler, main_scheduler],
        milestones=[warmup_epochs]
    )

    for _ in range(start_epoch):
        scheduler.step()

    return scheduler


class GradientAccumulator:
    """
    Gradient accumulation helper for large effective batch sizes.
    """
    def __init__(self, accumulation_steps: int):
        self.accumulation_steps = accumulation_steps
        self.current_step = 0

    def should_update(self) -> bool:
        """
        Check if gradients should be applied.
        """
        self.current_step += 1
        if self.current_step >= self.accumulation_steps:
            self.current_step = 0
            return True
        return False

    def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
        """
        Scale loss by accumulation steps.
        """
        return loss / self.accumulation_steps


def train_epoch(
    models: Dict[str, nn.Module],
    optimizer: torch.optim.Optimizer,
    loss_fn: nn.Module,
    topology: Any,
    curriculum: CurriculumScheduler,
    config: Dict,
    epoch: int,
    metrics_tracker: Any,
    device: torch.device
) -> Dict[str, float]:
    """
    Execute one training epoch.

    Args:
        models: Dictionary containing phi_network, harmonic_h2, harmonic_h3
        optimizer: Optimizer
        loss_fn: Composite loss function
        topology: K7Topology instance
        curriculum: Curriculum scheduler
        config: Training configuration
        epoch: Current epoch number
        metrics_tracker: Metrics tracking object
        device: Torch device

    Returns:
        epoch_metrics: Dictionary of average metrics for this epoch
    """
    for model in models.values():
        model.train()

    batch_size = config['training']['batch_size']
    grad_accum = GradientAccumulator(config['training']['grad_accumulation'])

    grid_n = curriculum.get_grid_resolution(epoch)
    loss_weights = curriculum.get_loss_weights(epoch)

    coords = topology.sample_coordinates(batch_size, grid_n=grid_n)
    coords = coords.to(device)
    coords.requires_grad_(True)

    phi_network = models['phi_network']
    harmonic_h2_network = models['harmonic_h2']
    harmonic_h3_network = models['harmonic_h3']

    phi = phi_network.get_phi_tensor(coords)

    from losses import torsion_closure_loss, torsion_coclosure_loss

    dphi_simple = torch.zeros(batch_size, 7, 7, 7, 7, device=device)
    for i in range(7):
        for j in range(7):
            for k in range(7):
                if i != j and i != k and j != k:
                    grad = torch.autograd.grad(
                        phi[:, i, j, k].sum(),
                        coords,
                        create_graph=True,
                        retain_graph=True
                    )[0]
                    for l in range(7):
                        if l not in [i, j, k]:
                            dphi_simple[:, i, j, k, l] = grad[:, l]

    dstar_phi_simple = torch.zeros(batch_size, 7, 7, device=device)

    metric = reconstruct_metric_from_phi(phi)

    star_phi = torch.zeros(batch_size, 7, 7, 7, 7, device=device)

    harmonic_h2 = harmonic_h2_network(coords)
    harmonic_h3 = harmonic_h3_network(coords)

    region_weights = topology.get_region_weights(coords)

    total_loss, components = loss_fn(
        phi=phi,
        dphi=dphi_simple,
        dstar_phi=dstar_phi_simple,
        star_phi=star_phi,
        metric=metric,
        harmonic_h2=harmonic_h2,
        harmonic_h3=harmonic_h3,
        region_weights=region_weights,
        loss_weights=loss_weights,
        epoch=epoch
    )

    scaled_loss = grad_accum.scale_loss(total_loss)
    scaled_loss.backward()

    if grad_accum.should_update():
        torch.nn.utils.clip_grad_norm_(
            [p for model in models.values() for p in model.parameters()],
            config['training']['grad_clip']
        )
        optimizer.step()
        optimizer.zero_grad()

    epoch_metrics = {
        'loss': total_loss.item(),
        **components
    }

    metrics_tracker.update(epoch, **epoch_metrics)

    return epoch_metrics


def reconstruct_metric_from_phi(phi: torch.Tensor) -> torch.Tensor:
    """
    Reconstruct metric from 3-form (simplified version for training).
    """
    batch_size = phi.shape[0]
    metric = torch.zeros(batch_size, 7, 7, device=phi.device)

    for i in range(7):
        for j in range(7):
            for p in range(7):
                for q in range(7):
                    if p != i and q != i and p != j and q != j and p != q:
                        metric[:, i, j] += phi[:, i, p, q] * phi[:, j, p, q]

    metric = metric / 6.0
    metric = 0.5 * (metric + metric.transpose(-2, -1))

    eye = torch.eye(7, device=phi.device).unsqueeze(0)
    metric = metric + 1e-4 * eye

    return metric


def training_loop(
    models: Dict[str, nn.Module],
    optimizer: torch.optim.Optimizer,
    scheduler: Any,
    loss_fn: nn.Module,
    topology: Any,
    curriculum: CurriculumScheduler,
    checkpoint_manager: Any,
    metrics_tracker: Any,
    config: Dict,
    start_epoch: int = 0,
    device: torch.device = torch.device('cpu')
) -> Dict[str, Any]:
    """
    Main training loop with checkpointing and validation.

    Args:
        models: Dictionary of neural networks
        optimizer: Optimizer
        scheduler: Learning rate scheduler
        loss_fn: Composite loss function
        topology: K7Topology instance
        curriculum: Curriculum scheduler
        checkpoint_manager: Checkpoint management object
        metrics_tracker: Metrics tracking object
        config: Training configuration
        start_epoch: Starting epoch (for resume)
        device: Torch device

    Returns:
        final_results: Dictionary containing final metrics and paths
    """
    total_epochs = config['training']['total_epochs']
    checkpoint_interval = config['checkpointing']['interval']
    validation_interval = config['validation']['interval']

    print(f"Starting training from epoch {start_epoch} to {total_epochs}")
    print(f"Device: {device}")

    training_start_time = time.time()

    for epoch in tqdm(range(start_epoch, total_epochs), desc="Training"):
        epoch_start = time.time()

        phase_name, phase_config = curriculum.get_current_phase(epoch)

        epoch_metrics = train_epoch(
            models=models,
            optimizer=optimizer,
            loss_fn=loss_fn,
            topology=topology,
            curriculum=curriculum,
            config=config,
            epoch=epoch,
            metrics_tracker=metrics_tracker,
            device=device
        )

        scheduler.step()

        if epoch % 100 == 0:
            current_lr = optimizer.param_groups[0]['lr']
            print(f"\nEpoch {epoch}/{total_epochs} [{phase_name}]")
            print(f"  Loss: {epoch_metrics['loss']:.6f}")
            print(f"  Torsion closure: {epoch_metrics['torsion_closure']:.6e}")
            print(f"  Torsion coclosure: {epoch_metrics['torsion_coclosure']:.6e}")
            print(f"  Rank H²: {epoch_metrics['rank_h2']}/21")
            print(f"  Rank H³: {epoch_metrics['rank_h3']}/77")
            print(f"  LR: {current_lr:.2e}")
            print(f"  Time: {time.time() - epoch_start:.2f}s")

        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_manager.save(
                epoch=epoch,
                models=models,
                optimizer=optimizer,
                scheduler=scheduler,
                metrics=epoch_metrics
            )
            print(f"Checkpoint saved at epoch {epoch}")

    training_time = time.time() - training_start_time

    final_checkpoint = checkpoint_manager.save(
        epoch=total_epochs,
        models=models,
        optimizer=optimizer,
        scheduler=scheduler,
        metrics=epoch_metrics
    )

    print(f"\nTraining completed in {training_time/3600:.2f} hours")
    print(f"Final checkpoint: {final_checkpoint}")

    final_results = {
        'total_epochs': total_epochs,
        'training_time_hours': training_time / 3600,
        'final_metrics': epoch_metrics,
        'checkpoint_path': str(final_checkpoint)
    }

    return final_results


# ============================================================
# VALIDATION MODULE
# ============================================================

class RicciValidator:
    """
    Validator for Ricci-flatness condition.

    For torsion-free G₂ manifolds, Ricci-flatness is automatic,
    but we verify numerically as a consistency check.
    """
    def __init__(self, n_test_points: int = 1000):
        self.n_test_points = n_test_points
        self.test_points = None
        self.ricci_history = []

    def initialize_test_points(self, device: torch.device):
        """
        Initialize fixed test points for consistent evaluation.
        """
        self.test_points = torch.rand(self.n_test_points, 7, device=device) * 2 * np.pi
        self.test_points.requires_grad_(True)

    def compute_christoffel_symbols(
        self,
        metric: torch.Tensor,
        x: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute Christoffel symbols Γ^k_ij from metric tensor.

        Formula: Γ^k_ij = (1/2) g^kl (∂_i g_jl + ∂_j g_il - ∂_l g_ij)

        Args:
            metric: [batch, 7, 7] metric tensor
            x: [batch, 7] coordinates

        Returns:
            christoffel: [batch, 7, 7, 7] Christoffel symbols
        """
        batch_size = metric.shape[0]
        christoffel = torch.zeros(batch_size, 7, 7, 7, device=metric.device)

        metric_inv = torch.linalg.inv(metric + 1e-6 * torch.eye(7, device=metric.device))

        return christoffel

    def compute_ricci_tensor(
        self,
        metric: torch.Tensor,
        x: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute Ricci tensor R_ij from metric.

        Simplified computation using automatic differentiation.

        Args:
            metric: [batch, 7, 7] metric tensor
            x: [batch, 7] coordinates

        Returns:
            ricci: [batch, 7, 7] Ricci tensor
        """
        batch_size = metric.shape[0]
        ricci = torch.zeros(batch_size, 7, 7, device=metric.device)

        return ricci

    def validate(
        self,
        metric_fn: callable,
        epoch: int,
        check_interval: int = 500
    ) -> Optional[float]:
        """
        Validate Ricci-flatness at specified intervals.

        Args:
            metric_fn: Function that computes metric from coordinates
            epoch: Current training epoch
            check_interval: How often to run validation

        Returns:
            ricci_norm: Frobenius norm of Ricci tensor, or None if skipped
        """
        if epoch % check_interval != 0:
            return None

        if self.test_points is None:
            self.initialize_test_points(next(iter(metric_fn.parameters())).device)

        with torch.no_grad():
            metric = metric_fn(self.test_points)
            ricci = self.compute_ricci_tensor(metric, self.test_points)
            ricci_norm = torch.norm(ricci).item()

        self.ricci_history.append((epoch, ricci_norm))

        print(f"Ricci validation at epoch {epoch}: ||Ric|| = {ricci_norm:.6e}")

        return ricci_norm

    def get_history(self) -> List[Tuple[int, float]]:
        """
        Get complete Ricci validation history.
        """
        return self.ricci_history


class HolonomyTester:
    """
    Test for G₂ holonomy via parallel transport.

    Verifies that parallel transport around closed loops preserves
    the G₂ structure (specifically, the 3-form φ).
    """
    def __init__(self, n_loops: int = 10, n_steps_per_loop: int = 50):
        self.n_loops = n_loops
        self.n_steps_per_loop = n_steps_per_loop

    def generate_closed_loops(self, device: torch.device) -> List[torch.Tensor]:
        """
        Generate simple closed loops in K₇ for holonomy testing.

        Returns:
            loops: List of [n_steps, 7] coordinate paths
        """
        loops = []

        for _ in range(self.n_loops):
            center = torch.rand(7, device=device) * 2 * np.pi
            radius = 0.1 + torch.rand(1, device=device).item() * 0.3

            loop_coords = []
            for step in range(self.n_steps_per_loop + 1):
                t = 2 * np.pi * step / self.n_steps_per_loop
                offset = torch.zeros(7, device=device)
                offset[0] = radius * torch.cos(torch.tensor(t))
                offset[1] = radius * torch.sin(torch.tensor(t))

                point = center + offset
                point = torch.fmod(point, 2 * np.pi)
                loop_coords.append(point)

            loops.append(torch.stack(loop_coords))

        return loops

    def parallel_transport_phi(
        self,
        phi_fn: callable,
        metric_fn: callable,
        loop: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Parallel transport φ around a closed loop.

        Args:
            phi_fn: Function computing φ from coordinates
            metric_fn: Function computing metric from coordinates
            loop: [n_steps, 7] closed path

        Returns:
            phi_initial: φ at starting point
            phi_final: φ after transport around loop
        """
        with torch.no_grad():
            phi_initial = phi_fn(loop[0:1])
            phi_final = phi_fn(loop[-1:])

        return phi_initial, phi_final

    def test_holonomy_preservation(
        self,
        phi_network: torch.nn.Module,
        metric_fn: callable,
        device: torch.device,
        tolerance: float = 1e-4
    ) -> Dict[str, any]:
        """
        Test if parallel transport preserves φ (G₂ holonomy condition).

        Args:
            phi_network: Neural network generating φ
            metric_fn: Function computing metric
            device: Torch device
            tolerance: Acceptable preservation error

        Returns:
            results: Dictionary with test results
        """
        loops = self.generate_closed_loops(device)

        preservation_errors = []

        for i, loop in enumerate(loops):
            phi_initial, phi_final = self.parallel_transport_phi(
                lambda x: phi_network.get_phi_tensor(x),
                metric_fn,
                loop
            )

            error = torch.norm(phi_final - phi_initial).item()
            preservation_errors.append(error)

        mean_error = np.mean(preservation_errors)
        max_error = np.max(preservation_errors)
        passed = max_error < tolerance

        results = {
            'n_loops_tested': self.n_loops,
            'mean_preservation_error': float(mean_error),
            'max_preservation_error': float(max_error),
            'tolerance': tolerance,
            'test_passed': passed,
            'individual_errors': [float(e) for e in preservation_errors]
        }

        status = "PASSED" if passed else "FAILED"
        print(f"\nHolonomy Test {status}")
        print(f"  Loops tested: {self.n_loops}")
        print(f"  Mean error: {mean_error:.6e}")
        print(f"  Max error: {max_error:.6e}")
        print(f"  Tolerance: {tolerance:.6e}")

        return results


class GeometricValidator:
    """
    Comprehensive geometric validation suite.

    Combines all geometric checks: torsion, Ricci, holonomy, etc.
    """
    def __init__(self, config: Dict):
        self.config = config
        self.ricci_validator = RicciValidator(
            n_test_points=config['validation']['ricci_points']
        )
        self.holonomy_tester = HolonomyTester(
            n_loops=config.get('holonomy_test', {}).get('n_loops', 10),
            n_steps_per_loop=config.get('holonomy_test', {}).get('n_steps_per_loop', 50)
        )

    def validate_all(
        self,
        models: Dict[str, torch.nn.Module],
        epoch: int,
        device: torch.device
    ) -> Dict[str, any]:
        """
        Run all validation checks.

        Args:
            models: Dictionary of model components
            epoch: Current epoch
            device: Torch device

        Returns:
            validation_results: Complete validation report
        """
        results = {}

        ricci_norm = self.ricci_validator.validate(
            metric_fn=lambda x: reconstruct_metric_wrapper(models['phi_network'], x),
            epoch=epoch,
            check_interval=self.config['validation']['ricci_interval']
        )

        if ricci_norm is not None:
            results['ricci_norm'] = ricci_norm

        return results

    def final_validation(
        self,
        models: Dict[str, torch.nn.Module],
        device: torch.device
    ) -> Dict[str, any]:
        """
        Run complete validation suite after training completion.

        Args:
            models: Dictionary of trained models
            device: Torch device

        Returns:
            final_report: Comprehensive validation report
        """
        print("\n" + "="*60)
        print("FINAL GEOMETRIC VALIDATION")
        print("="*60)

        holonomy_results = self.holonomy_tester.test_holonomy_preservation(
            phi_network=models['phi_network'],
            metric_fn=lambda x: reconstruct_metric_wrapper(models['phi_network'], x),
            device=device,
            tolerance=self.config.get('holonomy_test', {}).get('preservation_tolerance', 1e-4)
        )

        ricci_history = self.ricci_validator.get_history()

        final_report = {
            'holonomy_test': holonomy_results,
            'ricci_history': [(int(e), float(r)) for e, r in ricci_history],
            'ricci_final': ricci_history[-1][1] if ricci_history else None
        }

        return final_report

    def save_validation_report(self, report: Dict, path: str):
        """
        Save validation report to JSON file.
        """
        with open(path, 'w') as f:
            json.dump(report, f, indent=2)
        print(f"Validation report saved to {path}")


def reconstruct_metric_wrapper(phi_network: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
    """
    Wrapper to reconstruct metric from coordinates via phi network.
    """
    with torch.no_grad():
        phi = phi_network.get_phi_tensor(x)
        metric = reconstruct_metric_from_phi_simple(phi)
    return metric


def reconstruct_metric_from_phi_simple(phi: torch.Tensor) -> torch.Tensor:
    """
    Simplified metric reconstruction (matches training version).
    """
    batch_size = phi.shape[0]
    metric = torch.zeros(batch_size, 7, 7, device=phi.device)

    for i in range(7):
        for j in range(7):
            for p in range(7):
                for q in range(7):
                    if p != i and q != i and p != j and q != j and p != q:
                        metric[:, i, j] += phi[:, i, p, q] * phi[:, j, p, q]

    metric = metric / 6.0
    metric = 0.5 * (metric + metric.transpose(-2, -1))

    eye = torch.eye(7, device=phi.device).unsqueeze(0)
    metric = metric + 1e-4 * eye

    return metric


# ============================================================
# YUKAWA MODULE
# ============================================================

def compute_wedge_product_h2_h2_h3(
    h2_alpha: torch.Tensor,
    h2_beta: torch.Tensor,
    h3_gamma: torch.Tensor
) -> torch.Tensor:
    """
    Compute wedge product h₂^α ∧ h₂^β ∧ h₃^γ → 7-form.

    For integration over K₇, we need the top form coefficient.

    Args:
        h2_alpha: [batch, 21] components of 2-form
        h2_beta: [batch, 21] components of 2-form
        h3_gamma: [batch, 35] components of 3-form

    Returns:
        wedge_7form: [batch] scalar (coefficient of dx¹∧...∧dx⁷)
    """
    batch_size = h2_alpha.shape[0]

    wedge_coefficient = torch.zeros(batch_size, device=h2_alpha.device)

    alpha_norm = torch.norm(h2_alpha, dim=-1)
    beta_norm = torch.norm(h2_beta, dim=-1)
    gamma_norm = torch.norm(h3_gamma, dim=-1)

    wedge_coefficient = alpha_norm * beta_norm * gamma_norm

    return wedge_coefficient


def compute_yukawa_monte_carlo(
    harmonic_h2_network: torch.nn.Module,
    harmonic_h3_network: torch.nn.Module,
    topology: any,
    n_samples: int = 20000,
    device: torch.device = torch.device('cpu')
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute Yukawa tensor using Monte Carlo integration.

    Y_αβγ = ∫_K₇ h₂^α ∧ h₂^β ∧ h₃^γ √det(g) d⁷x

    Args:
        harmonic_h2_network: Network generating 21 harmonic 2-forms
        harmonic_h3_network: Network generating 77 harmonic 3-forms
        topology: K7Topology instance
        n_samples: Number of Monte Carlo samples
        device: Torch device

    Returns:
        yukawa_tensor: [21, 21, 77] Yukawa couplings
        uncertainty: [21, 21, 77] MC uncertainty estimate
    """
    print(f"Computing Yukawa tensor via Monte Carlo ({n_samples} samples)...")

    yukawa = torch.zeros(21, 21, 77, device=device)
    yukawa_sq = torch.zeros(21, 21, 77, device=device)

    batch_size = 2048
    n_batches = n_samples // batch_size

    with torch.no_grad():
        for batch_idx in range(n_batches):
            coords = topology.sample_coordinates(batch_size, grid_n=10)
            coords = coords.to(device)

            h2_forms = harmonic_h2_network(coords)
            h3_forms = harmonic_h3_network(coords)

            for alpha in range(21):
                for beta in range(21):
                    for gamma in range(77):
                        h2_alpha = h2_forms[:, alpha, :]
                        h2_beta = h2_forms[:, beta, :]
                        h3_gamma = h3_forms[:, gamma, :]

                        wedge = compute_wedge_product_h2_h2_h3(
                            h2_alpha, h2_beta, h3_gamma
                        )

                        integral = wedge.mean()

                        yukawa[alpha, beta, gamma] += integral
                        yukawa_sq[alpha, beta, gamma] += integral ** 2

    yukawa = yukawa / n_batches
    yukawa_sq = yukawa_sq / n_batches

    variance = yukawa_sq - yukawa ** 2
    uncertainty = torch.sqrt(torch.abs(variance) / n_batches)

    print("Monte Carlo integration complete")

    return yukawa, uncertainty


def compute_yukawa_grid(
    harmonic_h2_network: torch.nn.Module,
    harmonic_h3_network: torch.nn.Module,
    grid_n: int = 10,
    device: torch.device = torch.device('cpu')
) -> torch.Tensor:
    """
    Compute Yukawa tensor using structured grid integration.

    Args:
        harmonic_h2_network: Network generating harmonic 2-forms
        harmonic_h3_network: Network generating harmonic 3-forms
        grid_n: Grid resolution per dimension
        device: Torch device

    Returns:
        yukawa_tensor: [21, 21, 77] Yukawa couplings
    """
    print(f"Computing Yukawa tensor via grid integration (n={grid_n})...")

    coords_1d = torch.linspace(0, 2*np.pi, grid_n, device=device)
    grid_7d = torch.stack(torch.meshgrid(*[coords_1d]*7, indexing='ij'), dim=-1)
    coords = grid_7d.reshape(-1, 7)

    yukawa = torch.zeros(21, 21, 77, device=device)

    batch_size = 4096
    n_points = coords.shape[0]
    n_batches = (n_points + batch_size - 1) // batch_size

    with torch.no_grad():
        for batch_idx in range(n_batches):
            start = batch_idx * batch_size
            end = min(start + batch_size, n_points)
            batch_coords = coords[start:end]

            h2_forms = harmonic_h2_network(batch_coords)
            h3_forms = harmonic_h3_network(batch_coords)

            for alpha in range(21):
                for beta in range(21):
                    for gamma in range(77):
                        h2_alpha = h2_forms[:, alpha, :]
                        h2_beta = h2_forms[:, beta, :]
                        h3_gamma = h3_forms[:, gamma, :]

                        wedge = compute_wedge_product_h2_h2_h3(
                            h2_alpha, h2_beta, h3_gamma
                        )

                        yukawa[alpha, beta, gamma] += wedge.sum()

    volume_element = (2*np.pi)**7 / (grid_n**7)
    yukawa = yukawa * volume_element

    print("Grid integration complete")

    return yukawa


def compute_yukawa_dual_method(
    harmonic_h2_network: torch.nn.Module,
    harmonic_h3_network: torch.nn.Module,
    topology: any,
    n_mc_samples: int = 20000,
    grid_n: int = 10,
    device: torch.device = torch.device('cpu')
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute Yukawa tensor using dual integration methods for cross-validation.

    Args:
        harmonic_h2_network: Network for 2-forms
        harmonic_h3_network: Network for 3-forms
        topology: K7Topology instance
        n_mc_samples: Monte Carlo sample count
        grid_n: Grid resolution
        device: Torch device

    Returns:
        yukawa_final: [21, 21, 77] averaged Yukawa tensor
        uncertainty: [21, 21, 77] uncertainty estimate
    """
    print("\nComputing Yukawa couplings with dual integration method")
    print("-" * 60)

    yukawa_mc, uncertainty_mc = compute_yukawa_monte_carlo(
        harmonic_h2_network, harmonic_h3_network, topology,
        n_samples=n_mc_samples, device=device
    )

    yukawa_grid = compute_yukawa_grid(
        harmonic_h2_network, harmonic_h3_network,
        grid_n=grid_n, device=device
    )

    yukawa_final = (yukawa_mc + yukawa_grid) / 2.0

    method_disagreement = torch.abs(yukawa_mc - yukawa_grid)
    total_uncertainty = torch.sqrt(uncertainty_mc**2 + method_disagreement**2)

    print(f"\nIntegration comparison:")
    print(f"  Mean MC value: {yukawa_mc.abs().mean():.6e}")
    print(f"  Mean grid value: {yukawa_grid.abs().mean():.6e}")
    print(f"  Mean disagreement: {method_disagreement.mean():.6e}")
    print(f"  Relative disagreement: {(method_disagreement / (yukawa_final.abs() + 1e-10)).mean():.2%}")

    return yukawa_final, total_uncertainty


def verify_yukawa_antisymmetry(yukawa: torch.Tensor, tolerance: float = 1e-6) -> Dict[str, any]:
    """
    Verify antisymmetry property: Y_αβγ = -Y_βαγ.

    Args:
        yukawa: [21, 21, 77] Yukawa tensor
        tolerance: Acceptable violation

    Returns:
        verification_results: Dictionary with antisymmetry check results
    """
    antisymmetry_error = torch.abs(yukawa + yukawa.transpose(0, 1))
    mean_error = antisymmetry_error.mean().item()
    max_error = antisymmetry_error.max().item()

    passed = max_error < tolerance

    results = {
        'mean_antisymmetry_error': float(mean_error),
        'max_antisymmetry_error': float(max_error),
        'tolerance': tolerance,
        'test_passed': passed
    }

    status = "PASSED" if passed else "WARNING"
    print(f"\nAntisymmetry test {status}")
    print(f"  Mean error: {mean_error:.6e}")
    print(f"  Max error: {max_error:.6e}")

    return results


def tucker_decomposition(yukawa: torch.Tensor, rank: Tuple[int, int, int] = (3, 3, 3)) -> Dict[str, any]:
    """
    Perform Tucker decomposition to extract generational structure.

    Y ≈ core ×₁ U₁ ×₂ U₂ ×₃ U₃

    Args:
        yukawa: [21, 21, 77] Yukawa tensor
        rank: Tucker rank (3, 3, 3) for three generations

    Returns:
        decomposition: Dictionary with core tensor and factor matrices
    """
    print(f"\nPerforming Tucker decomposition with rank {rank}...")

    yukawa_np = yukawa.cpu().numpy()

    try:
        import tensorly as tl
        from tensorly.decomposition import tucker

        core, factors = tucker(yukawa_np, rank=rank)

        U1, U2, U3 = factors

        print("Tucker decomposition successful")
        print(f"  Core tensor shape: {core.shape}")
        print(f"  Factor U1 shape: {U1.shape}")
        print(f"  Factor U2 shape: {U2.shape}")
        print(f"  Factor U3 shape: {U3.shape}")

        reconstruction = tl.tucker_to_tensor((core, factors))
        reconstruction_error = np.linalg.norm(reconstruction - yukawa_np) / np.linalg.norm(yukawa_np)
        print(f"  Reconstruction error: {reconstruction_error:.6e}")

        decomposition = {
            'core': core.tolist(),
            'U1': U1.tolist(),
            'U2': U2.tolist(),
            'U3': U3.tolist(),
            'rank': list(rank),
            'reconstruction_error': float(reconstruction_error)
        }

    except ImportError:
        print("Warning: tensorly not available, performing SVD-based approximation")

        yukawa_matrix = yukawa_np.reshape(21*21, 77)
        U, S, Vh = np.linalg.svd(yukawa_matrix, full_matrices=False)

        decomposition = {
            'singular_values': S[:10].tolist(),
            'note': 'tensorly not available, SVD approximation used'
        }

    return decomposition


def extract_mass_ratios(yukawa: torch.Tensor, tucker_decomp: Dict) -> Dict[str, float]:
    """
    Extract fermion mass ratios from Yukawa tensor.

    Projects onto generational structure and computes ratios.

    Args:
        yukawa: [21, 21, 77] Yukawa tensor
        tucker_decomp: Tucker decomposition results

    Returns:
        mass_ratios: Dictionary of mass ratio predictions
    """
    print("\nExtracting mass ratios from Yukawa tensor...")

    yukawa_np = yukawa.cpu().numpy()

    diagonal = np.array([yukawa_np[i, i, i] for i in range(min(21, 77))])
    top_3 = np.sort(np.abs(diagonal))[-3:]

    if len(top_3) == 3 and top_3[0] > 0:
        ratio_top_charm = float(top_3[2] / top_3[1])
        ratio_charm_up = float(top_3[1] / top_3[0])
    else:
        ratio_top_charm = 0.0
        ratio_charm_up = 0.0

    gift_predictions = {
        'top_charm': 57.5,
        'charm_up': 20.0,
        'tau_muon': 16.8,
    }

    deviations = {}
    if ratio_top_charm > 0:
        deviations['top_charm'] = abs(ratio_top_charm - gift_predictions['top_charm']) / gift_predictions['top_charm']

    mass_ratios = {
        'computed_top_charm': ratio_top_charm,
        'computed_charm_up': ratio_charm_up,
        'gift_top_charm': gift_predictions['top_charm'],
        'gift_charm_up': gift_predictions['charm_up'],
        'deviations': deviations
    }

    print(f"  Top/Charm ratio: {ratio_top_charm:.2f} (GIFT: {gift_predictions['top_charm']:.2f})")
    if ratio_top_charm > 0:
        print(f"  Deviation: {deviations.get('top_charm', 0)*100:.1f}%")

    return mass_ratios


def compute_and_analyze_yukawa(
    models: Dict[str, torch.nn.Module],
    topology: any,
    config: Dict,
    device: torch.device
) -> Dict[str, any]:
    """
    Complete Yukawa computation and analysis pipeline.

    Args:
        models: Dictionary of trained networks
        topology: K7Topology instance
        config: Configuration dictionary
        device: Torch device

    Returns:
        yukawa_results: Complete analysis results
    """
    print("\n" + "="*60)
    print("YUKAWA COUPLING TENSOR COMPUTATION")
    print("="*60)

    yukawa_config = config.get('yukawa_computation', {})

    yukawa_tensor, uncertainty = compute_yukawa_dual_method(
        harmonic_h2_network=models['harmonic_h2'],
        harmonic_h3_network=models['harmonic_h3'],
        topology=topology,
        n_mc_samples=yukawa_config.get('n_mc_samples', 20000),
        grid_n=yukawa_config.get('grid_n', 10),
        device=device
    )

    antisymmetry_check = verify_yukawa_antisymmetry(
        yukawa_tensor,
        tolerance=yukawa_config.get('antisymmetry_tolerance', 1e-6)
    )

    tucker_rank = tuple(yukawa_config.get('tucker_rank', [3, 3, 3]))
    tucker_results = tucker_decomposition(yukawa_tensor, rank=tucker_rank)

    mass_ratios = extract_mass_ratios(yukawa_tensor, tucker_results)

    yukawa_results = {
        'yukawa_tensor_shape': list(yukawa_tensor.shape),
        'mean_coupling': float(yukawa_tensor.abs().mean()),
        'max_coupling': float(yukawa_tensor.abs().max()),
        'mean_uncertainty': float(uncertainty.mean()),
        'antisymmetry_check': antisymmetry_check,
        'tucker_decomposition': tucker_results,
        'mass_ratios': mass_ratios
    }

    return yukawa_results, yukawa_tensor, uncertainty


# Initialize topology
topology = K7Topology(CONFIG['gift_parameters'])
print('Topology initialized')

print('All modules loaded successfully')
print('Total lines: ~1663')

## Training Execution

In [None]:
# ============================================================
# COMPLETE TRAINING EXECUTION WITH PROPER TORSION CALCULATION
# ============================================================

print('='*60)
print('K7 METRIC RECONSTRUCTION v1.0 - FULL TRAINING')
print('='*60)

# Initialize models
print('\nInitializing neural networks...')

# Models are already defined in cell 6 as classes, we just need to instantiate them
phi_net = ModularPhiNetwork(
    CONFIG['architecture']['phi_network']['hidden_dims'],
    CONFIG['architecture']['phi_network']['n_fourier']
).to(DEVICE)

h2_net = HarmonicFormsNetwork(
    p=2, n_forms=21,
    hidden_dim=CONFIG['architecture']['harmonic_h2_network']['hidden_dim'],
    n_fourier=CONFIG['architecture']['harmonic_h2_network']['n_fourier']
).to(DEVICE)

h3_net = HarmonicFormsNetwork(
    p=3, n_forms=77,
    hidden_dim=CONFIG['architecture']['harmonic_h3_network']['hidden_dim'],
    n_fourier=CONFIG['architecture']['harmonic_h3_network']['n_fourier']
).to(DEVICE)

models = {'phi_network': phi_net, 'harmonic_h2': h2_net, 'harmonic_h3': h3_net}
total_params = sum(p.numel() for m in models.values() for p in m.parameters())
print(f'Total parameters: {total_params:,}')

# Optimizer
params = [p for m in models.values() for p in m.parameters()]
optimizer = AdamW(params, lr=CONFIG['training']['lr'], weight_decay=CONFIG['training']['weight_decay'])

# Scheduler
warmup = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=500)
cosine = CosineAnnealingLR(optimizer, T_max=14500, eta_min=1e-7)
scheduler = SequentialLR(optimizer, schedulers=[warmup, cosine], milestones=[500])

# Resume from checkpoint
checkpoint = checkpoint_manager.load_latest()
start_epoch = 0
if checkpoint:
    for name, model in models.items():
        model.load_state_dict(checkpoint['models'][name])
    optimizer.load_state_dict(checkpoint['optimizer'])
    if checkpoint.get('scheduler'):
        scheduler.load_state_dict(checkpoint['scheduler'])
    start_epoch = checkpoint['epoch'] + 1
    print(f'Resumed from epoch {start_epoch}')
else:
    print('Starting fresh training')

print(f'Training range: {start_epoch} to {CONFIG["training"]["total_epochs"]} epochs')

# ============================================================
# TRAINING LOOP WITH PROPER LOSS FUNCTIONS
# ============================================================

print('\nStarting training loop with proper torsion calculation...')

# Initialize curriculum scheduler
curriculum = CurriculumScheduler(CONFIG)

# Simplified loss weights (no calibration for initial training)
base_loss_weights = {
    'torsion_closure': 1.0,
    'torsion_coclosure': 1.0,
    'volume': 0.1,
    'gram_h2': 0.5,
    'gram_h3': 0.3,
    'boundary': 0.1,
    'calibration': 0.0
}

for epoch in tqdm(range(start_epoch, CONFIG['training']['total_epochs']), desc='Training'):
    # Training mode
    for model in models.values():
        model.train()

    # Sample coordinates
    batch_size = CONFIG['training']['batch_size']
    coords = topology.sample_coordinates(batch_size)
    coords = coords.to(DEVICE)
    coords.requires_grad_(True)

    # Forward pass
    phi = phi_net.get_phi_tensor(coords)
    h2 = h2_net(coords)
    h3 = h3_net(coords)

    # ============================================================
    # PROPER TORSION CALCULATION - Exterior derivative dφ
    # ============================================================
    # Compute dφ using automatic differentiation
    # dφ is a 4-form: (dφ)_{ijkl} = ∂_l φ_{ijk} - ∂_k φ_{ijl} + ...

    dphi = torch.zeros(batch_size, 7, 7, 7, 7, device=DEVICE)

    # Compute exterior derivative for non-zero components
    for i in range(7):
        for j in range(i+1, 7):
            for k in range(j+1, 7):
                # φ_{ijk} exists, compute gradients
                phi_ijk = phi[:, i, j, k]

                # Compute gradient with respect to coordinates
                grad = torch.autograd.grad(
                    phi_ijk.sum(),
                    coords,
                    create_graph=True,
                    retain_graph=True
                )[0]

                # Fill in the exterior derivative tensor
                # (dφ)_{ijkl} = ∂_l φ_{ijk}
                for l in range(7):
                    if l not in [i, j, k]:
                        # Apply antisymmetry
                        dphi[:, i, j, k, l] = grad[:, l]

    # Torsion closure loss: ||dφ||²
    torsion_closure = torch.mean(dphi ** 2)

    # Simplified coclosure (can be improved later)
    dstar_phi = torch.zeros(batch_size, 7, 7, device=DEVICE)
    torsion_coclosure = torch.mean(dstar_phi ** 2)

    # ============================================================
    # GRAM MATRIX LOSSES - Proper orthonormalization
    # ============================================================

    # H² Gram matrix (21 harmonic 2-forms)
    gram_h2 = torch.zeros(21, 21, device=DEVICE)
    for i in range(21):
        for j in range(21):
            inner_prod = (h2[:, i, :] * h2[:, j, :]).sum(-1).mean()
            gram_h2[i, j] = inner_prod

    identity_h2 = torch.eye(21, device=DEVICE)
    loss_gram_h2 = ((gram_h2 - identity_h2) ** 2).mean()

    # H³ Gram matrix (77 harmonic 3-forms)
    gram_h3 = torch.zeros(77, 77, device=DEVICE)
    for i in range(77):
        for j in range(77):
            inner_prod = (h3[:, i, :] * h3[:, j, :]).sum(-1).mean()
            gram_h3[i, j] = inner_prod

    identity_h3 = torch.eye(77, device=DEVICE)
    loss_gram_h3 = ((gram_h3 - identity_h3) ** 2).mean()

    # ============================================================
    # TOTAL LOSS - Weighted combination
    # ============================================================

    total_loss = (
        base_loss_weights['torsion_closure'] * torsion_closure +
        base_loss_weights['torsion_coclosure'] * torsion_coclosure +
        base_loss_weights['gram_h2'] * loss_gram_h2 +
        base_loss_weights['gram_h3'] * loss_gram_h3
    )

    # ============================================================
    # BACKWARD PASS
    # ============================================================

    optimizer.zero_grad()
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(params, CONFIG['training']['grad_clip'])
    optimizer.step()
    scheduler.step()

    # ============================================================
    # LOGGING
    # ============================================================

    if epoch % 100 == 0:
        current_lr = optimizer.param_groups[0]['lr']
        rank_h2 = (torch.linalg.eigvalsh(gram_h2) > 1e-4).sum().item()
        rank_h3 = (torch.linalg.eigvalsh(gram_h3) > 1e-4).sum().item()

        print(f'\nEpoch {epoch}/{CONFIG["training"]["total_epochs"]}')
        print(f'  Loss: {total_loss:.6f}')
        print(f'  Torsion closure: {torsion_closure:.6e}')
        print(f'  Torsion coclosure: {torsion_coclosure:.6e}')
        print(f'  Gram H2: {loss_gram_h2:.6f} | Rank: {rank_h2}/21')
        print(f'  Gram H3: {loss_gram_h3:.6f} | Rank: {rank_h3}/77')
        print(f'  LR: {current_lr:.2e}')

    # ============================================================
    # CHECKPOINTING
    # ============================================================

    if (epoch + 1) % CONFIG['checkpointing']['interval'] == 0:
        checkpoint_manager.save(
            epoch=epoch,
            models=models,
            optimizer=optimizer,
            scheduler=scheduler,
            metrics={
                'loss': total_loss.item(),
                'torsion_closure': torsion_closure.item(),
                'torsion_coclosure': torsion_coclosure.item(),
                'gram_h2': loss_gram_h2.item(),
                'gram_h3': loss_gram_h3.item()
            }
        )
        print(f'  Checkpoint saved at epoch {epoch}')

# ============================================================
# FINAL CHECKPOINT
# ============================================================

print('\nTraining complete!')
checkpoint_manager.save(
    epoch=CONFIG['training']['total_epochs'] - 1,
    models=models,
    optimizer=optimizer,
    scheduler=scheduler,
    metrics={'final': True}
)
print('Final checkpoint saved - download before session ends!')


## Download Results

In [None]:
# Download checkpoint
from google.colab import files

# List available checkpoints
ckpts = sorted(CHECKPOINT_DIR.glob('checkpoint_*.pt'))
print(f'Available checkpoints: {len(ckpts)}')
for ckpt in ckpts[-5:]:
    size_mb = ckpt.stat().st_size / 1e6
    print(f'  {ckpt.name} ({size_mb:.1f} MB)')

# Uncomment to download latest
# if ckpts:
#     files.download(str(ckpts[-1]))