In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import linalg
from Mesh import Mesh

import matplotlib.pyplot as plt


# Convert to torch tensors (double precision for better numerical stability)
torch.set_default_dtype(torch.double)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

m = Mesh('data/coil_1.2_MM.obj')

centroid = m.verts.mean(0)
std_max = m.verts.std(0).max()

verts_new = (m.verts - centroid)/std_max

m = Mesh(verts = verts_new, connectivity = m.connectivity)

print('Computing Laplacian')
K, M = m.computeLaplacian()

# following Finite Elements methodology 
# K is stiffness matrix, M is mass matrix
# The problem to solve becomes 
# K*u = lambda * M*u
print('Computing eigen values')
eigvals, eigvecs = linalg.eigh(K,M)


# send all relevant numpy arrays to torch tensors
K = torch.from_numpy(K).to(device)
M = torch.from_numpy(M).to(device)
X = torch.from_numpy(m.verts).to(device)


# in the paper we used 50 eigenvalues so set k to 50
k = 50

Computing Laplacian
Computing eigen values


In [2]:
class MLP(nn.Module):
    """
    Multilayer Perceptron for mapping coordinates to k eigenmodes.
    Uses SiLU (Swish) activation for better gradient flow than Tanh.
    """
    def __init__(self, in_dim=3, out_dim=k, hidden=[64, 64]):
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden:
            # Using nn.SiLU (Swish) instead of nn.Tanh
            layers.append(nn.Linear(last, h))
            layers.append(nn.SiLU())
            last = h
        layers.append(nn.Linear(last, out_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)  # returns (N, k)

# --- 3. Model Instantiation and Initialization ---

# Instantiate model
model = MLP().to(device)

# Initialize all layers (Xavier), final layer small (Best practice for PINNs)
for name, p in model.named_parameters():
    if 'net' in name:
        # Standard Xavier for hidden layers (weights)
        if p.dim() > 1 and name.split('.')[1] != str(len(model.net) - 1):
            nn.init.xavier_uniform_(p.data)
        # Final Linear layer: Small weights and zero bias
        if name.split('.')[1] == str(len(model.net) - 1):
            if p.ndim == 2:
                # Weights: Very small normal distribution
                nn.init.normal_(p.data, std=1e-3)
            else:
                # Biases: Zero
                nn.init.zeros_(p.data)

# --- 4. Training Setup ---

# Hyperparameters
# CRITICAL FIX: Since we are using M-orthogonalization (SVD projection), 
# lambda_orth can be set very low, focusing the optimizer on eig_loss.
lambda_orth = 0.01         
lr_start = 0.01
lr_end = 0.0001
max_epochs = 100_000 # Reverting to a more manageable epoch count for testing         
print_every = 500
loss_history = []

optimizer = optim.Adam(model.parameters(), lr=lr_start)
decay_factor = (lr_end / lr_start) ** (1 / max_epochs)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=decay_factor)

# --- 5. Training Loop ---

print("\n--- Starting Full-Batch Training with SVD Stabilization ---")
identity_k = torch.eye(k, device=device)

for epoch in range(1, max_epochs + 1):
    model.train()
    optimizer.zero_grad()
    
    # Forward Pass (Full Batch)
    U = model(X)  # N x k (Basis functions evaluated at coordinates X)
    
    # --- M-Orthogonalization via SVD ---
    B = U.T @ (M @ U)        # k x k 
    V, S, _ = torch.linalg.svd(B)
    S_inv_sqrt = torch.diag_embed(1.0 / torch.sqrt(torch.clamp(S, min=1e-8)))
    B_inv_sqrt = V @ S_inv_sqrt @ V.T
    U_orth = U @ B_inv_sqrt
    
    # --- Rayleigh matrix ---
    R = U_orth.T @ (K @ U_orth)   # k x k

    # --- Eigenvalue Loss (unsupervised) ---
    # Option A: Penalize off-diagonals only
    eig_loss = torch.norm(R - torch.diag(torch.diag(R)), p='fro')**2

    # Option B: Eigen-equation residual (alternative form, comment/uncomment)
    # Lambda = torch.diag(torch.diag(R))
    # residual = K @ U_orth - (M @ U_orth) @ Lambda
    # eig_loss = torch.norm(residual, p='fro')**2 / k

    # --- Orthogonality Loss ---
    B_orth = U_orth.T @ (M @ U_orth)        
    orth_loss = torch.norm(B_orth - identity_k, p='fro')**2

    # --- Total Loss ---
    loss = eig_loss + lambda_orth * orth_loss

    # Backpropagation and Step
    loss.backward()
    optimizer.step()
    scheduler.step()
    
    # Logging
    loss_history.append(loss.item())

    if epoch % print_every == 0 or epoch == 1:
        model.eval()
        with torch.no_grad():
            eigvals = torch.diag(R).cpu().numpy()   # read off eigenvalues
            eigvals.sort()
            current_lr = scheduler.get_last_lr()[0]
        
        print(
            f"Epoch {epoch:<5}, LR={current_lr:.6f}, "
            f"Total Loss={loss.item():.4f}, "
            f"Eig Loss={eig_loss.item():.4f}, "
            f"Orth Loss={orth_loss.item():.4f}"
        )
        # print(f"  Approx Eigenvalues: {eigvals[:5]}")

print("--- Training Complete ---")


--- Starting Full-Batch Training with SVD Stabilization ---
Epoch 1    , LR=0.010000, Total Loss=128.8134, Eig Loss=128.4884, Orth Loss=32.4945
Epoch 500  , LR=0.009772, Total Loss=0.4149, Eig Loss=0.1408, Orth Loss=27.4080
Epoch 1000 , LR=0.009550, Total Loss=0.4283, Eig Loss=0.1544, Orth Loss=27.3827
Epoch 1500 , LR=0.009333, Total Loss=0.3189, Eig Loss=0.0460, Orth Loss=27.2878
Epoch 2000 , LR=0.009120, Total Loss=0.2954, Eig Loss=0.0236, Orth Loss=27.1823
Epoch 2500 , LR=0.008913, Total Loss=0.3088, Eig Loss=0.0383, Orth Loss=27.0472
Epoch 3000 , LR=0.008710, Total Loss=0.2743, Eig Loss=0.0048, Orth Loss=26.9456
Epoch 3500 , LR=0.008511, Total Loss=0.2914, Eig Loss=0.0225, Orth Loss=26.8887
Epoch 4000 , LR=0.008318, Total Loss=0.3171, Eig Loss=0.0488, Orth Loss=26.8321
Epoch 4500 , LR=0.008128, Total Loss=0.3200, Eig Loss=0.0522, Orth Loss=26.7817
Epoch 5000 , LR=0.007943, Total Loss=0.2818, Eig Loss=0.0146, Orth Loss=26.7207
Epoch 5500 , LR=0.007762, Total Loss=0.2720, Eig Loss=0

In [6]:
# --- Final Eigenvalue Check ---
model.eval()
with torch.no_grad():
    U_final = model(X)
    B_final = U_final.T @ (M @ U_final)
    V_final, S_final, _ = torch.linalg.svd(B_final)
    S_inv_sqrt_final = torch.diag_embed(1.0 / torch.sqrt(torch.clamp(S_final, min=1e-8)))
    B_inv_sqrt_final = V_final @ S_inv_sqrt_final @ V_final.T
    U_orth_final = U_final @ B_inv_sqrt_final
    
    R_final = U_orth_final.T @ (K @ U_orth_final)
    final_eigenvalues = torch.diag(R_final).cpu().numpy()
    final_eigenvalues.sort()
    final_ortho_matrix = U_orth_final.T @ (M @ U_orth_final)

    print("\n--- Final Results ---")
    print("Final Approximate Eigenvalues (Sorted):", 
      [f"{val:.6f}" for val in final_eigenvalues[:5]])

    print("Reference eigenvalues (first k):", 
        [f"{val:.6f}" for val in eigvals[:5]])

    print("\nFinal Orthogonality Matrix (U_orth^T M U_orth):")
    print(final_ortho_matrix.cpu().numpy().round(4))


--- Final Results ---
Final Approximate Eigenvalues (Sorted): ['0.000036', '0.000038', '0.000038', '0.000038', '0.000039']
Reference eigenvalues (first k): ['0.000036', '0.000038', '0.000038', '0.000038', '0.000039']

Final Orthogonality Matrix (U_orth^T M U_orth):
[[ 9.707e-01 -3.700e-03  2.000e-04 ... -3.300e-03  1.200e-03  0.000e+00]
 [-3.700e-03  9.687e-01 -1.000e-03 ... -7.100e-03 -2.700e-03 -0.000e+00]
 [ 2.000e-04 -1.000e-03  9.991e-01 ... -1.500e-03 -0.000e+00  1.000e-04]
 ...
 [-3.300e-03 -7.100e-03 -1.500e-03 ...  9.837e-01  7.000e-04  0.000e+00]
 [ 1.200e-03 -2.700e-03 -0.000e+00 ...  7.000e-04  9.949e-01  1.000e-04]
 [ 0.000e+00 -0.000e+00  1.000e-04 ...  0.000e+00  1.000e-04  5.680e-02]]
