In [None]:
!pip install -q torch numpy matplotlib tqdm
import torch, torch.nn as nn, torch.nn.functional as F, numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
class CurvedLaplacian:
    def __init__(self, metric_fn): self.metric_fn = metric_fn
    def __call__(self, f_net, x):
        x = x.requires_grad_(True)
        batch_size, dim = x.shape
        g_diag = self.metric_fn(x)
        g_inv = 1.0 / (g_diag + 1e-10)
        sqrt_det_g = torch.sqrt(torch.prod(g_diag, dim=1, keepdim=True))
        f = f_net(x).squeeze(-1)
        grad_f = torch.autograd.grad(f.sum(), x, create_graph=True)[0]
        laplacian = torch.zeros(batch_size, device=x.device)
        for i in range(dim):
            coeff = sqrt_det_g.squeeze() * g_inv[:, i]
            flux = coeff * grad_f[:, i]
            div_flux = torch.autograd.grad(flux.sum(), x, create_graph=True)[0][:, i]
            laplacian = laplacian + div_flux / (sqrt_det_g.squeeze() + 1e-10)
        return laplacian

In [None]:
class G2TCSMetric:
    def __init__(self, b2, b3):
        self.b2, self.b3 = b2, b3
        self.H_star = b2 + b3 + 1
        self.dim = 7
        self.T = np.sqrt(self.H_star)
        self.T0 = self.T / 3
        x = torch.rand(100000, 7, device=device)
        x[:, 0] *= self.T
        h_val = torch.cosh((x[:, 0] - self.T/2) / self.T0)
        vol_raw = torch.mean(h_val**6).item() * self.T
        self.vol_factor = vol_raw**(-1/7)
    def h(self, t): return torch.cosh((t - self.T/2) / self.T0)
    def metric(self, x):
        h_val = self.h(x[:, 0])
        g = torch.ones(x.shape[0], 7, device=x.device)
        g[:, 1:] = (h_val**2).unsqueeze(1).expand(-1, 6)
        return g * self.vol_factor**2
    def sqrt_det_g(self, x): return self.h(x[:, 0])**6 * self.vol_factor**7
    def sample(self, n):
        x = torch.rand(n, 7, device=device)
        x[:, 0] *= self.T
        return x

In [None]:
class EigenfunctionNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.B = nn.Parameter(torch.randn(7, 64) * 2.0, requires_grad=False)
        self.net = nn.Sequential(
            nn.Linear(7 + 128, 256), nn.GELU(),
            nn.Linear(256, 256), nn.GELU(),
            nn.Linear(256, 256), nn.GELU(),
            nn.Linear(256, 256), nn.GELU(),
            nn.Linear(256, 1))
        self.log_lambda = nn.Parameter(torch.tensor(0.0))  # starts at 1.0!
    def forward(self, x):
        proj = x @ self.B
        h = torch.cat([x, torch.sin(proj), torch.cos(proj)], dim=-1)
        return self.net(h)
    def get_lambda(self): return torch.exp(self.log_lambda).item()

In [None]:
def train(metric, n_epochs=4000, seed=42):
    torch.manual_seed(seed)
    net = EigenfunctionNet().to(device)
    lap = CurvedLaplacian(metric.metric)
    opt = torch.optim.AdamW(net.parameters(), lr=5e-4)
    hist = []
    best_lam = float('inf')
    for ep in tqdm(range(n_epochs)):
        x = metric.sample(10000)
        sg = metric.sqrt_det_g(x)
        f = net(x).squeeze(-1)
        lam = torch.exp(net.log_lambda)
        lap_f = lap(net, x)
        loss_pde = torch.mean((lap_f + lam*f)**2 * sg) / torch.mean(sg)
        f_mean = torch.sum(f*sg)/torch.sum(sg)
        f_norm = torch.sum(f**2*sg)/torch.sum(sg)
        loss = loss_pde + 100*f_mean**2 + 10*(f_norm-1)**2
        opt.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
        opt.step()
        cur = net.get_lambda()
        hist.append(cur)
        if ep > 200 and cur < best_lam: best_lam = cur
    return hist, best_lam

In [None]:
# MAIN TEST
print('='*60)
print('UNBIASED TEST: lambda starts at 1.0, NOT 14/H*')
print('='*60)
manifolds = [
    ('K7', 21, 77),
    ('J1', 12, 43),
    ('J4', 0, 103),
    ('Kov', 0, 71)]
results = []
for name, b2, b3 in manifolds:
    H = b2 + b3 + 1
    pred = 14.0 / H
    print(f'\n{name}: H*={H}, 14/H*={pred:.6f}')
    metric = G2TCSMetric(b2, b3)
    mins = []
    for seed in [42, 123, 456]:
        hist, best = train(metric, seed=seed)
        mins.append(best)
        print(f'  seed {seed}: min={best:.6f}')
    avg = np.mean(mins)
    results.append((name, H, pred, avg, hist))
    print(f'  AVG min: {avg:.6f}, product={avg*H:.4f}')

In [None]:
# SUMMARY
print('\n' + '='*60)
print('SUMMARY: lambda_min * H* = ?')
print('='*60)
prods = []
for name, H, pred, avg, _ in results:
    p = avg * H
    prods.append(p)
    print(f'{name}: {avg:.6f} * {H} = {p:.4f}')
print(f'\nMean: {np.mean(prods):.4f} (target: 14)')
print(f'Std:  {np.std(prods):.4f}')

In [None]:
# PLOT
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
for ax, (name, H, pred, avg, hist) in zip(axes.flat, results):
    ax.plot(hist, 'b-', alpha=0.7)
    ax.axhline(pred, color='r', ls='--', label=f'14/H*={pred:.4f}')
    ax.axhline(avg, color='g', ls=':', label=f'min={avg:.4f}')
    ax.axhline(1.0, color='gray', ls='--', alpha=0.3)
    ax.set_title(f'{name} (H*={H})')
    ax.legend()
    ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('unbiased.png', dpi=150)
plt.show()