In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import linalg
from Mesh import Mesh

import matplotlib.pyplot as plt


# Convert to torch tensors (double precision for better numerical stability)
torch.set_default_dtype(torch.double)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

m = Mesh('data/coil_1.2_MM.obj')

centroid = m.verts.mean(0)
std_max = m.verts.std(0).max()

verts_new = (m.verts - centroid)/std_max

m = Mesh(verts = verts_new, connectivity = m.connectivity)

print('Computing Laplacian')
K, M = m.computeLaplacian()

# following Finite Elements methodology 
# K is stiffness matrix, M is mass matrix
# The problem to solve becomes 
# K*u = lambda * M*u
print('Computing eigen values')
eigvals, eigvecs = linalg.eigh(K,M)


# send all relevant numpy arrays to torch tensors
K = torch.from_numpy(K).to(device)
M = torch.from_numpy(M).to(device)
X = torch.from_numpy(m.verts).to(device)


# in the paper we used 50 eigenvalues so set k to 50
k = 50

N = X.shape[0]

# Before training, verify:
print(f"N = {N}, k = {k}, ratio = {N/k:.1f}")
print(f"Condition number of K: {torch.linalg.cond(K).item():.2e}")
print(f"Condition number of M: {torch.linalg.cond(M).item():.2e}")
print(f"Target eigenvalue range: [{eigvals[0]:.4f}, {eigvals[k-1]:.4f}]")

# If eigenvalues span many orders of magnitude, normalize:
if eigvals[k-1] / eigvals[0] > 100:
    print("WARNING: Large eigenvalue spread, consider normalization")

Computing Laplacian
Computing eigen values
N = 1546, k = 50, ratio = 30.9
Condition number of K: 2.56e+16
Condition number of M: 3.70e+02
Target eigenvalue range: [0.0000, 7.8346]


In [2]:
# --- CRITICAL: Fix the ill-conditioning ---

# 1. Start with stronger regularization
epsilon = 1e-4

# Apply the regularization
K_reg = K + epsilon * torch.eye(N, device=device)
print(f"\nUsing epsilon={epsilon}, final condition number: {torch.linalg.cond(K_reg).item():.2e}")

print("\n=== Matrix Normalization ===")
K_scale = torch.norm(K_reg, p='fro')
M_scale = torch.norm(M, p='fro')

K = K_reg / K_scale
M = M / M_scale

print(f"K Frobenius norm (normalized): {torch.norm(K, p='fro').item():.4f}")
print(f"M Frobenius norm (normalized): {torch.norm(M, p='fro').item():.4f}")
print(f"Normalization: K_scale={K_scale.item():.2e}, M_scale={M_scale.item():.2e}")

# Keep reference eigenvalues as-is (original scale)
# We'll scale the predicted eigenvalues instead when comparing
print(f"Target eigenvalue range (original): [{eigvals[0]:.6f}, {eigvals[k-1]:.6f}]")


Using epsilon=0.0001, final condition number: 1.47e+05

=== Matrix Normalization ===
K Frobenius norm (normalized): 1.0000
M Frobenius norm (normalized): 1.0000
Normalization: K_scale=1.58e+02, M_scale=1.22e+00
Target eigenvalue range (original): [0.000000, 7.834566]


In [3]:
print("\n=== Matrix Diagnostics ===")

# Basic checks
print(f"K has NaN: {torch.isnan(K_reg).any()}")
print(f"K has Inf: {torch.isinf(K_reg).any()}")
print(f"K is symmetric: {torch.allclose(K_reg, K_reg.T, atol=1e-6)}")
print(f"M is positive definite: {torch.all(torch.linalg.eigvalsh(M) > 0)}")

# Matrix-vector multiply check
v = torch.randn(N, device=device, dtype=torch.float64)
Kv = K_reg @ v
print(f"||K*v|| / ||v|| = {torch.norm(Kv) / torch.norm(v):.2e}")

# Additional diagnostics
print(f"\nK_reg diagonal stats:")
K_diag = torch.diag(K_reg)
print(f"  min: {K_diag.min():.2e}, max: {K_diag.max():.2e}")
print(f"  mean: {K_diag.mean():.2e}, std: {K_diag.std():.2e}")
print(f"  negative entries: {(K_diag < 0).sum().item()}/{N}")

print(f"\nM diagonal stats:")
M_diag = torch.diag(M)
print(f"  min: {M_diag.min():.2e}, max: {M_diag.max():.2e}")
print(f"  mean: {M_diag.mean():.2e}, std: {M_diag.std():.2e}")

print(f"\nMatrix norms:")
print(f"  ||K_reg||_F: {torch.norm(K_reg, p='fro'):.2e}")
print(f"  ||M||_F: {torch.norm(M, p='fro'):.2e}")

# Estimate condition numbers (cheap approximation)
print(f"\nCondition number estimates:")
print(f"  K_reg: {torch.linalg.cond(K_reg).item():.2e}")
print(f"  M: {torch.linalg.cond(M).item():.2e}")

# Reference eigenvalues
print(f"\nReference eigenvalues:")
print(f"  First 10: {eigvals[:10]}")
print(f"  Last 10 of k=50: {eigvals[k-10:k]}")
print(f"  Range: [{eigvals[0]:.6f}, {eigvals[k-1]:.6f}]")

print("\n" + "="*60)


=== Matrix Diagnostics ===
K has NaN: False
K has Inf: False
K is symmetric: True
M is positive definite: True
||K*v|| / ||v|| = 4.13e+00

K_reg diagonal stats:
  min: 2.18e+00, max: 7.60e+00
  mean: 3.63e+00, std: 4.33e-01
  negative entries: 0/1546

M diagonal stats:
  min: 1.80e-04, max: 4.15e-02
  mean: 2.29e-02, std: 5.67e-03

Matrix norms:
  ||K_reg||_F: 1.58e+02
  ||M||_F: 1.00e+00

Condition number estimates:
  K_reg: 1.47e+05
  M: 3.70e+02

Reference eigenvalues:
  First 10: [3.50308915e-13 7.57414444e-03 3.03079128e-02 6.81464805e-02
 1.21207968e-01 1.89242748e-01 2.72231499e-01 3.70535951e-01
 4.83409417e-01 6.11342543e-01]
  Last 10 of k=50: [7.42219306 7.44948415 7.45560059 7.5081038  7.52679082 7.60739344
 7.61427628 7.70810371 7.72485585 7.8345656 ]
  Range: [0.000000, 7.834566]



In [4]:
# ============================================================================
# MODEL DEFINITION
# ============================================================================

class MLP(nn.Module):
    """
    MLP for approximating k=50 eigenmodes.
    Uses 3-layer architecture with increased capacity for k=50.
    """
    def __init__(self, in_dim=3, out_dim=50, hidden=[128, 128, 64]):
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden:
            layers.append(nn.Linear(last, h))
            layers.append(nn.SiLU())
            last = h
        layers.append(nn.Linear(last, out_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# ============================================================================
# MODEL INITIALIZATION
# ============================================================================

model = MLP().double().to(device)

# Xavier initialization for hidden layers, small weights for output layer
for name, p in model.named_parameters():
    if 'weight' in name:
        layer_idx = int(name.split('.')[1])
        if layer_idx < len(model.net) - 1:  # Hidden layers
            nn.init.xavier_uniform_(p.data)
        else:  # Final layer - very small initialization
            nn.init.normal_(p.data, std=1e-4)
    elif 'bias' in name:
        nn.init.zeros_(p.data)

print(f"\nModel: {sum(p.numel() for p in model.parameters()):,} parameters")

# ============================================================================
# TRAINING HYPERPARAMETERS
# ============================================================================

# Dynamic orthogonality weight scheduling
def get_lambda_orth(epoch):
    if epoch < 50000:
        return 1.0      # Strong constraint initially
    elif epoch < 100000:
        return 0.1      # Relax after orthogonality established
    else:
        return 0.01     # Focus on eigenvalues in final phase

# Learning rate and optimizer
lr_start = 0.05
lr_end = 0.0001
max_epochs = 300_000

optimizer = optim.AdamW(model.parameters(), lr=lr_start, weight_decay=1e-5)

# Cosine annealing with warm restarts for better convergence
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=30000, T_mult=2, eta_min=lr_end
)

# Logging
print_every = 2000
loss_history = []
identity_k = torch.eye(k, device=device, dtype=torch.float64)

# ============================================================================
# TRAINING LOOP
# ============================================================================

print("\n=== Starting Training ===")
print(f"Epochs: {max_epochs:,} | Print every: {print_every:,}\n")

for epoch in range(1, max_epochs + 1):
    model.train()
    optimizer.zero_grad()
    
    # --- Forward Pass ---
    U = model(X)  # (N, k) - basis functions at all points
    
    # --- M-Orthogonalization via SVD (differentiable, stable) ---
    # This projects U onto M-orthonormal subspace
    B = U.T @ (M @ U)  # (k, k) - Gram matrix
    V, S, _ = torch.linalg.svd(B)  # SVD of small k×k matrix (cheap!)
    
    # Compute B^(-1/2) for orthonormalization
    S_inv_sqrt = torch.diag_embed(1.0 / torch.sqrt(torch.clamp(S, min=1e-7)))
    B_inv_sqrt = V @ S_inv_sqrt @ V.T
    U_orth = U @ B_inv_sqrt  # M-orthonormalized basis
    
    # --- Loss Computation ---
    
    # 1. EIGENVALUE LOSS: Find the SMALLEST k eigenvalues
    rayleigh_matrix = U_orth.T @ (K @ U_orth)  # (k, k)
    eigenvalues_approx = torch.diag(rayleigh_matrix)
    
    # Sort eigenvalues for diversity penalty
    sorted_eigs, _ = torch.sort(eigenvalues_approx)
    
    # 1a. TRACE LOSS: Minimize sum of eigenvalues (finds smallest modes)
    eig_loss_trace = torch.sum(sorted_eigs) / k
    
    # 1b. DIVERSITY LOSS: Penalize eigenvalues being too similar
    #     Encourage separation: λ_i+1 should be > λ_i
    #     Compute gaps between consecutive eigenvalues
    gaps = sorted_eigs[1:] - sorted_eigs[:-1]
    target_gap = 0.1  # Minimum desired separation (scaled units)
    diversity_loss = torch.sum(torch.relu(target_gap - gaps)) / (k - 1)
    
    # 1c. Off-diagonal penalty: Force Rayleigh matrix to be diagonal
    off_diag_mask = 1 - torch.eye(k, device=device, dtype=torch.float64)
    eig_loss_offdiag = torch.sum((rayleigh_matrix * off_diag_mask)**2) / (k * (k-1))
    
    # Combined: Balance finding small eigenvalues with maintaining diversity
    eig_loss = 5.0 * eig_loss_trace + 10.0 * diversity_loss + eig_loss_offdiag
    
    # 2. ORTHOGONALITY LOSS: Residual check (should be ~0 due to SVD)
    B_orth = U_orth.T @ (M @ U_orth)
    orth_loss = torch.norm(B_orth - identity_k, p='fro')**2
    
    # 3. ORDERING LOSS: Encourage λ_1 ≤ λ_2 ≤ ... ≤ λ_k
    #    (Already sorted above, so this just penalizes inversions)
    ordering_loss = torch.sum(torch.relu(sorted_eigs[:-1] - sorted_eigs[1:])) / k
    
    # 4. STABILITY: Penalize if B becomes ill-conditioned
    # If singular values of B vary too much, SVD orthogonalization becomes unstable
    S_ratio = S.max() / (S.min() + 1e-10)
    stability_loss = torch.relu(S_ratio - 1e3) / 1e3  # Penalize if condition number > 1000
    
    # Dynamic weighting
    lambda_orth = get_lambda_orth(epoch)
    lambda_order = 0.05
    lambda_stability = 0.1
    
    # Total Loss
    loss = eig_loss + lambda_orth * orth_loss + lambda_order * ordering_loss + lambda_stability * stability_loss
    
    # --- Backpropagation ---
    loss.backward()
    
    # Gradient clipping for numerical stability
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    optimizer.step()
    scheduler.step()
    
    # --- Logging ---
    loss_history.append(loss.item())
    
    if epoch % print_every == 0 or epoch == 1:
        model.eval()
        with torch.no_grad():
            # Get current eigenvalue estimates
            approx_eigenvalues = torch.diag(rayleigh_matrix).cpu().numpy()
            approx_eigenvalues.sort()
            
            # Scale back to original units for comparison with reference
            approx_eigenvalues_original = approx_eigenvalues * (K_scale / M_scale).cpu().numpy()
            
            # Compute errors
            abs_error = np.abs(approx_eigenvalues_original[:k] - eigvals[:k])
            rel_error = abs_error / (np.abs(eigvals[:k]) + 1e-10)
            mean_rel_error = np.mean(rel_error)
            
            current_lr = optimizer.param_groups[0]['lr']
            
        print(f"Epoch {epoch:>6} | LR={current_lr:.6f} | λ_orth={lambda_orth:.2f} | "
              f"Loss={loss.item():.6f}")
        print(f"           | Eig={eig_loss.item():.6f} (trace:{eig_loss_trace.item():.3f}, "
              f"div:{diversity_loss.item():.3f}) | Orth={orth_loss.item():.2e}")
        print(f"           | Order={ordering_loss.item():.6f} | Stab={stability_loss.item():.6f} | "
              f"SVD: σ_max/σ_min={S_ratio.item():.2e}")
        print(f"           | MeanRelErr={mean_rel_error:.4f} | "
              f"λ∈[{approx_eigenvalues_original[0]:.4f}, {approx_eigenvalues_original[k-1]:.4f}]")
        print(f"           | λ_spread={(approx_eigenvalues_original[-1] - approx_eigenvalues_original[0]):.4f}")
        
        # Detailed eigenvalue comparison every 10k epochs
        if epoch % (print_every * 5) == 0:
            print(f"           | Predicted (first 5): {approx_eigenvalues_original[:5].round(4)}")
            print(f"           | Reference (first 5): {eigvals[:5].round(4)}")
            print(f"           | Predicted (last  5): {approx_eigenvalues_original[-5:].round(4)}")
            print(f"           | Reference (last  5): {eigvals[k-5:k].round(4)}")
        print()

print("=== Training Complete ===\n")


Model: 28,530 parameters

=== Starting Training ===
Epochs: 300,000 | Print every: 2,000

Epoch      1 | LR=0.050000 | λ_orth=1.00 | Loss=47.936622
           | Eig=1.001236 (trace:0.000, div:0.100) | Orth=4.64e+01
           | Order=0.000000 | Stab=5.629483 | SVD: σ_max/σ_min=6.63e+03
           | MeanRelErr=1236720.8630 | λ∈[0.0062, 0.1159]
           | λ_spread=0.1097

Epoch   2000 | LR=0.049455 | λ_orth=1.00 | Loss=1.153948
           | Eig=1.153948 (trace:0.038, div:0.097) | Orth=2.41e-26
           | Order=0.000000 | Stab=0.000000 | SVD: σ_max/σ_min=9.26e+02
           | MeanRelErr=13586489.6952 | λ∈[0.0682, 21.8618]
           | λ_spread=21.7937

Epoch   4000 | LR=0.047843 | λ_orth=1.00 | Loss=1.147217
           | Eig=1.147217 (trace:0.037, div:0.096) | Orth=2.82e-26
           | Order=0.000000 | Stab=0.000000 | SVD: σ_max/σ_min=5.22e+02
           | MeanRelErr=5586654.4952 | λ∈[0.0280, 22.9592]
           | λ_spread=22.9312

Epoch   6000 | LR=0.045235 | λ_orth=1.00 | Loss=1.1

In [7]:
# ============================================================================
# FINAL EVALUATION
# ============================================================================

model.eval()
with torch.no_grad():
    U_final = model(X)
    
    # Final M-orthogonalization
    B_final = U_final.T @ (M @ U_final)
    V_final, S_final, _ = torch.linalg.svd(B_final)
    S_inv_sqrt_final = torch.diag_embed(1.0 / torch.sqrt(torch.clamp(S_final, min=1e-7)))
    B_inv_sqrt_final = V_final @ S_inv_sqrt_final @ V_final.T
    U_orth_final = U_final @ B_inv_sqrt_final
    
    # Final matrices
    final_rayleigh_matrix = U_orth_final.T @ (K @ U_orth_final)
    final_ortho_matrix = U_orth_final.T @ (M @ U_orth_final)
    
    # Extract and sort eigenvalues
    final_eigenvalues_scaled = torch.diag(final_rayleigh_matrix).cpu().numpy()
    final_eigenvalues_scaled.sort()
    
    # Scale back to original units
    final_eigenvalues = final_eigenvalues_scaled * (K_scale / M_scale).cpu().numpy()
    
    # Compute errors
    abs_error = np.abs(final_eigenvalues - eigvals[:k])
    rel_error = abs_error / (np.abs(eigvals[:k]) + 1e-10)
    
    print("=" * 80)
    print("FINAL RESULTS")
    print("=" * 80)
    
    # Orthogonality check
    ortho_residual = torch.norm(final_ortho_matrix - identity_k, p='fro').item()
    ortho_diag = torch.diag(final_ortho_matrix).cpu().numpy()
    print(f"\nOrthogonality Quality:")
    print(f"  ||U^T M U - I||_F = {ortho_residual:.2e}")
    print(f"  Diagonal range: [{ortho_diag.min():.6f}, {ortho_diag.max():.6f}] (should be ~1.0)")
    
    # Rayleigh matrix structure
    rayleigh_diag = torch.diag(final_rayleigh_matrix).cpu().numpy()
    rayleigh_offdiag = (final_rayleigh_matrix - torch.diag(torch.diag(final_rayleigh_matrix))).cpu().numpy()
    print(f"\nRayleigh Quotient Matrix:")
    print(f"  Diagonal norm: {np.linalg.norm(rayleigh_diag):.6f}")
    print(f"  Off-diagonal norm: {np.linalg.norm(rayleigh_offdiag, 'fro'):.6f} (should be small)")
    
    # Eigenvalue comparison
    print(f"\nEigenvalue Comparison (first 10 modes):")
    print(f"{'Mode':<6} {'Predicted':<12} {'Reference':<12} {'Abs Error':<12} {'Rel Error':<12}")
    print("-" * 66)
    for i in range(min(20, k)):
        print(f"{i+1:<6} {final_eigenvalues[i]:<12.6f} {eigvals[i]:<12.6f} "
              f"{abs_error[i]:<12.6f} {rel_error[i]:<12.4%}")
    
    print(f"\nEigenvalue Comparison (last 10 modes):")
    print(f"{'Mode':<6} {'Predicted':<12} {'Reference':<12} {'Abs Error':<12} {'Rel Error':<12}")
    print("-" * 66)
    for i in range(max(0, k-10), k):
        print(f"{i+1:<6} {final_eigenvalues[i]:<12.6f} {eigvals[i]:<12.6f} "
              f"{abs_error[i]:<12.6f} {rel_error[i]:<12.4%}")
    
    # Overall statistics
    print(f"\nOverall Statistics (all {k} modes):")
    print(f"  Mean Absolute Error:   {np.mean(abs_error):.6f}")
    print(f"  Mean Relative Error:   {np.mean(rel_error):.4%}")
    print(f"  Median Relative Error: {np.median(rel_error):.4%}")
    print(f"  Max Relative Error:    {np.max(rel_error):.4%}")
    print(f"  Modes with <5% error:  {np.sum(rel_error < 0.05)}/{k}")
    print(f"  Modes with <10% error: {np.sum(rel_error < 0.10)}/{k}")
    
    print("\n" + "=" * 80)

FINAL RESULTS

Orthogonality Quality:
  ||U^T M U - I||_F = 1.38e-13
  Diagonal range: [1.000000, 1.000000] (should be ~1.0)

Rayleigh Quotient Matrix:
  Diagonal norm: 0.404483
  Off-diagonal norm: 0.119736 (should be small)

Eigenvalue Comparison (first 10 modes):
Mode   Predicted    Reference    Abs Error    Rel Error   
------------------------------------------------------------------
1      0.002012     0.000000     0.002012     2004992368.4949%
2      0.288125     0.007574     0.280551     3704.0559%  
3      0.320587     0.030308     0.290279     957.7679%   
4      0.357158     0.068146     0.289012     424.1034%   
5      0.373633     0.121208     0.252425     208.2576%   
6      0.374031     0.189243     0.184789     97.6463%    
7      0.395269     0.272231     0.123037     45.1958%    
8      0.460161     0.370536     0.089625     24.1879%    
9      0.619759     0.483409     0.136350     28.2059%    
10     0.751307     0.611343     0.139965     22.8946%    
11     0.8690

In [6]:
final_eigenvalues

array([2.01201604e-03, 2.88124694e-01, 3.20587381e-01, 3.57158054e-01,
       3.73632728e-01, 3.74031356e-01, 3.95268651e-01, 4.60160852e-01,
       6.19759317e-01, 7.51307118e-01, 8.69004582e-01, 9.56841574e-01,
       1.07852232e+00, 1.14297995e+00, 1.27981064e+00, 1.33402746e+00,
       1.62734130e+00, 2.12937044e+00, 2.35154644e+00, 2.74612125e+00,
       2.98446296e+00, 3.41304258e+00, 3.57027107e+00, 3.58906393e+00,
       4.19116252e+00, 4.22471689e+00, 5.22911190e+00, 5.61958304e+00,
       5.94564640e+00, 6.50567618e+00, 6.60841742e+00, 7.00709300e+00,
       7.19406886e+00, 7.20105877e+00, 7.25728258e+00, 7.26437860e+00,
       7.26560591e+00, 7.28165822e+00, 7.30073049e+00, 7.33064576e+00,
       7.37279871e+00, 7.37444745e+00, 7.37898770e+00, 7.38228787e+00,
       7.38926960e+00, 7.42931378e+00, 7.44089643e+00, 7.71908793e+00,
       2.06188012e+01, 3.35920843e+01])