In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import linalg
from Mesh import Mesh

import matplotlib.pyplot as plt

# Convert to torch tensors (double precision for better numerical stability)
torch.set_default_dtype(torch.double)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

m = Mesh('data/coil_1.2_MM.obj')

centroid = m.verts.mean(0)
std_max = m.verts.std(0).max()

verts_new = (m.verts - centroid)/std_max

m = Mesh(verts = verts_new, connectivity = m.connectivity)

print('Computing Laplacian')
K, M = m.computeLaplacian()

# following Finite Elements methodology 
# K is stiffness matrix, M is mass matrix
# The problem to solve becomes 
# K*u = lambda * M*u
print('Computing eigen values')
eigvals, eigvecs = linalg.eigh(K,M)

# send all relevant numpy arrays to torch tensors
K = torch.from_numpy(K).to(device)
M = torch.from_numpy(M).to(device)
X = torch.from_numpy(m.verts).to(device)
N = X.shape[0]


# 1. Start with stronger regularization
epsilon = 1e-4

# Apply the regularization
K_reg = K + epsilon * torch.eye(N, device=device)
print(f"\nUsing epsilon={epsilon}, final condition number: {torch.linalg.cond(K_reg).item():.2e}")

print("\n=== Matrix Normalization ===")
K_scale = torch.norm(K_reg, p='fro')
M_scale = torch.norm(M, p='fro')

K = K_reg / K_scale
M = M / M_scale

# in the paper we used 50 eigenvalues so set k to 50
k = 50

# Build the neural network that maps coordinates -> k outputs per node
class MLP(nn.Module):
    def __init__(self, in_dim=3, out_dim=k, hidden=[256, 256]):
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden:
            layers.append(nn.Linear(last, h, dtype=torch.double))
            layers.append(nn.SiLU())
            last = h
        layers.append(nn.Linear(last, out_dim, dtype=torch.double))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)  # returns (N, k)

# Instantiate model and optimizer
model = MLP().to(device)
# Initialize all layers (Xavier), final layer small
for name, p in model.named_parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
for p in model.net[-1].parameters():  # last Linear
    if p.ndim == 2:
        nn.init.normal_(p, std=1e-2)
    else:
        nn.init.zeros_(p)

# --- Training loop ---
max_epochs = 100_000
print_every = 1_000
loss_history = []

lr_start = 0.05
lr_end = 0.001

optimizer = optim.AdamW(model.parameters(), lr=lr_start, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=30000, T_mult=2, eta_min=lr_end)

for epoch in range(1, max_epochs+1):
    optimizer.zero_grad()
    
    # Predict eigenvectors
    U = model(X)  # N x k

    # --- Column-wise normalization w.r.t M ---
    col_norms = torch.sqrt(torch.diagonal(U.T @ M @ U))  # k
    U_normed = U / col_norms.unsqueeze(0)                # normalize each eigenvector
    
    # --- Compute matrices ---
    UMU = U_normed.T @ M @ U_normed
    UKU = U_normed.T @ K @ U_normed

    # --- Eigenvalue loss ---
    eigenvalues_approx = torch.diag(UKU)
    sorted_eigs, _ = torch.sort(eigenvalues_approx)
    zero_eig_loss = sorted_eigs[0] ** 2
    eig_loss_trace = torch.sum(sorted_eigs[1:])
    
    gaps = sorted_eigs[1:] - sorted_eigs[:-1]
    min_gap = 1e-4
    diversity_loss = torch.sum(torch.relu(min_gap - gaps))
    
    off_diag_mask = 1 - torch.eye(k, device=device, dtype=torch.float64)
    eig_loss_offdiag = torch.sum((UKU * off_diag_mask)**2)
    
    eig_loss = zero_eig_loss + eig_loss_trace + diversity_loss + eig_loss_offdiag

    # --- Orthogonality loss ---
    orth_loss = torch.norm(UMU - torch.eye(k, device=device), p='fro')**2

 
    loss = eig_loss +  orth_loss

    # --- Backprop ---
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)  # optional gradient clipping
    optimizer.step()
    scheduler.step()
    
    loss_history.append(loss.item())

    if epoch % print_every == 0 or epoch == 1:
        print(f"Epoch {epoch}, total_loss={loss.item():.6f}, eig_loss={eig_loss.item():.6f}, "
              f"orth_loss={orth_loss.item():.6f})")


np.set_printoptions(suppress=True, precision=6)

# ==== Final result ====
with torch.no_grad():
    U_final = model(X)
    UKU = U_final.T @ (K @ U_final)
    final_eigs = torch.diag(UKU).cpu().numpy() * (K_scale / M_scale).cpu().numpy()
    final_eigs.sort()

    # FINAL EVALUATION
    abs_error = np.abs(final_eigs - eigvals[:k])
    rel_error = abs_error / (np.abs(eigvals[:k]) + 1e-10)
    #print("\nLearned Ritz values (from U^T K U):", np.round(mu[:5], 6))
    #print("Reference eigenvalues (first k):   ", np.round(eigvals[:5], 6))
    print("=" * 80)
    print("FINAL RESULTS")
    print("=" * 80)
    
    # Eigenvalue comparison
    print(f"\nEigenvalue Comparison (first 10 modes):")
    print(f"{'Mode':<6} {'Predicted':<12} {'Reference':<12} {'Abs Error':<12} {'Rel Error':<12}")
    print("-" * 66)
    for i in range(min(10, k)):
        print(f"{i+1:<6} {final_eigs[i]:<12.6f} {eigvals[i]:<12.6f} "
              f"{abs_error[i]:<12.6f} {rel_error[i]:<12.4%}")
    
    print(f"\nEigenvalue Comparison (last 10 modes):")
    print(f"{'Mode':<6} {'Predicted':<12} {'Reference':<12} {'Abs Error':<12} {'Rel Error':<12}")
    print("-" * 66)
    for i in range(max(0, k-10), k):
        print(f"{i+1:<6} {final_eigs[i]:<12.6f} {eigvals[i]:<12.6f} "
              f"{abs_error[i]:<12.6f} {rel_error[i]:<12.4%}")
    
    # Overall statistics
    print(f"\nOverall Statistics (all {k} modes):")
    print(f"  Mean Absolute Error:   {np.mean(abs_error):.6f}")
    print(f"  Mean Relative Error:   {np.mean(rel_error):.4%}")
    print(f"  Median Relative Error: {np.median(rel_error):.4%}")
    print(f"  Max Relative Error:    {np.max(rel_error):.4%}")
    print(f"  Modes with <5% error:  {np.sum(rel_error < 0.05)}/{k}")
    print(f"  Modes with <10% error: {np.sum(rel_error < 0.10)}/{k}")
    
    print("\n" + "=" * 80)

Computing Laplacian
Computing eigen values

Using epsilon=0.0001, final condition number: 1.47e+05

=== Matrix Normalization ===
Epoch 1, total_loss=1075.187125, eig_loss=0.077544, orth_loss=1075.109581)
Epoch 1000, total_loss=2.311047, eig_loss=2.205674, orth_loss=0.105373)
Epoch 2000, total_loss=2.163453, eig_loss=2.069941, orth_loss=0.093511)
Epoch 3000, total_loss=2.052492, eig_loss=1.997693, orth_loss=0.054800)
Epoch 4000, total_loss=2.031165, eig_loss=1.937679, orth_loss=0.093486)
Epoch 5000, total_loss=1.907106, eig_loss=1.860658, orth_loss=0.046448)
Epoch 6000, total_loss=1.886454, eig_loss=1.808165, orth_loss=0.078289)
Epoch 7000, total_loss=1.818474, eig_loss=1.773256, orth_loss=0.045218)
Epoch 8000, total_loss=1.790870, eig_loss=1.749750, orth_loss=0.041120)
Epoch 9000, total_loss=1.771857, eig_loss=1.726291, orth_loss=0.045565)
Epoch 10000, total_loss=1.751434, eig_loss=1.711365, orth_loss=0.040069)
Epoch 11000, total_loss=1.733853, eig_loss=1.702562, orth_loss=0.031291)
Ep