In [7]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import linalg

In [8]:
# Convert to torch tensors (double precision for better numerical stability)
torch.set_default_dtype(torch.double)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
####### inputs
N, k = 20, 3  # toy size
# random SPD matrices K, M
A = np.random.randn(N, N)
K = A.T @ A + np.eye(N)   # make SPD
B = np.random.randn(N, N)
M = B.T @ B + np.eye(N)   # make SPD

K = torch.from_numpy(K).to(device)
M = torch.from_numpy(M).to(device)

# reference solution
w_all, V_all = linalg.eigh(K.cpu().numpy(), M.cpu().numpy())
print("Reference eigs:", np.round(w_all[:k], 6))


X = torch.randn(N, 3, device=device)

Reference eigs: [0.038217 0.056592 0.086685]


In [10]:
# Build the neural network that maps coordinates -> k outputs per node
class MLP(nn.Module):
    def __init__(self, in_dim=3, out_dim=k, hidden=[64,64]):
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden:
            layers.append(nn.Linear(last, h, dtype=torch.double))
            layers.append(nn.Tanh())
            last = h
        layers.append(nn.Linear(last, out_dim, dtype=torch.double))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)  # returns (N, k)

# Instantiate model and optimizer
model = MLP().to(device)
# initialize final layer small
for name, p in model.named_parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [11]:
max_epochs = 20_000
print_every = 1_000
loss_history = []

In [12]:
# Helper: given A (N x k), return M-orthonormalized U = A (A^T M A)^{-1/2}
def m_orthonormalize(A, M):
    # A: (N, k), M: (N, N)
    # compute B = A^T M A (k x k)
    B = A.T @ (M @ A)  # k x k
    # symmetrize B
    B = 0.5*(B + B.T)
    # compute inverse sqrt of B via eigendecomposition (k small)
    s, Q = torch.linalg.eigh(B)  # s are eigenvalues
    # regularize small eigenvalues
    s_clamped = torch.clamp(s, min=1e-12)
    inv_sqrt = Q @ torch.diag(1.0/torch.sqrt(s_clamped)) @ Q.T
    U = A @ inv_sqrt
    return U, B, s

In [13]:
for epoch in range(1, max_epochs+1):
    model.train()
    optimizer.zero_grad()
    A = model(X)  # N x k
    U, B, s = m_orthonormalize(A, M)  # U is M-orthonormal
    # compute objective: trace(U^T K U) -> equals sum of Rayleighs for columns
    UK = U.T @ (K @ U)
    # symmetrize for numerical safety
    UK = 0.5*(UK + UK.T)
    loss = torch.trace(UK)
    loss.backward()
    optimizer.step()
    loss_history.append(loss.item())
    if epoch % print_every == 0 or epoch==1:
        # compute approximate eigenvalues by Rayleigh on current U
        with torch.no_grad():
            # the diagonal entries of UK are the Rayleigh quotients for each column if columns are M-orthonormal
            approx_vals = torch.diag(UK).cpu().numpy()
        print(f"Epoch {epoch:4d}, loss(trace)={loss.item():.6f}, approx_vals={np.round(approx_vals,6)}")

Epoch    1, loss(trace)=3.954603, approx_vals=[0.776102 1.130515 2.047986]
Epoch 1000, loss(trace)=0.533638, approx_vals=[0.172237 0.190246 0.171154]
Epoch 2000, loss(trace)=0.364480, approx_vals=[0.077848 0.143504 0.143128]
Epoch 3000, loss(trace)=0.306609, approx_vals=[0.078265 0.09962  0.128724]
Epoch 4000, loss(trace)=0.254188, approx_vals=[0.086205 0.080382 0.087601]
Epoch 5000, loss(trace)=0.183291, approx_vals=[0.049457 0.074871 0.058963]
Epoch 6000, loss(trace)=0.184610, approx_vals=[0.048033 0.07462  0.061957]
Epoch 7000, loss(trace)=0.181764, approx_vals=[0.047927 0.074351 0.059485]
Epoch 8000, loss(trace)=0.181684, approx_vals=[0.047822 0.073946 0.059917]
Epoch 9000, loss(trace)=0.182084, approx_vals=[0.048197 0.073847 0.060039]
Epoch 10000, loss(trace)=0.182961, approx_vals=[0.048166 0.073447 0.061347]
Epoch 11000, loss(trace)=0.182699, approx_vals=[0.048333 0.07391  0.060456]
Epoch 12000, loss(trace)=0.181848, approx_vals=[0.048607 0.072488 0.060753]
Epoch 13000, loss(trac

In [17]:
model.eval()
with torch.no_grad():
    A_final = model(X)
    U_final, B_final, s_final = m_orthonormalize(A_final, M)
    UK_final = U_final.T @ (K @ U_final)

    print(f"The shape of U.t @ K @ U is: {UK_final.shape}")

    approx_eigs = np.round(torch.diag(UK_final).cpu().numpy(), 8)
    # For better comparison, we can compute Ritz values from subspace U by solving small generalized eigenproblem
    # (U^T K U) c = mu (U^T M U) c, but U^T M U = I so just eig of UK_final
    mu, Wsmall = np.linalg.eigh(UK_final.cpu().numpy())
    mu = np.real(mu)
    # sort
    idx = np.argsort(mu)
    mu = mu[idx]
    print("\nLearned Ritz values (from U^T K U):", np.round(mu, 6))
    print("Reference eigenvalues (first k):   ", np.round(w_all[:k], 6))

The shape of U.t @ K @ U is: torch.Size([3, 3])

Learned Ritz values (from U^T K U): [0.038219 0.056608 0.086786]
Reference eigenvalues (first k):    [0.038217 0.056592 0.086685]
