In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import linalg
from Mesh import Mesh

import matplotlib.pyplot as plt


# Convert to torch tensors (double precision for better numerical stability)
torch.set_default_dtype(torch.double)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

m = Mesh('bunny.obj')

centroid = m.verts.mean(0)
std_max = m.verts.std(0).max()

verts_new = (m.verts - centroid)/std_max

m = Mesh(verts = verts_new, connectivity = m.connectivity)

print('Computing Laplacian')
K, M = m.computeLaplacian()

# following Finite Elements methodology 
# K is stiffness matrix, M is mass matrix
# The problem to solve becomes 
# K*u = lambda * M*u
print('Computing eigen values')
eigvals, eigvecs = linalg.eigh(K,M)


# send all relevant numpy arrays to torch tensors
K = torch.from_numpy(K).to(device)
M = torch.from_numpy(M).to(device)
X = torch.from_numpy(m.verts).to(device)
N = X.shape[0]

# in the paper we used 50 eigenvalues so set k to 50
k = 5

# Build the neural network that maps coordinates -> k outputs per node
class MLP(nn.Module):
    def __init__(self, in_dim=3, out_dim=k, hidden=[64, 64, 64]):
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden:
            layers.append(nn.Linear(last, h, dtype=torch.double))
            layers.append(nn.SiLU())
            last = h
        layers.append(nn.Linear(last, out_dim, dtype=torch.double))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)  # returns (N, k)

# Instantiate model and optimizer2503
model = MLP().to(device)

lr_start = 0.01
lr_end = 0.0001
max_epochs = 5_000
print_every = 1_000
loss_history = []

optimizer = optim.Adam(model.parameters(), lr=lr_start)
decay_factor = (lr_end / lr_start) ** (1 / max_epochs)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=decay_factor)

for epoch in range(1, 22001): # , max_epochs+1
    model.train()
    optimizer.zero_grad()
    U = model(X)  # N x k

    KU = K @ U
    MU = M @ U

    UKU = U.T @ KU
    UMU = U.T @ MU        # k x k

    rayleigh = UKU / (UMU + 1E-6)

    loss_1 = torch.mean(torch.norm((KU - torch.diag(rayleigh) * MU )**2))
    UMU_I = (UMU - torch.eye(k, device=device))**2
    off_diag_loss = torch.max(UMU_I) * 10
    diag_loss = torch.mean(UMU_I) * 10  
    orth_loss = diag_loss + off_diag_loss
    loss = loss_1 + orth_loss


    loss.backward()
    optimizer.step()
    scheduler.step()
    loss_history.append(loss.item())
    if epoch % 2500 == 0 or epoch == 1:
        print(f"Epoch {epoch:4d}, total loss={loss.item():.3f}, loss_1: {loss_1.item():.3f}, diag loss: {diag_loss.item():.3f}, off diag loss: {off_diag_loss.item():.3f}")

    if loss.item() < 1E-6:
        break


Computing Laplacian
Computing eigen values
Epoch    1, total loss=11.008, loss_1: 0.000, diag loss: 1.758, off diag loss: 9.250
Epoch 2500, total loss=0.011, loss_1: 0.011, diag loss: 0.000, off diag loss: 0.000
Epoch 5000, total loss=0.010, loss_1: 0.010, diag loss: 0.000, off diag loss: 0.000
Epoch 7500, total loss=0.010, loss_1: 0.010, diag loss: 0.000, off diag loss: 0.000
Epoch 10000, total loss=0.010, loss_1: 0.010, diag loss: 0.000, off diag loss: 0.000
Epoch 12500, total loss=0.010, loss_1: 0.010, diag loss: 0.000, off diag loss: 0.000
Epoch 15000, total loss=0.010, loss_1: 0.010, diag loss: 0.000, off diag loss: 0.000
Epoch 17500, total loss=0.010, loss_1: 0.010, diag loss: 0.000, off diag loss: 0.000
Epoch 20000, total loss=0.010, loss_1: 0.010, diag loss: 0.000, off diag loss: 0.000


In [4]:
tmp = (UMU - torch.eye(k, device=device))**2
tmp = tmp.detach().cpu().numpy()
np.set_printoptions(6, suppress=True)
print(UMU.detach().cpu().numpy())
print()

[[ 0.999543  0.000209 -0.000578 -0.000471  0.000032]
 [ 0.000209  0.999312  0.000688  0.000688  0.000255]
 [-0.000578  0.000688  0.999312 -0.000688  0.000688]
 [-0.000471  0.000688 -0.000688  0.999312  0.000688]
 [ 0.000032  0.000255  0.000688  0.000688  0.999312]]



In [None]:
(UMU_I.detach().cpu().numpy())** 2

In [None]:
tmp

In [None]:
print(rayleigh)

In [None]:
UKU

In [None]:
UMU

In [None]:
np.set_printoptions(3)
eigvecs[:, :5].T @ M.cpu().numpy() @ eigvecs[:, :5]

In [None]:
 5 x 2503 @ 2503

In [None]:
M.cpu().numpy() @ eigvecs[:, :5]

In [None]:
eigvecs.T @ M.cpu().numpy() @ eigvecs

In [None]:
np.set_printoptions(suppress=True, precision=6)

model.eval()
with torch.no_grad():
    A_final = model(X)

    U_final= m_orthonormalize(A_final, M)
    UKU = U_final.T @ (K @ U_final)

    print(f"The shape of U.t @ K @ U is: {UKU.shape}")

    approx_eigs = np.round(torch.diag(UKU).cpu().numpy(), 8)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
    # For better comparison, we can compute Ritz values from subspace U by solving small generalized eigenproblem
    # (U^T K U) c = mu (U^T M U) c, but U^T M U = I so just eig of UK_final
    mu, Wsmall = np.linalg.eigh(UKU.cpu().numpy())
    mu = np.real(mu)
    # sort
    idx = np.argsort(mu)
    mu = mu[idx]
    print("\nLearned Ritz values (from U^T K U):", np.round(mu[:5], 6))
    print("Reference eigenvalues (first k):   ", np.round(eigvals[:5], 6))

In [None]:
for i, j in zip(mu, eigvals):
    print(f"--- approximation: {np.round(i, 4)} actual: {np.round(j, 4)} ---")

In [None]:
torch.save({
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss,
}, "simple_model.pth")



In [None]:
# Load
checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]

model.eval()