In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import linalg
from Mesh import Mesh

import matplotlib.pyplot as plt


In [2]:
# Convert to torch tensors (double precision for better numerical stability)
torch.set_default_dtype(torch.double)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
m = Mesh('data/coil_1.2_MM.obj')

centroid = m.verts.mean(0)
std_max = m.verts.std(0).max()

verts_new = (m.verts - centroid)/std_max

m = Mesh(verts = verts_new, connectivity = m.connectivity)

print('Computing Laplacian')
K, M = m.computeLaplacian()

# following Finite Elements methodology 
# K is stiffness matrix, M is mass matrix
# The problem to solve becomes 
# K*u = lambda * M*u
print('Computing eigen values')
eigvals, eigvecs = linalg.eigh(K,M)

Computing Laplacian
Computing eigen values


In [4]:
# send all relevant numpy arrays to torch tensors
K = torch.from_numpy(K).to(device)
M = torch.from_numpy(M).to(device)
X = torch.from_numpy(m.verts).to(device)
N = X.shape[0]


# 1. Start with stronger regularization
epsilon = 1e-4

# Apply the regularization
K_reg = K + epsilon * torch.eye(N, device=device)
print(f"\nUsing epsilon={epsilon}, final condition number: {torch.linalg.cond(K_reg).item():.2e}")

print("\n=== Matrix Normalization ===")
K_scale = torch.norm(K_reg, p='fro')
M_scale = torch.norm(M, p='fro')

K = K_reg / K_scale
M = M / K_scale


Using epsilon=0.0001, final condition number: 1.47e+05

=== Matrix Normalization ===


In [5]:
# in the paper we used 50 eigenvalues so set k to 50
k = 50

In [None]:
# Build the neural network that maps coordinates -> k outputs per node
class MLP(nn.Module):
    def __init__(self, in_dim=3, out_dim=k, hidden=[64,64]):
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden:
            layers.append(nn.Linear(last, h, dtype=torch.double))
            layers.append(nn.SiLU())
            last = h
        layers.append(nn.Linear(last, out_dim, dtype=torch.double))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)  # returns (N, k)

# Instantiate model and optimizer
model = MLP().to(device)
# Initialize all layers (Xavier), final layer small
for name, p in model.named_parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
for p in model.net[-1].parameters():  # last Linear
    if p.ndim == 2:
        nn.init.normal_(p, std=1e-3)
    else:
        nn.init.zeros_(p)

lr_start = 0.05
lr_end = 0.001
max_epochs = 100_000

optimizer = optim.AdamW(model.parameters(), lr=lr_start, weight_decay=1e-5)

# Cosine annealing with warm restarts for better convergence
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=30000, T_mult=2, eta_min=lr_end)
print_every = 1_000
loss_history = []

#optimizer = optim.Adam(model.parameters(), lr=0.01)
#decay_factor = (lr_end / lr_start) ** (1 / max_epochs)
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=decay_factor)

In [7]:
for epoch in range(1, max_epochs+1):
    optimizer.zero_grad()
    U = model(X)  # N x k

    # losses
    UMU = U.T @ (M @ U)        # k x k
    UKU = U.T @ (K @ U)

    eigenvalues_approx = torch.diag(UKU)
    sorted_eigs, _ = torch.sort(eigenvalues_approx)
    zero_eig_loss = sorted_eigs[0] ** 2
    eig_loss_trace = torch.sum(sorted_eigs[1:])
    gaps = sorted_eigs[1:] - sorted_eigs[:-1]
    min_gap = 1e-4  # Minimum separation in normalized units (very small)
    diversity_loss = torch.sum(torch.relu(min_gap - gaps))
    off_diag_mask = 1 - torch.eye(k, device=device, dtype=torch.float64)
    eig_loss_offdiag = torch.sum((UKU * off_diag_mask)**2)


    orth_loss = torch.norm(UMU - torch.eye(k, device=device), p='fro')**2
    #eig_loss = torch.norm(UKU, p='fro')**2
    eig_loss = 2*zero_eig_loss + 5*eig_loss_trace + 5*diversity_loss + eig_loss_offdiag

    loss = eig_loss + 10 * orth_loss

    loss.backward()
    optimizer.step()
    scheduler.step()
    loss_history.append(loss.item())

    if epoch % print_every == 0 or epoch == 1:
        approx_vals = torch.diag(U.T @ (K @ U)).detach().cpu().numpy()
        print(
        f"Epoch {epoch}, "
        f"total loss={loss.item():.6f}, "
        f"eig_loss={eig_loss.item():.6f}, "
        f"orth_loss={orth_loss.item():.6f}"
    )


Epoch 1, total loss=500.021039, eig_loss=0.024632, orth_loss=49.999641
Epoch 1000, total loss=453.643545, eig_loss=21.684457, orth_loss=43.195909
Epoch 2000, total loss=434.481217, eig_loss=38.047090, orth_loss=39.643413
Epoch 3000, total loss=1018.130197, eig_loss=604.390280, orth_loss=41.373992
Epoch 4000, total loss=518.048194, eig_loss=69.187764, orth_loss=44.886043
Epoch 5000, total loss=484.558249, eig_loss=28.852698, orth_loss=45.570555
Epoch 6000, total loss=476.662288, eig_loss=19.774643, orth_loss=45.688764
Epoch 7000, total loss=473.327488, eig_loss=17.747808, orth_loss=45.557968
Epoch 8000, total loss=470.974342, eig_loss=17.748498, orth_loss=45.322584
Epoch 9000, total loss=469.038847, eig_loss=17.416100, orth_loss=45.162275
Epoch 10000, total loss=466.678177, eig_loss=17.353933, orth_loss=44.932424
Epoch 11000, total loss=464.419860, eig_loss=18.267429, orth_loss=44.615243
Epoch 12000, total loss=462.209522, eig_loss=18.455048, orth_loss=44.375447
Epoch 13000, total loss=

In [8]:
np.set_printoptions(suppress=True, precision=6)

# ==== Final result ====
with torch.no_grad():
    U_final = model(X)
    UKU = U_final.T @ (K @ U_final)
    final_eigs = torch.diag(UKU).cpu().numpy()
    final_eigs.sort()

    # FINAL EVALUATION
    abs_error = np.abs(final_eigs - eigvals[:k])
    rel_error = abs_error / (np.abs(eigvals[:k]) + 1e-10)
    
    print("=" * 80)
    print("FINAL RESULTS")
    print("=" * 80)
    
    # Eigenvalue comparison
    print(f"\nEigenvalue Comparison (first 10 modes):")
    print(f"{'Mode':<6} {'Predicted':<12} {'Reference':<12} {'Abs Error':<12} {'Rel Error':<12}")
    print("-" * 66)
    for i in range(min(10, k)):
        print(f"{i+1:<6} {final_eigs[i]:<12.6f} {eigvals[i]:<12.6f} "
              f"{abs_error[i]:<12.6f} {rel_error[i]:<12.4%}")
    
    print(f"\nEigenvalue Comparison (last 10 modes):")
    print(f"{'Mode':<6} {'Predicted':<12} {'Reference':<12} {'Abs Error':<12} {'Rel Error':<12}")
    print("-" * 66)
    for i in range(max(0, k-10), k):
        print(f"{i+1:<6} {final_eigs[i]:<12.6f} {eigvals[i]:<12.6f} "
              f"{abs_error[i]:<12.6f} {rel_error[i]:<12.4%}")
    
    # Overall statistics
    print(f"\nOverall Statistics (all {k} modes):")
    print(f"  Mean Absolute Error:   {np.mean(abs_error):.6f}")
    print(f"  Mean Relative Error:   {np.mean(rel_error):.4%}")
    print(f"  Median Relative Error: {np.median(rel_error):.4%}")
    print(f"  Max Relative Error:    {np.max(rel_error):.4%}")
    print(f"  Modes with <5% error:  {np.sum(rel_error < 0.05)}/{k}")
    print(f"  Modes with <10% error: {np.sum(rel_error < 0.10)}/{k}")
    
    print("\n" + "=" * 80)

FINAL RESULTS

Eigenvalue Comparison (first 10 modes):
Mode   Predicted    Reference    Abs Error    Rel Error   
------------------------------------------------------------------
1      0.001015     0.000000     0.001015     1011352243.9958%
2      0.001172     0.007574     0.006402     84.5228%    
3      0.001203     0.030308     0.029105     96.0300%    
4      0.001301     0.068146     0.066846     98.0910%    
5      0.001353     0.121208     0.119855     98.8840%    
6      0.001386     0.189243     0.187857     99.2678%    
7      0.001460     0.272231     0.270772     99.4638%    
8      0.001489     0.370536     0.369047     99.5983%    
9      0.001523     0.483409     0.481886     99.6849%    
10     0.001559     0.611343     0.609784     99.7450%    

Eigenvalue Comparison (last 10 modes):
Mode   Predicted    Reference    Abs Error    Rel Error   
------------------------------------------------------------------
41     0.002840     7.422193     7.419353     99.9617%    
