In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import linalg
from Mesh import Mesh

import matplotlib.pyplot as plt


# Convert to torch tensors (double precision for better numerical stability)
torch.set_default_dtype(torch.double)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

m = Mesh('data/coil_1.2_MM.obj')

centroid = m.verts.mean(0)
std_max = m.verts.std(0).max()

verts_new = (m.verts - centroid)/std_max

m = Mesh(verts = verts_new, connectivity = m.connectivity)

print('Computing Laplacian')
K, M = m.computeLaplacian()

# following Finite Elements methodology 
# K is stiffness matrix, M is mass matrix
# The problem to solve becomes 
# K*u = lambda * M*u
print('Computing eigen values')
eigvals, eigvecs = linalg.eigh(K,M)


# send all relevant numpy arrays to torch tensors
K = torch.from_numpy(K).to(device)
M = torch.from_numpy(M).to(device)
X = torch.from_numpy(m.verts).to(device)
N = X.shape[0]

# in the paper we used 50 eigenvalues so set k to 50
k = 50
epsilon = 1e-4

# Apply the regularization
K_reg = K + epsilon * torch.eye(N, device=device)
print(f"\nUsing epsilon={epsilon}, final condition number: {torch.linalg.cond(K_reg).item():.2e}")

print("\n=== Matrix Normalization ===")
K_scale = torch.norm(K_reg, p='fro')

K = K_reg / K_scale
M = M / K_scale


Computing Laplacian
Computing eigen values

Using epsilon=0.0001, final condition number: 1.47e+05

=== Matrix Normalization ===


In [2]:
# Build the neural network that maps coordinates -> k outputs per node
class MLP(nn.Module):
    def __init__(self, in_dim=3, out_dim=k, hidden=[64,64]):
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden:
            layers.append(nn.Linear(last, h, dtype=torch.double))
            layers.append(nn.SiLU())
            last = h
        layers.append(nn.Linear(last, out_dim, dtype=torch.double))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)  # returns (N, k)

# Instantiate model and optimizer
model = MLP().to(device)
# Initialize all layers (Xavier), final layer small
for name, p in model.named_parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
for p in model.net[-1].parameters():  # last Linear
    if p.ndim == 2:
        nn.init.normal_(p, std=1e-3)
    else:
        nn.init.zeros_(p)

# Helper: given A (N x k), return M-orthonormalized U = A (A^T M A)^{-1/2}
def m_orthonormalize(A, M):
    # A: (N, k), M: (N, N)
    # compute B = A^T M A (k x k)
    B = A.T @ (M @ A)  # k x k
    # symmetrize B
    B = 0.5*(B + B.T)
    # compute inverse sqrt of B via eigendecomposition (k small)
    s, Q = torch.linalg.eigh(B)  # s are eigenvalues
    # regularize small eigenvalues
    s_clamped = torch.clamp(s, min=1e-12)
    inv_sqrt = Q @ torch.diag(1.0/torch.sqrt(s_clamped)) @ Q.T
    U = A @ inv_sqrt
    return U

lr_start = 0.01
lr_end = 0.0001
max_epochs = 100_000
print_every = 1_000
loss_history = []

optimizer = optim.Adam(model.parameters(), lr=lr_start)
decay_factor = (lr_end / lr_start) ** (1 / max_epochs)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=decay_factor)

for epoch in range(1, max_epochs+1):
    model.train()
    optimizer.zero_grad()
    A = model(X)  # N x k
    U = m_orthonormalize(A, M)  # U is M-orthonormal

    UKU = U.T @ (K @ U)
    UMU = U.T @ (M @ U)        # k x k
    
    orth_loss = torch.norm(UMU - torch.eye(k, device=device), p='fro')**2
    eig_loss = torch.norm(UKU, p='fro')**2

    loss = 100 * eig_loss + orth_loss
    loss.backward()
    optimizer.step()
    scheduler.step()
    loss_history.append(loss.item())
    if epoch % print_every == 0 or epoch == 1:
        print(f"Epoch {epoch:4d}, total loss={loss.item():.6f}, orthogonal loss={orth_loss.item():.4f}, eigen loss={eig_loss.item():.4f}")

Epoch    1, total loss=49430.347857, orthogonal loss=24.5302, eigen loss=494.0582
Epoch 1000, total loss=195127.295924, orthogonal loss=0.0000, eigen loss=1951.2730
Epoch 2000, total loss=184466.079235, orthogonal loss=0.0000, eigen loss=1844.6608
Epoch 3000, total loss=177873.340622, orthogonal loss=0.0000, eigen loss=1778.7334
Epoch 4000, total loss=173375.323495, orthogonal loss=0.0000, eigen loss=1733.7532
Epoch 5000, total loss=170370.003695, orthogonal loss=0.0000, eigen loss=1703.7000
Epoch 6000, total loss=167046.180719, orthogonal loss=0.0000, eigen loss=1670.4618
Epoch 7000, total loss=164315.378199, orthogonal loss=0.0000, eigen loss=1643.1538
Epoch 8000, total loss=162971.924109, orthogonal loss=0.0000, eigen loss=1629.7192
Epoch 9000, total loss=161534.146348, orthogonal loss=0.0000, eigen loss=1615.3415
Epoch 10000, total loss=161518.334778, orthogonal loss=0.0000, eigen loss=1615.1833
Epoch 11000, total loss=159758.849099, orthogonal loss=0.0000, eigen loss=1597.5885
Epo

In [4]:
np.set_printoptions(suppress=True, precision=6)

model.eval()
with torch.no_grad():
    A_final = model(X)

    U_final= m_orthonormalize(A_final, M)
    UKU = U_final.T @ (K @ U_final)

    print(f"The shape of U.t @ K @ U is: {UKU.shape}")

    approx_eigs = np.round(torch.diag(UKU).cpu().numpy(), 8)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
    # For better comparison, we can compute Ritz values from subspace U by solving small generalized eigenproblem
    # (U^T K U) c = mu (U^T M U) c, but U^T M U = I so just eig of UK_final
    mu, Wsmall = np.linalg.eigh(UKU.cpu().numpy())
    mu = np.real(mu)
    # sort
    idx = np.argsort(mu)
    mu = mu[idx]
    print("\nLearned Ritz values (from U^T K U):", np.round(mu[:5], 6))
    print("Reference eigenvalues (first k):   ", np.round(eigvals[:5], 6))

The shape of U.t @ K @ U is: torch.Size([50, 50])

Learned Ritz values (from U^T K U): [0.006524 0.073555 0.229321 0.262386 0.283282]
Reference eigenvalues (first k):    [0.       0.007574 0.030308 0.068146 0.121208]


In [5]:
for i, j in zip(mu, eigvals):
    print(f"--- approximation: {np.round(i, 4)} actual: {np.round(j, 4)} ---")

--- approximation: 0.0065 actual: 0.0 ---
--- approximation: 0.0736 actual: 0.0076 ---
--- approximation: 0.2293 actual: 0.0303 ---
--- approximation: 0.2624 actual: 0.0681 ---
--- approximation: 0.2833 actual: 0.1212 ---
--- approximation: 0.3811 actual: 0.1892 ---
--- approximation: 0.4652 actual: 0.2722 ---
--- approximation: 0.7194 actual: 0.3705 ---
--- approximation: 1.0522 actual: 0.4834 ---
--- approximation: 1.1064 actual: 0.6113 ---
--- approximation: 1.1753 actual: 0.754 ---
--- approximation: 1.7005 actual: 0.9117 ---
--- approximation: 1.8233 actual: 1.0836 ---
--- approximation: 1.904 actual: 1.27 ---
--- approximation: 2.2072 actual: 1.4713 ---
--- approximation: 2.4038 actual: 1.687 ---
--- approximation: 2.4726 actual: 1.9172 ---
--- approximation: 2.6417 actual: 2.1605 ---
--- approximation: 2.8411 actual: 2.419 ---
--- approximation: 3.1089 actual: 2.6903 ---
--- approximation: 3.5563 actual: 2.9745 ---
--- approximation: 3.8684 actual: 3.2753 ---
--- approximation: 

In [6]:
torch.save({
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss,
}, "simple_model.pth")



In [None]:
# Load
checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]

model.eval()