# GIFT-Native PINN Training Notebook

Train a Physics-Informed Neural Network with built-in GIFT algebraic structure
to learn the G2 metric on K7.

**Key Features:**
- Fano plane structure constants (exact)
- G2 adjoint representation (14 DOF instead of 35)
- Target: det(g) = 65/32, ||T|| < 0.001

**Runnable on Google Colab with free T4 GPU.**

## 1. Setup

In [None]:
# Install dependencies (uncomment for Colab)
# !pip install torch numpy matplotlib tqdm

In [None]:
# Clone GIFT repository (uncomment for Colab)
# !git clone https://github.com/gift-framework/core.git
# %cd core

In [None]:
import sys
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Check GPU availability
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Import GIFT modules
try:
    from gift_core.nn.gift_native_pinn import (
        GIFTNativePINN,
        GIFTNativeLoss,
        GIFTTrainConfig,
        create_gift_native_pinn,
        train_gift_native_pinn,
        sample_k7_points,
        export_analytical_form,
        phi0_standard,
        FANO_LINES,
        B2, B3, DIM_G2, DET_G_TARGET_FLOAT, H_STAR,
        TORSION_THRESHOLD,
    )
    print("GIFT modules loaded successfully!")
except ImportError as e:
    print(f"Import error: {e}")
    print("Make sure you're in the gift-framework/core directory")

## 2. GIFT Constants

Display the key GIFT constants that are hard-coded in the architecture.

In [None]:
print("GIFT Constants (Proven in Lean)")
print("=" * 40)
print(f"b2 (Second Betti number): {B2}")
print(f"b3 (Third Betti number): {B3}")
print(f"H* = b2 + b3 + 1 = {H_STAR}")
print(f"dim(G2) = {DIM_G2}")
print(f"det(g) target = 65/32 = {DET_G_TARGET_FLOAT:.6f}")
print(f"Joyce torsion threshold = {TORSION_THRESHOLD}")
print()
print("Fano plane lines:")
for i, line in enumerate(FANO_LINES):
    print(f"  Line {i}: {line}")

In [None]:
# Display phi0 structure
phi0 = phi0_standard(normalize=True)
print(f"phi0 has {len(phi0)} components (C(7,3) = 35)")
print(f"Non-zero components: {np.sum(np.abs(phi0) > 1e-10)}")
print(f"phi0 norm: {np.linalg.norm(phi0):.6f}")

## 3. Create Model

Create the GIFT-native PINN with G2 structure built-in.

In [None]:
# Model configuration
model = create_gift_native_pinn(
    num_frequencies=32,       # Fourier features
    hidden_dims=[128, 128, 128],  # MLP architecture
    perturbation_scale=0.01,  # Scale of delta_phi
)
model = model.to(device)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")
print(f"Output dimension: 14 (G2 adjoint) -> 35 (3-form)")

In [None]:
# Test forward pass
x_test = sample_k7_points(100, device)
with torch.no_grad():
    phi_test = model(x_test)
    det_test = model.det_g(x_test)

print(f"Input shape: {x_test.shape}")
print(f"Output (phi) shape: {phi_test.shape}")
print(f"det(g) mean: {det_test.mean().item():.6f}")
print(f"det(g) target: {DET_G_TARGET_FLOAT:.6f}")

## 4. Training Configuration

In [None]:
# Training configuration
config = GIFTTrainConfig(
    epochs=5000,
    batch_size=1024,
    learning_rate=1e-3,
    
    # Loss weights
    det_weight=100.0,     # Enforce det(g) = 65/32
    torsion_weight=1.0,   # Minimize torsion
    topo_weight=10.0,     # Topological constraint
    sparse_weight=0.1,    # Encourage sparse solution
    pd_weight=10.0,       # Positive definite metric
    
    # Early stopping
    target_torsion=0.001,
    target_det_error=1e-6,
    
    device=device,
)

print("Training Configuration:")
print(f"  Epochs: {config.epochs}")
print(f"  Batch size: {config.batch_size}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Target torsion: {config.target_torsion}")
print(f"  Target det error: {config.target_det_error}")

## 5. Training Loop

Train the PINN with curriculum learning.

In [None]:
# Manual training loop with visualization
loss_fn = GIFTNativeLoss(
    det_weight=config.det_weight,
    torsion_weight=config.torsion_weight,
    topo_weight=config.topo_weight,
    sparse_weight=config.sparse_weight,
    pd_weight=config.pd_weight,
)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.learning_rate,
    weight_decay=1e-5
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=100,
)

# History tracking
history = {
    'loss': [],
    'torsion': [],
    'det_error': [],
    'lr': [],
}

In [None]:
# Training loop
model.train()
pbar = tqdm(range(config.epochs), desc="Training")

best_torsion = float('inf')
best_state = None

for epoch in pbar:
    # Sample batch
    x = sample_k7_points(config.batch_size, device)
    
    # Forward
    optimizer.zero_grad()
    loss, components = loss_fn(model, x, return_components=True)
    
    # Backward
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    scheduler.step(loss)
    
    # Record
    history['loss'].append(loss.item())
    history['torsion'].append(components['torsion'].item())
    history['det_error'].append(components['det'].item())
    history['lr'].append(optimizer.param_groups[0]['lr'])
    
    # Track best
    if components['torsion'].item() < best_torsion:
        best_torsion = components['torsion'].item()
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
    
    # Update progress bar
    pbar.set_postfix({
        'loss': f"{loss.item():.4f}",
        'torsion': f"{components['torsion'].item():.6f}",
        'det_err': f"{components['det'].item():.6f}",
    })
    
    # Early stopping
    if (components['torsion'].item() < config.target_torsion and
        components['det'].item() < config.target_det_error):
        print(f"\nConverged at epoch {epoch}!")
        break

print(f"\nBest torsion achieved: {best_torsion:.6f}")

## 6. Training Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Loss curve
axes[0, 0].semilogy(history['loss'])
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Total Loss')
axes[0, 0].set_title('Training Loss')
axes[0, 0].grid(True, alpha=0.3)

# Torsion
axes[0, 1].semilogy(history['torsion'], label='Torsion norm')
axes[0, 1].axhline(y=config.target_torsion, color='g', linestyle='--', label=f'Target ({config.target_torsion})')
axes[0, 1].axhline(y=TORSION_THRESHOLD, color='r', linestyle='--', label=f'Joyce threshold ({TORSION_THRESHOLD})')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Torsion Norm')
axes[0, 1].set_title('Torsion Convergence')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Det(g) error
axes[1, 0].semilogy(history['det_error'])
axes[1, 0].axhline(y=config.target_det_error, color='g', linestyle='--', label=f'Target ({config.target_det_error})')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('|det(g) - 65/32|Â²')
axes[1, 0].set_title('Determinant Error')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Learning rate
axes[1, 1].semilogy(history['lr'])
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Learning Rate')
axes[1, 1].set_title('Learning Rate Schedule')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=150)
plt.show()

## 7. Model Evaluation

In [None]:
# Load best model
if best_state is not None:
    model.load_state_dict(best_state)
model.eval()

# Evaluate on fresh samples
n_eval = 10000
x_eval = sample_k7_points(n_eval, device)

with torch.no_grad():
    phi_eval = model(x_eval)
    det_eval = model.det_g(x_eval)
    torsion_eval = model.torsion_norm(x_eval)
    adjoint_eval = model.get_adjoint_params(x_eval)

print("Evaluation Results (n={:,})".format(n_eval))
print("=" * 40)
print(f"det(g):")
print(f"  Mean: {det_eval.mean().item():.8f}")
print(f"  Std:  {det_eval.std().item():.8f}")
print(f"  Target: {DET_G_TARGET_FLOAT:.8f}")
print(f"  Error: {abs(det_eval.mean().item() - DET_G_TARGET_FLOAT):.2e}")
print()
print(f"Torsion norm:")
print(f"  Mean: {torsion_eval.mean().item():.8f}")
print(f"  Max:  {torsion_eval.max().item():.8f}")
print(f"  Target: < {config.target_torsion}")
print(f"  Joyce threshold: {TORSION_THRESHOLD}")
print()
print(f"G2 adjoint parameters:")
print(f"  Mean abs: {adjoint_eval.abs().mean().item():.6f}")
print(f"  Max abs:  {adjoint_eval.abs().max().item():.6f}")

In [None]:
# Visualize det(g) distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Histogram
axes[0].hist(det_eval.cpu().numpy(), bins=50, density=True, alpha=0.7)
axes[0].axvline(x=DET_G_TARGET_FLOAT, color='r', linestyle='--', label='Target 65/32')
axes[0].set_xlabel('det(g)')
axes[0].set_ylabel('Density')
axes[0].set_title('det(g) Distribution')
axes[0].legend()

# Torsion histogram
axes[1].hist(torsion_eval.cpu().numpy(), bins=50, density=True, alpha=0.7)
axes[1].axvline(x=config.target_torsion, color='g', linestyle='--', label=f'Target {config.target_torsion}')
axes[1].axvline(x=TORSION_THRESHOLD, color='r', linestyle='--', label=f'Joyce {TORSION_THRESHOLD}')
axes[1].set_xlabel('Torsion norm')
axes[1].set_ylabel('Density')
axes[1].set_title('Torsion Distribution')
axes[1].legend()

plt.tight_layout()
plt.savefig('evaluation_histograms.png', dpi=150)
plt.show()

## 8. Export Results

In [None]:
# Save model checkpoint
torch.save({
    'model_state_dict': model.state_dict(),
    'config': {
        'num_frequencies': model.fourier.num_frequencies,
        'perturbation_scale': model.perturbation_scale,
    },
    'training_history': history,
    'best_torsion': best_torsion,
}, 'gift_pinn_trained.pt')

print("Model saved to: gift_pinn_trained.pt")

In [None]:
# Export analytical form
analytical_data = export_analytical_form(
    model,
    'phi_analytical_coefficients.json',
    grid_resolution=32,
)

print("Analytical coefficients exported to: phi_analytical_coefficients.json")
print(f"  Dominant modes extracted: {len(analytical_data.get('dominant_modes', []))}")

## 9. Summary

Print final summary and success criteria.

In [None]:
print("=" * 60)
print("  GIFT-Native PINN Training Summary")
print("=" * 60)
print()

# Check success criteria
torsion_ok = torsion_eval.mean().item() < config.target_torsion
det_ok = abs(det_eval.mean().item() - DET_G_TARGET_FLOAT) < config.target_det_error

print("Success Criteria:")
print(f"  [{'X' if torsion_ok else ' '}] Torsion < {config.target_torsion}: {torsion_eval.mean().item():.6f}")
print(f"  [{'X' if det_ok else ' '}] |det(g) - 65/32| < {config.target_det_error}: {abs(det_eval.mean().item() - DET_G_TARGET_FLOAT):.2e}")
print()

if torsion_ok and det_ok:
    print("SUCCESS: All criteria met!")
else:
    print("Training may need more epochs or tuning.")

print()
print("Output files:")
print("  - gift_pinn_trained.pt: Model checkpoint")
print("  - phi_analytical_coefficients.json: Extracted coefficients")
print("  - training_curves.png: Training visualization")
print("  - evaluation_histograms.png: Evaluation plots")