In [37]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import linalg
from Mesh import Mesh

import matplotlib.pyplot as plt


# Convert to torch tensors (double precision for better numerical stability)
torch.set_default_dtype(torch.double)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

m = Mesh('bunny.obj')

centroid = m.verts.mean(0)
std_max = m.verts.std(0).max()

verts_new = (m.verts - centroid)/std_max

m = Mesh(verts = verts_new, connectivity = m.connectivity)

print('Computing Laplacian')
K, M = m.computeLaplacian()

# following Finite Elements methodology 
# K is stiffness matrix, M is mass matrix
# The problem to solve becomes 
# K*u = lambda * M*u
print('Computing eigen values')
eigvals, eigvecs = linalg.eigh(K,M)


# send all relevant numpy arrays to torch tensors
K = torch.from_numpy(K).to(device)
M = torch.from_numpy(M).to(device)
X = torch.from_numpy(m.verts).to(device)
N = X.shape[0]

# in the paper we used 50 eigenvalues so set k to 50
k = 128

def newton_schulz_orthogonalize(U, M, num_iters=5):
    """
    Input:
        U: (N, k) tensor, the raw output of the network
        M: (N, N) tensor, the mass matrix
    Output:
        U_orth: (N, k) tensor, such that U_orth.T @ M @ U_orth == Identity
    """
    N, k = U.shape
    
    # 1. Compute the Gram matrix (Interaction metric)
    # Shape: (k, k) - very small!
    A = U.T @ M @ U
    
    # 2. Normalize spectral norm to ensure convergence
    # Newton-Schulz only converges if norm(I - A) < 1. 
    # We divide A by its trace (approximate spectral norm) or a rough scalar estimate.
    normA = torch.norm(A, p='fro') # Frobenius norm is a safe upper bound
    scaling_factor = torch.sqrt(normA + 1e-6)
    
    Y = U / scaling_factor
    A = A / (scaling_factor**2)
    
    # 3. Newton-Schulz Iteration
    # We want to find X = A^{-1/2} using only matrix mul.
    # Then U_orth = U @ X
    
    # Initialize X as Identity (approx inverse sqrt)
    X = torch.eye(k, device=U.device, dtype=U.dtype)
    
    # Identity matrix for the update rule
    Id = torch.eye(k, device=U.device, dtype=U.dtype)
    
    for _ in range(num_iters):
        # Update rule: X_{new} = 0.5 * X * (3I - A * X^2)
        # This drives X to be the inverse square root of A
        AX = A @ X
        XAX = X @ AX
        X = 0.5 * X @ (3 * Id - XAX)
        
    # 4. Apply the whitening matrix to U
    U_orth = Y @ X
    
    return U_orth

# Build the neural network that maps coordinates -> k outputs per node
class MLP(nn.Module):
    def __init__(self, in_dim=3, out_dim=k, hidden=[64, 64, 64]):
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden:
            layers.append(nn.Linear(last, h, dtype=torch.double))
            layers.append(nn.SiLU())
            last = h
        layers.append(nn.Linear(last, out_dim, dtype=torch.double))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)  # returns (N, k)

# Instantiate model and optimizer2503
model = MLP().to(device)

lr_start = 0.01
lr_end = 0.0001
max_epochs = 5_000
print_every = 1_000
loss_history = []

optimizer = optim.Adam(model.parameters(), lr=lr_start)
decay_factor = (lr_end / lr_start) ** (1 / max_epochs)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=decay_factor)

for epoch in range(1, 22001): # , max_epochs+1
    model.train()
    optimizer.zero_grad()
    U = model(X)  # N x k

    KU = K @ U
    MU = M @ U

    UKU = U.T @ KU
    UMU = U.T @ MU        # k x k

    rayleigh = UKU / (UMU + 1E-6)

    loss_1 = torch.mean(torch.norm((KU - torch.diag(rayleigh) * MU )**2))
    UMU_I = (UMU - torch.eye(k, device=device))**2
    off_diag_loss = torch.max(UMU_I)
    diag_loss = torch.mean(UMU_I)
    orth_loss = diag_loss + off_diag_loss
    loss = loss_1 + orth_loss


    loss.backward()
    optimizer.step()
    scheduler.step()
    loss_history.append(loss.item())
    if epoch % 2500 == 0 or epoch == 1:
        print(f"Epoch {epoch:4d}, total loss={loss.item():.3f}, loss_1: {loss_1.item():.3f}, diag loss: {diag_loss.item():.3f}, off diag loss: {off_diag_loss.item():.3f}")

    if loss.item() < 1E-6:
        break


np.set_printoptions(3, suppress=True)
print(torch.max(UMU_I))

Computing Laplacian
Computing eigen values
Epoch    1, total loss=2.419, loss_1: 0.001, diag loss: 0.139, off diag loss: 2.280
Epoch 2500, total loss=0.379, loss_1: 0.018, diag loss: 0.066, off diag loss: 0.295
Epoch 5000, total loss=0.327, loss_1: 0.018, diag loss: 0.067, off diag loss: 0.241


KeyboardInterrupt: 

In [38]:
tmp = (UMU - torch.eye(k, device=device))**2
tmp = tmp.detach().cpu().numpy()
np.set_printoptions(3, suppress=True)
# print(UMU.detach().cpu().numpy())
print(torch.max(UMU_I).detach().cpu().numpy())

0.2389868677623058


In [None]:
- are the points uniquely distributed --> try DBScan
- network oneshot learning all eigenfunctions by the X points
- try LFBGS optimizer as the loss is fully convex

In [16]:
U = torch.cat(found_eigenvectors, dim=1)

In [17]:
UKU = U.T @ K @ U

In [18]:
UMU = U.T @ M @ U

In [19]:
np.set_printoptions(3, suppress=True)
print(UMU.detach().cpu().numpy())

[[ 1. -0.  0. ...  0.  0.  0.]
 [-0.  1.  0. ... -0. -0.  0.]
 [ 0. -0.  1. ...  0.  0. -0.]
 ...
 [ 0. -0.  0. ...  1.  0. -0.]
 [ 0. -0.  0. ...  0.  1.  0.]
 [ 0.  0. -0. ... -0.  0.  1.]]


In [None]:
 5 x 2503 @ 2503

In [None]:
M.cpu().numpy() @ eigvecs[:, :5]

In [None]:
eigvecs.T @ M.cpu().numpy() @ eigvecs