# GIFT-Native PINN Training Notebook

Train a Physics-Informed Neural Network with built-in GIFT algebraic structure.

**Runnable on Google Colab with free T4/A100 GPU.**

## 1. Setup & Dependencies

In [None]:
# Install dependencies
!pip install -q torch numpy matplotlib tqdm

In [None]:
import sys
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from dataclasses import dataclass
from typing import List, Dict, Optional
from fractions import Fraction

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. GIFT Constants (Hard-coded, proven in Lean)

In [None]:
# =============================================================================
# GIFT CONSTANTS
# =============================================================================

B2 = 21           # Second Betti number = C(7,2)
B3 = 77           # Third Betti number
DIM_G2 = 14       # dim(G2)
H_STAR = B2 + B3 + 1  # = 99

DET_G_TARGET = Fraction(65, 32)  # = 2.03125
DET_G_TARGET_FLOAT = float(DET_G_TARGET)

TORSION_THRESHOLD = 0.0288  # Joyce threshold
PINN_TARGET_TORSION = 0.001  # Our target (20x margin)

# Fano plane lines (cyclic triples)
FANO_LINES = [
    (0, 1, 3), (1, 2, 4), (2, 3, 5), (3, 4, 6),
    (4, 5, 0), (5, 6, 1), (6, 0, 2),
]

print("GIFT Constants (Proven in Lean)")
print("=" * 40)
print(f"b2 (Second Betti number): {B2}")
print(f"b3 (Third Betti number): {B3}")
print(f"H* = b2 + b3 + 1 = {H_STAR}")
print(f"dim(G2) = {DIM_G2}")
print(f"det(g) target = 65/32 = {DET_G_TARGET_FLOAT:.6f}")
print(f"Joyce torsion threshold = {TORSION_THRESHOLD}")
print()
print("Fano plane lines:")
for i, line in enumerate(FANO_LINES):
    print(f"  Line {i}: {line}")

## 3. Build Epsilon Tensor from Fano Plane

In [None]:
def build_epsilon_tensor():
    """Build structure constants epsilon_ijk from Fano plane."""
    epsilon = np.zeros((7, 7, 7), dtype=np.float32)
    for (i, j, k) in FANO_LINES:
        # Cyclic: +1
        epsilon[i, j, k] = epsilon[j, k, i] = epsilon[k, i, j] = 1
        # Anti-cyclic: -1
        epsilon[j, i, k] = epsilon[i, k, j] = epsilon[k, j, i] = -1
    return epsilon

EPSILON = build_epsilon_tensor()
print(f"Epsilon tensor shape: {EPSILON.shape}")
print(f"Non-zero entries: {np.count_nonzero(EPSILON)}")

In [None]:
def phi0_standard(normalize=True):
    """Standard G2 3-form (35 independent components)."""
    phi0 = []
    for i in range(7):
        for j in range(i + 1, 7):
            for k in range(j + 1, 7):
                phi0.append(EPSILON[i, j, k])
    phi0 = np.array(phi0, dtype=np.float32)
    if normalize:
        scale = (65.0 / 32.0) ** (1.0 / 7.0)
        phi0 = phi0 * scale
    return phi0

phi0 = phi0_standard()
print(f"phi0 has {len(phi0)} components (C(7,3) = 35)")
print(f"Non-zero: {np.sum(np.abs(phi0) > 1e-10)}")
print(f"Norm: {np.linalg.norm(phi0):.6f}")

## 4. Neural Network Components

In [None]:
class FourierFeatures(nn.Module):
    """Random Fourier feature encoding."""
    def __init__(self, input_dim=7, num_frequencies=32, scale=1.0):
        super().__init__()
        self.num_frequencies = num_frequencies
        self.output_dim = 2 * num_frequencies
        B = torch.randn(num_frequencies, input_dim) * scale
        self.register_buffer('B', B)

    def forward(self, x):
        projected = 2 * np.pi * torch.matmul(x, self.B.T)
        return torch.cat([torch.cos(projected), torch.sin(projected)], dim=-1)

In [None]:
def g2_generators():
    """Compute 14 generators of G2 in so(7)."""
    generators = np.zeros((14, 7, 7), dtype=np.float32)
    # First 7: rotations in Fano planes
    for idx, (i, j, k) in enumerate(FANO_LINES):
        generators[idx, i, j] = 1
        generators[idx, j, i] = -1
    # Remaining 7: mixed rotations
    for idx in range(7):
        i, j, k = idx, (idx + 1) % 7, (idx + 3) % 7
        gen_idx = 7 + idx
        generators[gen_idx, i, k] = 1
        generators[gen_idx, k, i] = -1
        generators[gen_idx, j, k] = 0.5
        generators[gen_idx, k, j] = -0.5
    # Normalize
    for idx in range(14):
        norm = np.linalg.norm(generators[idx])
        if norm > 1e-10:
            generators[idx] /= norm
    return generators

## 5. GIFT-Native PINN Model

In [None]:
class GIFTNativePINN(nn.Module):
    """
    PINN with GIFT structure built-in.
    - Hard-coded Fano epsilon_ijk
    - G2 adjoint: 14 DOF (not 35)
    - phi = phi0 + scale * delta_phi
    """
    def __init__(self, num_frequencies=32, hidden_dims=None, perturbation_scale=0.01):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = [128, 128, 128]
        
        self.perturbation_scale = perturbation_scale
        
        # Register buffers
        self.register_buffer('epsilon', torch.from_numpy(EPSILON))
        self.register_buffer('phi0', torch.from_numpy(phi0_standard()))
        
        # Precompute Lie derivatives
        self._precompute_lie_derivatives()
        
        # Network
        self.fourier = FourierFeatures(input_dim=7, num_frequencies=num_frequencies)
        
        layers = []
        in_dim = self.fourier.output_dim
        for h_dim in hidden_dims:
            layers.extend([nn.Linear(in_dim, h_dim), nn.SiLU()])
            in_dim = h_dim
        layers.append(nn.Linear(in_dim, 14))  # 14 G2 adjoint params
        self.mlp = nn.Sequential(*layers)
        
        self._init_weights()
    
    def _precompute_lie_derivatives(self):
        generators = g2_generators()
        lie_derivs = np.zeros((14, 35), dtype=np.float32)
        for a in range(14):
            X = generators[a]
            idx = 0
            for i in range(7):
                for j in range(i+1, 7):
                    for k in range(j+1, 7):
                        val = sum(X[i,l]*EPSILON[l,j,k] + X[j,l]*EPSILON[i,l,k] + X[k,l]*EPSILON[i,j,l] for l in range(7))
                        lie_derivs[a, idx] = val
                        idx += 1
        self.register_buffer('lie_derivatives', torch.from_numpy(lie_derivs))
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.1)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self, x):
        h = self.fourier(x)
        adjoint = self.mlp(h)  # (N, 14)
        delta_phi = torch.matmul(adjoint, self.lie_derivatives)  # (N, 35)
        return self.phi0.unsqueeze(0) + self.perturbation_scale * delta_phi
    
    def phi_tensor(self, x):
        """Get full (N, 7, 7, 7) tensor."""
        components = self.forward(x)
        N = components.shape[0]
        phi = torch.zeros(N, 7, 7, 7, device=x.device, dtype=x.dtype)
        idx = 0
        for i in range(7):
            for j in range(i+1, 7):
                for k in range(j+1, 7):
                    val = components[:, idx]
                    phi[:,i,j,k] = phi[:,j,k,i] = phi[:,k,i,j] = val
                    phi[:,j,i,k] = phi[:,i,k,j] = phi[:,k,j,i] = -val
                    idx += 1
        return phi
    
    def metric(self, x):
        phi = self.phi_tensor(x)
        return torch.einsum('nikl,njlm->nij', phi, phi) / 36.0
    
    def det_g(self, x):
        return torch.linalg.det(self.metric(x))
    
    def get_adjoint_params(self, x):
        return self.mlp(self.fourier(x))

## 6. Loss Function

In [None]:
class GIFTLoss(nn.Module):
    def __init__(self, det_weight=100.0, torsion_weight=1.0, sparse_weight=0.1, pd_weight=10.0):
        super().__init__()
        self.det_weight = det_weight
        self.torsion_weight = torsion_weight
        self.sparse_weight = sparse_weight
        self.pd_weight = pd_weight
    
    def forward(self, model, x, return_components=False):
        losses = {}
        
        # 1. Determinant loss
        det_g = model.det_g(x)
        losses['det'] = torch.mean((det_g - DET_G_TARGET_FLOAT) ** 2)
        
        # 2. Torsion loss (gradient-based proxy)
        x_grad = x.clone().requires_grad_(True)
        phi = model(x_grad)
        torsion = 0.0
        for i in range(min(10, 35)):  # Sample components for speed
            grad = torch.autograd.grad(phi[:, i].sum(), x_grad, create_graph=True, retain_graph=True)[0]
            torsion = torsion + (grad ** 2).sum(dim=-1).mean()
        losses['torsion'] = torsion / 10.0
        
        # 3. Sparsity
        adjoint = model.get_adjoint_params(x)
        losses['sparse'] = torch.mean(adjoint ** 2)
        
        # 4. Positive definiteness
        g = model.metric(x)
        eigvals = torch.linalg.eigvalsh(g)
        losses['pd'] = torch.mean(torch.relu(-eigvals) ** 2)
        
        total = (self.det_weight * losses['det'] + 
                 self.torsion_weight * losses['torsion'] +
                 self.sparse_weight * losses['sparse'] +
                 self.pd_weight * losses['pd'])
        
        return (total, losses) if return_components else total

## 7. Create Model

In [None]:
model = GIFTNativePINN(
    num_frequencies=32,
    hidden_dims=[128, 128, 128],
    perturbation_scale=0.01,
).to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")
print(f"Architecture: 7D -> Fourier(64) -> MLP -> 14 (G2 adjoint) -> 35 (3-form)")

In [None]:
# Test forward pass
x_test = torch.rand(100, 7, device=device)
with torch.no_grad():
    phi_test = model(x_test)
    det_test = model.det_g(x_test)

print(f"Input: {x_test.shape}")
print(f"Output (phi): {phi_test.shape}")
print(f"det(g) mean: {det_test.mean().item():.6f} (target: {DET_G_TARGET_FLOAT:.6f})")

## 8. Training

In [None]:
# Config
EPOCHS = 3000
BATCH_SIZE = 512
LR = 1e-3
TARGET_TORSION = 0.001
TARGET_DET_ERROR = 1e-5

loss_fn = GIFTLoss(det_weight=100.0, torsion_weight=1.0, sparse_weight=0.1, pd_weight=10.0)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=200)

history = {'loss': [], 'torsion': [], 'det_error': [], 'lr': []}
best_torsion = float('inf')
best_state = None

In [None]:
# Training loop
model.train()
pbar = tqdm(range(EPOCHS), desc="Training")

for epoch in pbar:
    x = torch.rand(BATCH_SIZE, 7, device=device)
    
    optimizer.zero_grad()
    loss, components = loss_fn(model, x, return_components=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    scheduler.step(loss)
    
    # Record
    history['loss'].append(loss.item())
    history['torsion'].append(components['torsion'].item())
    history['det_error'].append(components['det'].item())
    history['lr'].append(optimizer.param_groups[0]['lr'])
    
    if components['torsion'].item() < best_torsion:
        best_torsion = components['torsion'].item()
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
    
    pbar.set_postfix({
        'loss': f"{loss.item():.4f}",
        'torsion': f"{components['torsion'].item():.6f}",
        'det_err': f"{components['det'].item():.6f}",
    })
    
    if components['torsion'].item() < TARGET_TORSION and components['det'].item() < TARGET_DET_ERROR:
        print(f"\nConverged at epoch {epoch}!")
        break

print(f"\nBest torsion: {best_torsion:.6f}")

## 9. Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

axes[0,0].semilogy(history['loss'])
axes[0,0].set_xlabel('Epoch'); axes[0,0].set_ylabel('Loss'); axes[0,0].set_title('Total Loss')
axes[0,0].grid(True, alpha=0.3)

axes[0,1].semilogy(history['torsion'], label='Torsion')
axes[0,1].axhline(TARGET_TORSION, color='g', linestyle='--', label=f'Target {TARGET_TORSION}')
axes[0,1].axhline(TORSION_THRESHOLD, color='r', linestyle='--', label=f'Joyce {TORSION_THRESHOLD}')
axes[0,1].set_xlabel('Epoch'); axes[0,1].set_ylabel('Torsion'); axes[0,1].legend()
axes[0,1].grid(True, alpha=0.3)

axes[1,0].semilogy(history['det_error'])
axes[1,0].set_xlabel('Epoch'); axes[1,0].set_ylabel('|det(g) - 65/32|Â²'); axes[1,0].set_title('Det Error')
axes[1,0].grid(True, alpha=0.3)

axes[1,1].semilogy(history['lr'])
axes[1,1].set_xlabel('Epoch'); axes[1,1].set_ylabel('LR'); axes[1,1].set_title('Learning Rate')
axes[1,1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=150)
plt.show()

## 10. Evaluation

In [None]:
# Load best model
if best_state:
    model.load_state_dict(best_state)
model.eval()

with torch.no_grad():
    x_eval = torch.rand(10000, 7, device=device)
    det_eval = model.det_g(x_eval)
    adjoint_eval = model.get_adjoint_params(x_eval)

print("=" * 50)
print("EVALUATION RESULTS")
print("=" * 50)
print(f"det(g) mean: {det_eval.mean().item():.8f}")
print(f"det(g) target: {DET_G_TARGET_FLOAT:.8f}")
print(f"det(g) error: {abs(det_eval.mean().item() - DET_G_TARGET_FLOAT):.2e}")
print(f"Best torsion: {best_torsion:.6f}")
print(f"Joyce threshold: {TORSION_THRESHOLD}")
print(f"Margin: {TORSION_THRESHOLD / best_torsion:.1f}x" if best_torsion > 0 else "N/A")

In [None]:
# Histograms
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].hist(det_eval.cpu().numpy(), bins=50, density=True, alpha=0.7)
axes[0].axvline(DET_G_TARGET_FLOAT, color='r', linestyle='--', label='Target 65/32')
axes[0].set_xlabel('det(g)'); axes[0].set_title('det(g) Distribution'); axes[0].legend()

axes[1].hist(adjoint_eval.cpu().numpy().flatten(), bins=50, density=True, alpha=0.7)
axes[1].set_xlabel('Adjoint params'); axes[1].set_title('G2 Adjoint Distribution')

plt.tight_layout()
plt.savefig('evaluation.png', dpi=150)
plt.show()

## 11. Save Model

In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
    'best_state_dict': best_state,
    'history': history,
    'best_torsion': best_torsion,
    'config': {
        'num_frequencies': 32,
        'hidden_dims': [128, 128, 128],
        'perturbation_scale': 0.01,
    }
}, 'gift_pinn_trained.pt')

print("Model saved to: gift_pinn_trained.pt")

## 12. Summary

In [None]:
torsion_ok = best_torsion < TARGET_TORSION
det_ok = abs(det_eval.mean().item() - DET_G_TARGET_FLOAT) < TARGET_DET_ERROR

print("=" * 60)
print("  GIFT-Native PINN Training Summary")
print("=" * 60)
print(f"  [{'X' if torsion_ok else ' '}] Torsion < {TARGET_TORSION}: {best_torsion:.6f}")
print(f"  [{'X' if det_ok else ' '}] |det(g) - 65/32| < {TARGET_DET_ERROR}: {abs(det_eval.mean().item() - DET_G_TARGET_FLOAT):.2e}")
print()
if torsion_ok and det_ok:
    print("SUCCESS: All criteria met!")
else:
    print("Training may need more epochs.")