# GIFT v2.2 - Variational G2 Metric Learning

**Simplified approach**: Learn the metric g directly (not via phi extraction).

## GIFT v2.2 Constraints

| Constraint | Value | Origin |
|------------|-------|--------|
| det(g) | 65/32 = 2.03125 | From h* = 99 |
| kappa_T | 1/61 | Torsion magnitude |
| g > 0 | Positive definite | Riemannian |
| Smooth | Low gradient energy | Variational |

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import math

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

In [None]:
# GIFT v2.2 targets
DET_TARGET = 65/32      # = 2.03125
KAPPA_TARGET = 1/61     # ~ 0.01639
DIM = 7

print(f"GIFT v2.2 Targets:")
print(f"  det(g) = {DET_TARGET} = 65/32")
print(f"  kappa_T = {KAPPA_TARGET:.6f} = 1/61")

In [None]:
class MetricNetwork(nn.Module):
    """Learn a 7x7 SPD metric field g(x).
    
    Uses Cholesky parametrization: g = L @ L^T
    where L is lower triangular with positive diagonal.
    This guarantees g is symmetric positive definite.
    """
    
    def __init__(self, hidden=256, n_layers=4):
        super().__init__()
        
        # Fourier features
        self.B = nn.Parameter(torch.randn(DIM, 64) * 2.0, requires_grad=False)
        
        # MLP
        layers = []
        in_dim = DIM + 128  # coords + fourier
        for _ in range(n_layers):
            layers.extend([nn.Linear(in_dim, hidden), nn.SiLU()])
            in_dim = hidden
        self.backbone = nn.Sequential(*layers)
        
        # Output: 28 components for lower triangular L
        # 7 diagonal + 21 off-diagonal
        self.L_head = nn.Linear(hidden, 28)
        
        # Initialize near identity
        nn.init.zeros_(self.L_head.weight)
        # Bias: L = I gives g = I, det = 1
        # We want det = 2.03125, so scale L by 2.03125^(1/14) ~ 1.052
        scale = DET_TARGET ** (1/14)
        bias = torch.zeros(28)
        bias[:7] = math.log(scale)  # log for softplus
        self.L_head.bias = nn.Parameter(bias)
        
        # Indices for building L matrix
        self.register_buffer('tril_indices', torch.tril_indices(7, 7))
    
    def forward(self, x):
        batch = x.shape[0]
        
        # Fourier features
        proj = 2 * math.pi * x @ self.B
        fourier = torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
        
        # MLP
        h = self.backbone(torch.cat([x, fourier], dim=-1))
        L_flat = self.L_head(h)
        
        # Build lower triangular L
        L = torch.zeros(batch, 7, 7, device=x.device)
        L[:, self.tril_indices[0], self.tril_indices[1]] = L_flat
        
        # Positive diagonal via softplus
        diag_idx = torch.arange(7, device=x.device)
        L[:, diag_idx, diag_idx] = torch.nn.functional.softplus(L[:, diag_idx, diag_idx]) + 0.1
        
        # g = L @ L^T (SPD by construction)
        g = L @ L.transpose(-1, -2)
        
        return g, L

# Test
model = MetricNetwork().to(device)
x_test = torch.randn(100, 7, device=device)
g_test, L_test = model(x_test)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"g shape: {g_test.shape}")
print(f"det(g) mean: {torch.det(g_test).mean().item():.4f}")
print(f"min eigenvalue: {torch.linalg.eigvalsh(g_test).min().item():.4f}")

In [None]:
def compute_losses(model, x):
    """Compute all loss components."""
    g, L = model(x)
    batch = x.shape[0]
    
    # 1. Determinant constraint: det(g) = 65/32
    # det(g) = det(L)^2, det(L) = prod(diagonal)
    diag = torch.diagonal(L, dim1=-2, dim2=-1)
    log_det_L = torch.log(diag).sum(dim=-1)
    log_det_g = 2 * log_det_L
    det_g = torch.exp(log_det_g)
    
    det_loss = ((det_g - DET_TARGET) ** 2).mean()
    
    # 2. Smoothness: penalize spatial gradients of g
    eps = 1e-3
    grad_loss = torch.zeros(1, device=x.device)
    for i in range(DIM):
        x_p = x.clone()
        x_p[:, i] += eps
        g_p, _ = model(x_p)
        grad_loss = grad_loss + ((g_p - g) ** 2).mean()
    grad_loss = grad_loss / DIM
    
    # 3. Torsion proxy: variation of metric ~ kappa_T
    # Simplified: we want some non-zero but small variation
    torsion = torch.sqrt(grad_loss + 1e-10)
    torsion_loss = ((torsion - KAPPA_TARGET) ** 2)
    
    return {
        'det': det_loss,
        'det_val': det_g.mean().item(),
        'smooth': grad_loss.squeeze(),
        'torsion': torsion_loss.squeeze(),
        'torsion_val': torsion.item(),
        'min_eig': torch.linalg.eigvalsh(g).min().item(),
    }

# Test
losses = compute_losses(model, x_test)
print("Initial losses:")
for k, v in losses.items():
    if isinstance(v, torch.Tensor):
        print(f"  {k}: {v.item():.6f}")
    else:
        print(f"  {k}: {v:.6f}")

In [None]:
def train(model, n_epochs=3000, batch_size=1024, lr=1e-3):
    """Train with phased loss weights."""
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs)
    
    history = {'loss': [], 'det': [], 'torsion': []}
    
    for epoch in range(n_epochs):
        # Sample random coordinates in [-1, 1]^7
        x = 2 * torch.rand(batch_size, DIM, device=device) - 1
        
        losses = compute_losses(model, x)
        
        # Phase-dependent weights
        if epoch < n_epochs // 3:
            # Phase 1: Focus on det constraint
            w_det, w_smooth, w_torsion = 10.0, 0.1, 0.1
        elif epoch < 2 * n_epochs // 3:
            # Phase 2: Balance det and smoothness
            w_det, w_smooth, w_torsion = 5.0, 1.0, 1.0
        else:
            # Phase 3: Fine-tune torsion
            w_det, w_smooth, w_torsion = 3.0, 0.5, 3.0
        
        total_loss = (w_det * losses['det'] + 
                      w_smooth * losses['smooth'] + 
                      w_torsion * losses['torsion'])
        
        optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        history['loss'].append(total_loss.item())
        history['det'].append(losses['det_val'])
        history['torsion'].append(losses['torsion_val'])
        
        if epoch % 500 == 0:
            det_err = 100 * abs(losses['det_val'] - DET_TARGET) / DET_TARGET
            print(f"[{epoch:4d}] L={total_loss.item():.4f} | "
                  f"det={losses['det_val']:.4f} ({det_err:.1f}%) | "
                  f"kappa={losses['torsion_val']:.4f} | "
                  f"min_eig={losses['min_eig']:.4f}")
    
    return history

print("Training function ready.")

In [None]:
# Train!
print("="*60)
print("GIFT v2.2 Metric Learning")
print("="*60)
print(f"Target: det(g) = {DET_TARGET}")

model = MetricNetwork().to(device)
history = train(model, n_epochs=3000)

In [None]:
# Plot training
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

axes[0].semilogy(history['loss'])
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Total Loss')
axes[0].grid(True, alpha=0.3)

axes[1].plot(history['det'])
axes[1].axhline(DET_TARGET, color='r', linestyle='--', label=f'Target: {DET_TARGET}')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('det(g)')
axes[1].set_title('Determinant')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

axes[2].plot(history['torsion'])
axes[2].axhline(KAPPA_TARGET, color='r', linestyle='--', label=f'Target: 1/61')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Torsion proxy')
axes[2].set_title('Torsion')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Validation
model.eval()
n_val = 5000

with torch.no_grad():
    x = 2 * torch.rand(n_val, DIM, device=device) - 1
    g, L = model(x)
    
    # Compute det via L diagonal
    diag = torch.diagonal(L, dim1=-2, dim2=-1)
    det_g = (diag ** 2).prod(dim=-1)
    
    # Eigenvalues
    eigs = torch.linalg.eigvalsh(g)

det_np = det_g.cpu().numpy()
min_eig = eigs[:, 0].cpu().numpy()

print("="*60)
print("VALIDATION RESULTS")
print("="*60)
print(f"")
print(f"det(g):")
print(f"  Mean:   {det_np.mean():.6f}")
print(f"  Std:    {det_np.std():.6f}")
print(f"  Target: {DET_TARGET:.6f}")
print(f"  Error:  {100*abs(det_np.mean() - DET_TARGET)/DET_TARGET:.2f}%")
print(f"")
print(f"Positivity (g > 0):")
print(f"  Min eigenvalue: {min_eig.mean():.6f} +/- {min_eig.std():.6f}")
print(f"  All positive:   {100*(min_eig > 0).mean():.1f}%")
print(f"")
print(f"Eigenvalue spectrum:")
for i in range(7):
    e = eigs[:, i].cpu().numpy()
    print(f"  lambda_{i}: {e.mean():.4f} +/- {e.std():.4f}")

In [None]:
# Distributions
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].hist(det_np, bins=50, density=True, alpha=0.7, color='blue')
axes[0].axvline(DET_TARGET, color='r', linestyle='--', linewidth=2, label=f'Target: {DET_TARGET}')
axes[0].axvline(det_np.mean(), color='g', linestyle='--', linewidth=2, label=f'Mean: {det_np.mean():.4f}')
axes[0].set_xlabel('det(g)')
axes[0].set_ylabel('Density')
axes[0].set_title('Metric Determinant Distribution')
axes[0].legend()

axes[1].hist(min_eig, bins=50, density=True, alpha=0.7, color='green')
axes[1].axvline(0, color='r', linestyle='--', linewidth=2)
axes[1].set_xlabel('Minimum eigenvalue')
axes[1].set_ylabel('Density')
axes[1].set_title('Positivity Check (all should be > 0)')

plt.tight_layout()
plt.show()

In [None]:
# Final summary
det_err = 100 * abs(det_np.mean() - DET_TARGET) / DET_TARGET
pos_ok = (min_eig > 0).mean() > 0.99

print("="*60)
print("GIFT v2.2 METRIC LEARNING - SUMMARY")
print("="*60)
print(f"")
print(f"Constraints:")
print(f"  det(g) = {det_np.mean():.4f} (target: {DET_TARGET}, error: {det_err:.1f}%)")
print(f"  g > 0:  {'YES' if pos_ok else 'NO'} (min eig = {min_eig.min():.4f})")
print(f"")
print(f"GIFT v2.2 values used:")
print(f"  det(g) = 65/32 = {DET_TARGET}")
print(f"  kappa_T = 1/61 = {KAPPA_TARGET:.6f}")
print(f"")
status = 'PASSED' if det_err < 1.0 and pos_ok else 'NEEDS TUNING'
print(f"Status: {status}")