# GIFT v2.2 Variational G2 Analysis

Post-training analysis of the learned G2 geometry.

This notebook validates the trained model against GIFT v2.2 predictions and
visualizes the geometric properties of the learned 3-form.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import yaml

# Import our modules
from src.model import G2VariationalNet, create_model
from src.constraints import (
    metric_from_phi,
    expand_to_antisymmetric,
    phi_norm_squared,
    standard_g2_phi,
)
from src.validation import Validator, generate_validation_report, GIFT_TARGETS
from src.harmonic import extract_betti_numbers, sample_grid_points, CohomologyAnalyzer

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Load Configuration and Model

In [None]:
# Load configuration
config_path = Path('../config/gift_v22.yaml')
with open(config_path) as f:
    config = yaml.safe_load(f)

# Display GIFT v2.2 targets
print("GIFT v2.2 Targets:")
print("="*50)
physics = config['physics']
print(f"  b2 = {physics['b2']}")
print(f"  b3 = {physics['b3']}")
print(f"  det(g) = {physics['det_g_exact']} = {physics['det_g']}")
print(f"  kappa_T = {physics['kappa_T_exact']} = {physics['kappa_T']:.6f}")
print(f"  h* = {physics['h_star']}")

In [None]:
# Create model
model = create_model(config).to(device)

# Try to load trained weights
checkpoint_path = Path('../outputs/checkpoints/final_model.pt')
if checkpoint_path.exists():
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded trained model from epoch {checkpoint['epoch']}")
else:
    print("No trained model found - using randomly initialized model")
    print("Run training first with: python -m src.training")

model.eval()
print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")

## 2. Validate Against GIFT v2.2 Targets

In [None]:
# Run validation
validator = Validator(
    num_samples=5000,
    device=device,
)

results = validator.validate(model, detailed=True)
print(results.summary)

In [None]:
# Generate detailed report
report = generate_validation_report(
    results,
    model_info={
        'Architecture': 'G2VariationalNet',
        'Parameters': sum(p.numel() for p in model.parameters()),
        'Device': str(device),
    }
)

# Save report
report_path = Path('../outputs/metrics/validation_report.md')
report_path.parent.mkdir(parents=True, exist_ok=True)
with open(report_path, 'w') as f:
    f.write(report)
print(f"Report saved to {report_path}")

## 3. Visualize Metric Properties

In [None]:
# Sample points and compute metrics
with torch.no_grad():
    points = sample_grid_points(1000, device=device)
    output = model(points, return_full=True, return_metric=True)
    phi = output['phi_full']
    metric = output['metric']
    
    # Compute determinants
    det_g = torch.det(metric).cpu().numpy()
    
    # Compute eigenvalues
    eigenvalues = torch.linalg.eigvalsh(metric).cpu().numpy()

# Plot determinant distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Determinant histogram
ax = axes[0]
ax.hist(det_g, bins=50, density=True, alpha=0.7, color='blue')
ax.axvline(65/32, color='red', linestyle='--', linewidth=2, label=f'Target: 65/32 = {65/32:.4f}')
ax.axvline(det_g.mean(), color='green', linestyle='-', linewidth=2, label=f'Mean: {det_g.mean():.4f}')
ax.set_xlabel('det(g)')
ax.set_ylabel('Density')
ax.set_title('Metric Determinant Distribution')
ax.legend()

# Eigenvalue distribution
ax = axes[1]
for i in range(7):
    ax.hist(eigenvalues[:, i], bins=30, alpha=0.5, label=f'λ_{i+1}')
ax.axvline(0, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Eigenvalue')
ax.set_ylabel('Count')
ax.set_title('Metric Eigenvalue Distribution')
ax.legend(fontsize=8)

# Min eigenvalue histogram
ax = axes[2]
min_eig = eigenvalues.min(axis=1)
ax.hist(min_eig, bins=50, density=True, alpha=0.7, color='orange')
ax.axvline(0, color='red', linestyle='--', linewidth=2, label='Positivity threshold')
ax.set_xlabel('Minimum Eigenvalue')
ax.set_ylabel('Density')
ax.set_title('Minimum Eigenvalue (Positivity Check)')
ax.legend()

plt.tight_layout()
plt.savefig('../outputs/artifacts/metric_analysis.png', dpi=150)
plt.show()

## 4. Analyze Phi Norm (G2 Identity)

In [None]:
# Compute phi norm squared
with torch.no_grad():
    norm_sq = phi_norm_squared(phi).cpu().numpy()

# Plot
fig, ax = plt.subplots(figsize=(8, 5))
ax.hist(norm_sq, bins=50, density=True, alpha=0.7, color='purple')
ax.axvline(7.0, color='red', linestyle='--', linewidth=2, label='Target: ||φ||² = 7')
ax.axvline(norm_sq.mean(), color='green', linestyle='-', linewidth=2, label=f'Mean: {norm_sq.mean():.4f}')
ax.set_xlabel('||φ||²_g')
ax.set_ylabel('Density')
ax.set_title('G2 Identity: ||φ||²_g = 7')
ax.legend()
plt.tight_layout()
plt.savefig('../outputs/artifacts/phi_norm_analysis.png', dpi=150)
plt.show()

print(f"Phi norm squared: {norm_sq.mean():.6f} ± {norm_sq.std():.6f}")
print(f"Target: 7.0")
print(f"Error: {abs(norm_sq.mean() - 7.0):.6f} ({abs(norm_sq.mean() - 7.0)/7.0*100:.3f}%)")

## 5. Cohomology Analysis

In [None]:
# Run cohomology analysis
analyzer = CohomologyAnalyzer(target_b2=21, target_b3=77)
coh_results = analyzer.analyze(model, num_samples=2000, device=device)

print("Cohomology Analysis Results:")
print("="*50)
print(f"b2: {coh_results['b2']['effective']} (target: {coh_results['b2']['target']})")
print(f"b3: {coh_results['b3']['effective']} (target: {coh_results['b3']['target']})")
print(f"  Local component: {coh_results['b3']['local_component']} (expected: 35)")
print(f"  Global component: {coh_results['b3']['global_component']} (expected: 42)")
print(f"h*: {coh_results['h_star']['effective']} (target: {coh_results['h_star']['target']})")
print("="*50)
print(f"Cohomology match: {coh_results['summary']['cohomology_match']}")

## 6. Compare with Standard G2 Form

In [None]:
# Standard G2 3-form
standard_phi = standard_g2_phi(device=device)
standard_full = expand_to_antisymmetric(standard_phi.unsqueeze(0))
standard_metric = metric_from_phi(standard_full)

print("Standard G2 Form Properties:")
print("="*50)
print(f"det(g_standard) = {torch.det(standard_metric).item():.6f} (should be 1.0)")
print(f"||φ_standard||² = {phi_norm_squared(standard_full).item():.6f} (should be 7.0)")
print()

# Compare learned phi with standard
with torch.no_grad():
    # Sample phi at origin
    origin = torch.zeros(1, 7, device=device)
    learned_output = model(origin, return_full=True)
    learned_phi = learned_output['phi_components'][0]
    
    # Compute deviation
    deviation = (learned_phi - standard_phi).norm().item()
    relative_dev = deviation / standard_phi.norm().item()
    
print(f"Learned φ deviation from standard: {deviation:.6f}")
print(f"Relative deviation: {relative_dev*100:.2f}%")

## 7. Visualize 3-Form Components

In [None]:
# Sample along a line
t = torch.linspace(-1, 1, 100, device=device)
line_points = torch.zeros(100, 7, device=device)
line_points[:, 0] = t  # Vary first coordinate

with torch.no_grad():
    output = model(line_points)
    phi_components = output['phi_components'].cpu().numpy()

# Plot first few components
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
t_np = t.cpu().numpy()

component_labels = [
    'φ_{012}', 'φ_{013}', 'φ_{014}',
    'φ_{015}', 'φ_{016}', 'φ_{023}'
]

for i, ax in enumerate(axes.flat):
    ax.plot(t_np, phi_components[:, i], 'b-', linewidth=2)
    ax.set_xlabel('x₁')
    ax.set_ylabel(component_labels[i])
    ax.set_title(f'Component {component_labels[i]}')
    ax.grid(True, alpha=0.3)

plt.suptitle('3-Form Components Along x₁ Axis', fontsize=14)
plt.tight_layout()
plt.savefig('../outputs/artifacts/phi_components.png', dpi=150)
plt.show()

## 8. Save Artifacts

In [None]:
# Save phi and metric on a grid
with torch.no_grad():
    grid_points = sample_grid_points(10000, device=device)
    output = model(grid_points, return_full=True, return_metric=True)
    
    artifacts = {
        'points': grid_points.cpu().numpy(),
        'phi_components': output['phi_components'].cpu().numpy(),
        'metric': output['metric'].cpu().numpy(),
        'det_g': torch.det(output['metric']).cpu().numpy(),
    }

# Save as numpy archive
np.savez_compressed(
    '../outputs/artifacts/g2_geometry.npz',
    **artifacts
)

print("Saved artifacts:")
print(f"  Points: {artifacts['points'].shape}")
print(f"  Phi components: {artifacts['phi_components'].shape}")
print(f"  Metric: {artifacts['metric'].shape}")
print(f"  det(g): {artifacts['det_g'].shape}")

## 9. Summary

In [None]:
print("="*60)
print("GIFT v2.2 VARIATIONAL G2 ANALYSIS SUMMARY")
print("="*60)
print()
print("Target Metrics:")
print(f"  det(g) = 65/32 = {65/32:.6f}")
print(f"  κ_T = 1/61 = {1/61:.6f}")
print(f"  b₂ = 21, b₃ = 77")
print(f"  ||φ||²_g = 7")
print()
print("Achieved Metrics:")
print(f"  det(g) = {det_g.mean():.6f} ± {det_g.std():.6f}")
print(f"  ||φ||²_g = {norm_sq.mean():.6f} ± {norm_sq.std():.6f}")
print(f"  b₂_eff = {coh_results['b2']['effective']}")
print(f"  b₃_eff = {coh_results['b3']['effective']}")
print(f"  g positive: {(min_eig > 0).all()}")
print()
print(f"Overall validation: {'PASSED' if results.all_passed else 'FAILED'}")
print("="*60)