# K7 Metric Training with G₂ Holonomy - Google Colab


In [None]:
# Check GPU
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB')
else:
    print(' NO GPU! Go to Runtime → Change runtime type → GPU')

In [None]:
!pip install -q matplotlib scipy pandas

In [None]:
import torch
import torch.nn as nn
import numpy as np

class CompactG2Network(nn.Module):
    '''Compact network for Colab - optimized for T4 memory.'''
    def __init__(self, hidden_dims=[256,256,128], num_freq=32):
        super().__init__()
        self.register_buffer('B', torch.randn(7, num_freq) * 2.0)
        layers = []
        prev = 2 * num_freq
        for h in hidden_dims:
            layers += [nn.Linear(prev, h), nn.SiLU(), nn.LayerNorm(h)]
            prev = h
        layers.append(nn.Linear(prev, 28))  # Upper tri of 7x7
        self.mlp = nn.Sequential(*layers)
        with torch.no_grad():
            self.mlp[-1].weight.mul_(0.01)
            self.mlp[-1].bias.zero_()

    def forward(self, coords):
        x = 2*np.pi * coords @ self.B
        x = torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
        upper = self.mlp(x)
        batch = coords.shape[0]
        metric = torch.zeros(batch, 7, 7, device=coords.device)
        idx = 0
        for i in range(7):
            for j in range(i,7):
                if i==j:
                    metric[:,i,j] = torch.nn.functional.softplus(upper[:,idx]) + 0.1
                else:
                    metric[:,i,j] = metric[:,j,i] = upper[:,idx] * 0.1
                idx += 1
        return metric + torch.eye(7, device=coords.device).unsqueeze(0)

print('✓ Network defined')

In [None]:
def ricci_loss_fast(metric, coords):
    '''Fast Ricci loss - simplified for speed.'''
    batch = metric.shape[0]
    device = metric.device
    metric_inv = torch.linalg.inv(metric)
    ricci = torch.zeros(batch, 7, 7, device=device)
    for i in range(7):
        grad_i = torch.autograd.grad(
            metric[:,:,:].sum(), coords,
            create_graph=True, retain_graph=True
        )[0]
        for j in range(7):
            ricci[:,i,j] = torch.sum(metric_inv[:,i,:] * grad_i[:,j].unsqueeze(-1))
    return ricci

print('✓ Ricci computation defined')

In [None]:
class Trainer:
    def __init__(self, device='cuda'):
        self.device = device
        self.model = CompactG2Network().to(device)
        self.opt = torch.optim.AdamW(self.model.parameters(), lr=1e-4, weight_decay=1e-4)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.opt, T_0=500, eta_min=1e-7)
        self.history = {'epoch':[], 'loss':[], 'ricci':[]}
    
    def train_epoch(self, epoch, batch_size=512):
        self.model.train()
        coords = torch.randn(batch_size, 7, device=self.device) * 5.0
        coords.requires_grad_(True)
        metric = self.model(coords)
        ricci = ricci_loss_fast(metric, coords)
        ricci_loss = torch.mean(ricci**2)
        reg = torch.mean((metric - torch.eye(7, device=self.device))**2)
        total = (10.0 if epoch > 1000 else 1.0) * ricci_loss + 0.01 * reg
        self.opt.zero_grad()
        total.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.opt.step()
        self.scheduler.step()
        self.history['epoch'].append(epoch)
        self.history['loss'].append(total.item())
        self.history['ricci'].append(ricci_loss.item())
        return total.item(), ricci_loss.item()

print('✓ Trainer defined')

In [None]:
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

device = 'cuda' if torch.cuda.is_available() else 'cpu'
trainer = Trainer(device)
print(f'Device: {device}')
print(f'Parameters: {sum(p.numel() for p in trainer.model.parameters()):,}')
print('\nTraining...\n')

total_epochs = 10000
start = time.time()

for epoch in range(total_epochs):
    loss, ricci = trainer.train_epoch(epoch)
    if epoch % 50 == 0 or epoch == total_epochs-1:
        elapsed = time.time()-start
        eta = elapsed/(epoch+1)*(total_epochs-epoch-1)
        print(f'Epoch {epoch:4d}/{total_epochs} | Loss: {loss:.6e} | Ricci: {ricci:.6e} | {elapsed/60:.1f}min | ETA: {eta/60:.1f}min')
    if epoch % 300 == 0 and epoch > 0:
        clear_output(wait=True)
        fig, ax = plt.subplots(1, 2, figsize=(12,4))
        ax[0].semilogy(trainer.history['epoch'], trainer.history['loss'], 'b-', lw=2)
        ax[0].set_title('Total Loss')
        ax[0].grid(alpha=0.3)
        ax[1].semilogy(trainer.history['epoch'], trainer.history['ricci'], 'g-', lw=2)
        ax[1].set_title('Ricci Loss')
        ax[1].grid(alpha=0.3)
        plt.tight_layout()
        plt.show()
        print(f'Epoch {epoch} | Loss: {loss:.6e} | Ricci: {ricci:.6e}')

print(f'\n✓ Complete! Time: {(time.time()-start)/60:.1f}min')
print(f'Final loss: {trainer.history["loss"][-1]:.6e}')
print(f'Reduction: {trainer.history["loss"][0]/trainer.history["loss"][-1]:.1f}x')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(14,5))
epochs = np.array(trainer.history['epoch'])
losses = np.array(trainer.history['loss'])
ricci = np.array(trainer.history['ricci'])

ax[0].semilogy(epochs, losses, 'b-', lw=2, label='Total')
ax[0].semilogy(epochs, ricci, 'g-', lw=2, alpha=0.7, label='Ricci')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[0].set_title('Training Curves')
ax[0].legend()
ax[0].grid(alpha=0.3)

if len(losses) > 100:
    conv = -np.gradient(np.log(losses+1e-10))
    smooth = np.convolve(conv, np.ones(100)/100, 'valid')
    ax[1].plot(epochs[:len(smooth)], smooth, 'r-', lw=2)
    ax[1].set_xlabel('Epoch')
    ax[1].set_ylabel('Convergence Rate')
    ax[1].set_title('Efficiency')
    ax[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=150)
plt.show()
print('Saved: training_curves.png')

In [None]:
trainer.model.eval()
test_pts = torch.tensor([[0.,0.,0.,0.,0.,0.,0.], [1.,0.,0.,0.,0.,0.,0.], [0.,1.,0.,0.,0.,0.,0.]], device=device)
with torch.no_grad():
    metrics = trainer.model(test_pts)

print('\nLearned Metric:')
for i, pt in enumerate(test_pts):
    g = metrics[i].cpu().numpy()
    eig = np.linalg.eigvalsh(g)
    print(f'\nPoint {i+1}: {pt.cpu().numpy()}')
    print(f'  Diagonal: {np.diag(g)}')
    print(f'  Eigenvalues: {eig}')
    print(f'  Det: {np.linalg.det(g):.6f}')
    print(f'  Condition: {eig.max()/eig.min():.2f}')

In [None]:
ckpt = {
    'model': trainer.model.state_dict(),
    'optimizer': trainer.opt.state_dict(),
    'history': trainer.history,
    'final_loss': trainer.history['loss'][-1]
}
torch.save(ckpt, 'k7_metric_final.pt')
print('✓ Saved: k7_metric_final.pt')

import pandas as pd
pd.DataFrame(trainer.history).to_csv('history.csv', index=False)
print('✓ Saved: history.csv')

print('\n Download files from left panel')