In [3]:
# multigrid_gnn_refine_fixed.py
"""
Complete working rewrite of your multigrid + GNN eigen-refinement pipeline.

Assumptions:
 - `Mesh` class exists and Mesh('bunny.obj') loads .verts (n x 3) and .connectivity (triangles).
 - `robust_laplacian.point_cloud_laplacian(X)` returns (L, M) as scipy sparse matrices where L and M are compatible with eigsh.
 - scikit-learn, scipy, numpy, matplotlib, torch are installed.

Key features:
 - All classes and functions defined in one file (no missing names).
 - Stable training: column normalization, small correction scale, normalized losses, grad clipping, configurable weights.
 - Auto-detect input feature dimension to avoid matmul mismatches.
"""

import os
import numpy as np
from scipy.linalg import eigh
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import eigsh
from sklearn.neighbors import NearestNeighbors

import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt

# Project imports (must be available in your environment)
from Mesh import Mesh
import robust_laplacian


# ------------------------
# Utility helpers
# ------------------------
def sp_to_torch_sparse(A):
    """Convert scipy sparse matrix to torch.sparse_coo_tensor (CPU or GPU depending on .to(device))."""
    A = A.tocoo()
    indices = np.vstack((A.row, A.col)).astype(np.int64)
    i = torch.LongTensor(indices)
    v = torch.FloatTensor(A.data)
    return torch.sparse_coo_tensor(i, v, A.shape).coalesce()


def normalize_columns_np(U, eps=1e-12):
    """Normalize numpy matrix columns to have unit L2 norm. Returns normalized U and norms."""
    norms = np.linalg.norm(U, axis=0) + eps
    return U / norms, norms


def normalize_columns_torch(U, eps=1e-12):
    """Normalize torch tensor columns to have unit L2 norm. Returns normalized U and norms (torch)."""
    norms = torch.norm(U, dim=0) + eps
    return U / norms, norms


# ------------------------
# Simple per-node corrector (message-passing via neighbor mean + MLP)
# ------------------------
class SimpleCorrector(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_sizes=(128, 64, 32), dropout=0.0):
        super().__init__()
        layers = []
        prev = in_dim * 2  # because we will concat self + neighbor-mean in forward
        for h in hidden_sizes:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU(inplace=True))
            if dropout > 0.0:
                layers.append(nn.Dropout(dropout))
            prev = h
        layers.append(nn.Linear(prev, out_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x, edge_index):
        """
        x: [n, in_dim]
        edge_index: LongTensor shape [2, n_edges] with (row=target, col=source) semantics
        """
        row, col = edge_index  # both LongTensor
        n = x.shape[0]
        # aggregate neighbor features: mean aggregator
        agg = torch.zeros_like(x)
        agg.index_add_(0, row, x[col])
        deg = torch.bincount(row, minlength=n).unsqueeze(1).to(x.dtype).to(x.device)
        deg = deg.clamp(min=1.0)
        agg = agg / deg
        h = torch.cat([x, agg], dim=1)  # shape [n, 2*in_dim]
        return self.net(h)


# ------------------------
# Multigrid eigensolver with GNN corrector
# ------------------------
class MultigridEigensolver:
    def __init__(self, device=None, checkpoint_dir="./checkpoints"):
        self.device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
        self.model = None
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(self.checkpoint_dir, exist_ok=True)

    @staticmethod
    def normalize_mesh(mesh):
        centroid = mesh.verts.mean(0)
        std_max = mesh.verts.std(0).max() + 1e-12
        verts_normalized = (mesh.verts - centroid) / std_max
        return Mesh(verts=verts_normalized, connectivity=mesh.connectivity)

    @staticmethod
    def build_prolongation(X_coarse, X_fine, k=1):
        nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(X_coarse)
        distances, indices = nbrs.kneighbors(X_fine)
        n_fine, n_coarse = X_fine.shape[0], X_coarse.shape[0]
        rows, cols, vals = [], [], []
        for i in range(n_fine):
            weights = 1.0 / (distances[i] + 1e-12)
            weights /= weights.sum()
            for j, idx in enumerate(indices[i]):
                rows.append(i)
                cols.append(idx)
                vals.append(weights[j])
        return coo_matrix((vals, (rows, cols)), shape=(n_fine, n_coarse))

    @staticmethod
    def build_knn_graph(X, k=4):
        n_points = X.shape[0]
        nbrs = NearestNeighbors(n_neighbors=k + 1).fit(X)
        _, neighbors = nbrs.kneighbors(X)
        rows, cols = [], []
        for i in range(n_points):
            for j in neighbors[i][1:]:
                rows.append(i)
                cols.append(j)
        return torch.LongTensor([rows, cols]).to(torch.long)

    def solve_eigenvalue_problem(self, X, n_modes):
        L, M = robust_laplacian.point_cloud_laplacian(X)
        # use eigsh with M as mass
        vals, vecs = eigsh(L, k=n_modes, M=M, which='SM')
        return vals, np.array(vecs), L, M

    # ------------------------
    # Core training routine
    # ------------------------
    def train_gnn(self, model, x_feats, edge_index, U_init, L_fine, M_fine, U_coarse, P,
                  n_modes,
                  epochs=200,
                  lr=1e-3,
                  corr_scale=1e-2,
                  w_res=10.0,
                  w_orth=1.0,
                  w_proj=1e-3,
                  grad_clip=1.0,
                  weight_decay=1e-6,
                  log_every=200):
        """
        Train corrector model:
          - x_feats: torch.FloatTensor [n_fine, in_dim] on device
          - edge_index: torch.LongTensor [2, n_edges] on device
          - U_init: numpy array [n_fine, n_modes] (will be normalized inside)
          - L_fine, M_fine: scipy sparse matrices
          - U_coarse: numpy array [n_coarse, n_modes]
          - P: scipy sparse prolongation (n_fine x n_coarse)
        Returns U_pred (numpy array [n_fine, n_modes]) - denormalized to original U_init scale.
        """
        device = self.device

        # Convert sparse matrices to torch sparse on device
        L_t = sp_to_torch_sparse(L_fine).to(device)
        M_t = sp_to_torch_sparse(M_fine).to(device)
        R_t = sp_to_torch_sparse(P.T).to(device)

        # Normalize columns of U_init and U_coarse (keep original norms for rescaling)
        U_init_normed, uinit_norms = normalize_columns_np(U_init)
        U_coarse_normed, ucoarse_norms = normalize_columns_np(U_coarse)

        U_init_t = torch.FloatTensor(U_init_normed).to(device)   # [n_fine, n_modes]
        U_coarse_t = torch.FloatTensor(U_coarse_normed).to(device)

        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)

        n_fine = U_init_t.shape[0]
        n_coarse = U_coarse_t.shape[0]
        denom_res = float(max(1, n_fine * n_modes))
        denom_proj = float(max(1, n_coarse * n_modes))
        I = torch.eye(n_modes, device=device)

        model.train()
        for ep in range(epochs):
            optimizer.zero_grad()

            corr_raw = model(x_feats, edge_index)  # [n_fine, n_modes]
            corr = corr_scale * corr_raw
            U_pred = U_init_t + corr

            # Rayleigh-related tensors
            Lu = torch.sparse.mm(L_t, U_pred)
            Mu = torch.sparse.mm(M_t, U_pred)
            num = torch.sum(U_pred * Lu, dim=0)
            den = torch.sum(U_pred * Mu, dim=0) + 1e-12
            lambdas = num / den

            # Residual loss (normalized)
            res = Lu - Mu * lambdas.unsqueeze(0)
            L_res = torch.sum(res**2) / denom_res

            # Orthonormality loss (M-weighted Gram)
            MUt = torch.sparse.mm(M_t, U_pred)
            Gram = U_pred.t() @ MUt
            L_orth = torch.sum((Gram - I)**2) / (n_modes * n_modes)

            # Projection loss
            proj = torch.sparse.mm(R_t, U_pred)
            L_proj = torch.sum((proj - U_coarse_t)**2) / denom_proj

            loss = w_res * L_res + w_orth * L_orth + w_proj * L_proj
            loss.backward()

            if grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), grad_clip)

            optimizer.step()

            if (ep % log_every == 0) or (ep == epochs - 1):
                with torch.no_grad():
                    u_norm = float(U_pred.norm().cpu().item())
                    corr_std = float(corr.std().cpu().item())
                print(f"    Epoch {ep:4d}: Loss={loss.item():.6f} (Res={L_res.item():.6f}, Orth={L_orth.item():.6f}, Proj={L_proj.item():.6f}) U_norm={u_norm:.4f} corr_std={corr_std:.6f}")

        # Denormalize: multiply columns by original column norms of U_init
        U_pred_np = U_pred.detach().cpu().numpy() * uinit_norms.reshape(1, -1)
        return U_pred_np

    # ------------------------
    # Rayleigh-Ritz refinement
    # ------------------------
    def refine_eigenvectors(self, U_pred, L_fine, M_fine):
        L_t = sp_to_torch_sparse(L_fine).to(self.device)
        M_t = sp_to_torch_sparse(M_fine).to(self.device)
        U = torch.FloatTensor(U_pred).to(self.device)
        A = (U.t() @ torch.sparse.mm(L_t, U)).cpu().numpy()
        B = (U.t() @ torch.sparse.mm(M_t, U)).cpu().numpy()
        vals, C = eigh(A, B)
        U_refined = U.cpu().numpy() @ C
        return vals, U_refined

    # ------------------------
    # Refine one level (coarse -> fine)
    # ------------------------
    def refine_level(self, X_coarse, U_coarse, X_fine, n_modes,
                     hidden_sizes=(128, 64, 32),
                     dropout=0.0,
                     k_neighbors=4,
                     epochs=200,
                     lr=1e-3,
                     corr_scale=1e-2,
                     w_res=10.0,
                     w_orth=1.0,
                     w_proj=1e-3,
                     freeze_layers=0,
                     checkpoint_name=None):
        """
        Single-level refinement.
        - Automatically sets input dim from features.
        - Creates model if not existing; reuses and optionally freezes layers if existing.
        """
        device = self.device
        print(f"  Computing Laplacian for {X_fine.shape[0]} points...")
        L_fine, M_fine = robust_laplacian.point_cloud_laplacian(X_fine)

        print("  Building prolongation operator...")
        P = self.build_prolongation(X_coarse, X_fine, k=1)

        print("  Building kNN graph...")
        edge_index = self.build_knn_graph(X_fine, k=k_neighbors).to(device)

        # Build U_init on fine grid
        U_init = P @ U_coarse  # shape [n_fine, n_modes]

        # Build features: coords + U_init (we pass raw U_init; normalization happens inside train_gnn)
        x_feats = torch.FloatTensor(np.hstack([X_fine, U_init])).to(device)

        in_dim = x_feats.shape[1]
        out_dim = n_modes

        if self.model is None:
            print(f"  Creating new corrector model (in_dim={in_dim}, out_dim={out_dim})...")
            self.model = SimpleCorrector(in_dim, out_dim, hidden_sizes=hidden_sizes, dropout=dropout).to(device)
        else:
            # If model exists but input dimension changed, re-create model to match new in_dim
            # (safer than trying to partially load weights with mismatched shapes)
            existing_in_dim = None
            # try to infer existing in_dim by checking first Linear in model.net if present
            for m in self.model.net:
                if isinstance(m, nn.Linear):
                    existing_in_dim = m.in_features
                    break
            if existing_in_dim != in_dim:
                print(f"  Recreating model to match new input dim (was {existing_in_dim}, now {in_dim})...")
                # Optionally copy weights for layers that match by size
                old_state = self.model.state_dict()
                self.model = SimpleCorrector(in_dim, out_dim, hidden_sizes=hidden_sizes, dropout=dropout).to(device)
                # attempt to copy subset of weights where shapes match
                new_state = self.model.state_dict()
                for k, v in old_state.items():
                    if k in new_state and old_state[k].shape == new_state[k].shape:
                        new_state[k] = old_state[k]
                self.model.load_state_dict(new_state)

        # Optionally freeze first few linear layers (count of Linear modules)
        if freeze_layers > 0:
            linear_count = 0
            for module in self.model.net:
                if isinstance(module, nn.Linear):
                    linear_count += 1
                    if linear_count <= freeze_layers:
                        for p in module.parameters():
                            p.requires_grad = False
            print(f"  Frozen first {freeze_layers} linear layers.")

        print(f"  Training corrector: epochs={epochs}, lr={lr}, corr_scale={corr_scale}")
        U_pred = self.train_gnn(self.model, x_feats, edge_index, U_init, L_fine, M_fine, U_coarse, P,
                                n_modes,
                                epochs=epochs,
                                lr=lr,
                                corr_scale=corr_scale,
                                w_res=w_res,
                                w_orth=w_orth,
                                w_proj=w_proj)

        print("  Rayleigh-Ritz refinement...")
        lambda_refined, U_refined = self.refine_eigenvectors(U_pred, L_fine, M_fine)

        if checkpoint_name is not None:
            ckpt = {"model_state": self.model.state_dict(), "lambda_refined": lambda_refined}
            torch.save(ckpt, os.path.join(self.checkpoint_dir, checkpoint_name))
            print(f"  Saved checkpoint: {os.path.join(self.checkpoint_dir, checkpoint_name)}")

        return lambda_refined, U_refined, L_fine, M_fine


# ------------------------
# Visualization helper
# ------------------------
def visualize_mesh(mesh, title='Mesh Visualization', highlight_indices=None, show=True):
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(projection='3d')
    ax.plot_trisurf(mesh.verts[:, 0], mesh.verts[:, 1], mesh.verts[:, 2],
                    triangles=mesh.connectivity, alpha=0.35)
    if highlight_indices is not None:
        hv = mesh.verts[highlight_indices]
        ax.scatter(hv[:, 0], hv[:, 1], hv[:, 2], s=6, label=f"{len(highlight_indices)} pts")
        ax.legend()
    ax.set_title(title)
    ax.view_init(elev=120, azim=-90)
    if show:
        plt.show()


# ------------------------
# Main script
# ------------------------
def main():
    mesh_path = "bunny.obj"
    n_modes = 10
    hidden_sizes = (128, 64, 32)
    dropout = 0.0

    # schedule and hyperparams
    epochs_schedule = {0: 0, 1: 1500, 2: 1000, 3: 800, 4: 800}
    hierarchy = [32, 128, 512, 1024]  # final level will append full
    k_neighbors = 4
    lr_start = 1e-3
    lr_min = 5e-4
    corr_scale = 1e-2
    w_res = 10.0
    w_orth = 10.0
    w_proj = 1e-3
    freeze_schedule = {1: 0, 2: 1, 3: 1, 4: 2}

    print("Loading mesh...")
    mesh = Mesh(mesh_path)
    mesh = MultigridEigensolver.normalize_mesh(mesh)
    X_full = mesh.verts
    n_total = X_full.shape[0]
    print(f"Mesh loaded: {n_total} vertices")

    hierarchy = [n for n in hierarchy if n <= n_total]
    if hierarchy[-1] != n_total:
        hierarchy.append(n_total)
    print("Hierarchy:", hierarchy)

    rng = np.random.default_rng(seed=42)
    all_idx = np.arange(n_total)
    rng.shuffle(all_idx)
    indices_per_level = {}
    for i, n_points in enumerate(hierarchy):
        indices_per_level[i] = all_idx[:n_points].copy()
        print(f"  Level {i}: {n_points} points (nested)")

    solver = MultigridEigensolver()

    # Level 0 coarse solve
    idx0 = indices_per_level[0]
    X0 = X_full[idx0]
    print("\nLEVEL 0: coarse solving...")
    lambda_cur, U_cur, L_cur, M_cur = solver.solve_eigenvalue_problem(X0, n_modes)
    print("Coarse eigenvalues:", np.round(lambda_cur, 6))

    # iterative refinement
    for level in range(1, len(hierarchy)):
        idx_coarse = indices_per_level[level - 1]
        idx_fine = indices_per_level[level]
        Xc = X_full[idx_coarse]
        Xf = X_full[idx_fine]
        epochs = epochs_schedule.get(level, 1000)

        print(f"\nLEVEL {level}: refine {Xc.shape[0]} -> {Xf.shape[0]}, epochs={epochs}")
        freeze_layers = freeze_schedule.get(level, 0)

        total_levels = len(hierarchy)
        decay = (level - 1) / max(1, total_levels - 1)
        lr = lr_start * ((lr_min / lr_start) ** decay)

        lambda_cur, U_cur, L_cur, M_cur = solver.refine_level(
            Xc, U_cur, Xf, n_modes,
            hidden_sizes=hidden_sizes,
            dropout=dropout,
            k_neighbors=k_neighbors,
            epochs=epochs,
            lr=lr,
            corr_scale=corr_scale,
            w_res=w_res,
            w_orth=w_orth,
            w_proj=w_proj,
            freeze_layers=freeze_layers,
            checkpoint_name=f"level_{level}_ckpt.pt"
        )

        print("GNN-refined eigenvalues:", np.round(lambda_cur, 6))

        # exact eigenvalues for verification
        print("  computing exact eigenvalues for verification...")
        lambda_exact, _, _, _ = solver.solve_eigenvalue_problem(Xf, n_modes)
        rel_err = np.abs(lambda_cur - lambda_exact) / (np.abs(lambda_exact) + 1e-12)
        print("  Exact eigenvalues:", np.round(lambda_exact, 6))
        print("  Relative errors:  ", np.round(rel_err, 6))

    print("\nDone. Final eigenvalues:", np.round(lambda_cur, 6))


if __name__ == "__main__":
    main()


Loading mesh...
Mesh loaded: 2503 vertices
Hierarchy: [32, 128, 512, 1024, 2503]
  Level 0: 32 points (nested)
  Level 1: 128 points (nested)
  Level 2: 512 points (nested)
  Level 3: 1024 points (nested)
  Level 4: 2503 points (nested)

LEVEL 0: coarse solving...
Coarse eigenvalues: [-0.        0.359878  0.848913  1.158359  1.507636  1.78121   1.971736
  2.292232  2.718251  2.782262]

LEVEL 1: refine 32 -> 128, epochs=1500
  Computing Laplacian for 128 points...
  Building prolongation operator...
  Building kNN graph...
  Creating new corrector model (in_dim=13, out_dim=10)...
  Training corrector: epochs=1500, lr=0.001, corr_scale=0.01
    Epoch    0: Loss=0.732555 (Res=0.012540, Orth=0.060707, Proj=0.075425) U_norm=3.1661 corr_std=0.000681
    Epoch  200: Loss=0.291884 (Res=0.009555, Orth=0.019600, Proj=0.331179) U_norm=5.7654 corr_std=0.085166
    Epoch  400: Loss=0.055486 (Res=0.005403, Orth=0.000102, Proj=0.442424) U_norm=6.6783 corr_std=0.113448
    Epoch  600: Loss=0.023319 (R

In [23]:
# multigrid_gnn_multires_physics.py
"""
Physics-informed Multigrid + GNN eigen-refinement
- Exact solve only on coarsest mesh
- Multiresolution GNN with residual + orthonormality + projection loss
- Coarse-to-fine prolongation only
"""

import os
import numpy as np
from scipy.linalg import eigh
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import eigsh
from sklearn.neighbors import NearestNeighbors

import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt

from Mesh import Mesh
import robust_laplacian

# ------------------------
# Utilities
# ------------------------
def sp_to_torch_sparse(A):
    A = A.tocoo()
    indices = np.vstack((A.row, A.col)).astype(np.int64)
    i = torch.LongTensor(indices)
    v = torch.FloatTensor(A.data)
    return torch.sparse_coo_tensor(i, v, A.shape).coalesce()

def normalize_columns_np(U, eps=1e-12):
    norms = np.linalg.norm(U, axis=0) + eps
    return U / norms, norms

def normalize_columns_torch(U, eps=1e-12):
    norms = torch.norm(U, dim=0) + eps
    return U / norms, norms

# ------------------------
# Simple neighbor-mean corrector
# ------------------------
class SimpleCorrector(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_sizes=(128,64,32), dropout=0.0):
        super().__init__()
        layers = []
        prev = in_dim * 2
        for h in hidden_sizes:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU(inplace=True))
            if dropout > 0.0:
                layers.append(nn.Dropout(dropout))
            prev = h
        layers.append(nn.Linear(prev, out_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x, edge_index):
        row, col = edge_index
        n = x.shape[0]
        agg = torch.zeros_like(x)
        agg.index_add_(0, row, x[col])
        deg = torch.bincount(row, minlength=n).unsqueeze(1).to(x.dtype).to(x.device).clamp(min=1.0)
        agg = agg / deg
        h = torch.cat([x, agg], dim=1)
        return self.net(h)

# ------------------------
# Multigrid GNN solver
# ------------------------
class MultigridGNN:
    def __init__(self, device=None, checkpoint_dir="./checkpoints"):
        self.device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
        self.model = None
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(self.checkpoint_dir, exist_ok=True)

    @staticmethod
    def normalize_mesh(mesh):
        centroid = mesh.verts.mean(0)
        std_max = mesh.verts.std(0).max() + 1e-12
        verts_normalized = (mesh.verts - centroid) / std_max
        return Mesh(verts=verts_normalized, connectivity=mesh.connectivity)

    @staticmethod
    def build_prolongation(X_coarse, X_fine, k=1):
        nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(X_coarse)
        distances, indices = nbrs.kneighbors(X_fine)
        n_fine, n_coarse = X_fine.shape[0], X_coarse.shape[0]
        rows, cols, vals = [], [], []
        for i in range(n_fine):
            weights = 1.0 / (distances[i] + 1e-12)
            weights /= weights.sum()
            for j, idx in enumerate(indices[i]):
                rows.append(i)
                cols.append(idx)
                vals.append(weights[j])
        return coo_matrix((vals, (rows, cols)), shape=(n_fine, n_coarse))

    @staticmethod
    def build_knn_graph(X, k=4):
        n_points = X.shape[0]
        nbrs = NearestNeighbors(n_neighbors=k+1).fit(X)
        _, neighbors = nbrs.kneighbors(X)
        rows, cols = [], []
        for i in range(n_points):
            for j in neighbors[i][1:]:
                rows.append(i)
                cols.append(j)
        return torch.LongTensor([rows, cols]).to(torch.long)

    def solve_eigenvalue_problem(self, X, n_modes):
        L, M = robust_laplacian.point_cloud_laplacian(X)
        vals, vecs = eigsh(L, k=n_modes, M=M, which='SM')
        return vals, np.array(vecs), L, M

    # ------------------------
    # Physics-informed GNN training
    # ------------------------
    def train_multiresolution(self, X_list, U_init_list, edge_index_list,
                              epochs=1000, lr=1e-3, corr_scale=1e-2,
                              w_res=10.0, w_orth=1.0, w_proj=1e-3,
                              grad_clip=1.0, weight_decay=1e-6, log_every=100):
        device = self.device
        n_modes = U_init_list[0].shape[1]

        # Build torch tensors and resolution indicators
        x_feats_all, U_all, edge_index_all = [], [], []
        node_offset = 0
        max_nodes = max([X.shape[0] for X in X_list])
        for X, U_init, edge_index in zip(X_list, U_init_list, edge_index_list):
            res_feat = np.full((X.shape[0], 1), X.shape[0]/max_nodes)
            x_feats_all.append(np.hstack([X, U_init, res_feat]))
            U_all.append(U_init)
            edge_index_all.append(edge_index + node_offset)
            node_offset += X.shape[0]

        x_feats_all = torch.FloatTensor(np.vstack(x_feats_all)).to(device)
        U_all_tensor = torch.FloatTensor(np.vstack(U_all)).to(device)
        edge_index_all = torch.cat(edge_index_all, dim=1).to(device)

        in_dim = x_feats_all.shape[1]
        if self.model is None:
            self.model = SimpleCorrector(in_dim, n_modes).to(device)

        optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=lr, weight_decay=weight_decay)
        self.model.train()

        # ------------------------
        # Precompute Laplacians per level
        L_list, M_list = [], []
        node_offset = 0
        for X in X_list:
            L, M = robust_laplacian.point_cloud_laplacian(X)
            L_list.append(sp_to_torch_sparse(L).to(device))
            M_list.append(sp_to_torch_sparse(M).to(device))

        for ep in range(epochs):
            optimizer.zero_grad()
            corr_raw = self.model(x_feats_all, edge_index_all)
            corr = corr_scale * corr_raw
            U_pred = U_all_tensor + corr

            # Physics-informed loss
            loss = 0.0
            node_offset = 0
            for i, (L_t, M_t, U_init) in enumerate(zip(L_list, M_list, U_init_list)):
                n_nodes = U_init.shape[0]
                U_level = U_pred[node_offset:node_offset+n_nodes]

                # Rayleigh residual
                Lu = torch.sparse.mm(L_t, U_level)
                Mu = torch.sparse.mm(M_t, U_level)
                num = torch.sum(U_level * Lu, dim=0)
                den = torch.sum(U_level * Mu, dim=0) + 1e-12
                lambdas = num / den
                res = Lu - Mu * lambdas.unsqueeze(0)
                L_res = torch.mean(res**2)

                # Orthonormality
                Gram = U_level.t() @ Mu
                L_orth = torch.mean((Gram - torch.eye(n_modes, device=device))**2)

                loss += w_res * L_res + w_orth * L_orth
                node_offset += n_nodes

            loss.backward()
            if grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, self.model.parameters()), grad_clip)
            optimizer.step()

            if ep % log_every == 0 or ep == epochs-1:
                print(f"Epoch {ep:4d}: Loss={loss.item():.6f}")

        return U_pred.detach().cpu().numpy()

    # ------------------------
    # Rayleigh-Ritz refinement
    # ------------------------
    def refine_eigenvectors(self, U_pred, L, M):
        U = torch.FloatTensor(U_pred).to(self.device)
        L_t = sp_to_torch_sparse(L).to(self.device)
        M_t = sp_to_torch_sparse(M).to(self.device)
        A = (U.t() @ torch.sparse.mm(L_t, U)).cpu().numpy()
        B = (U.t() @ torch.sparse.mm(M_t, U)).cpu().numpy()
        vals, C = eigh(A, B)
        U_refined = U.cpu().numpy() @ C
        return vals, U_refined

# ------------------------
# Main
# ------------------------
def main():
    mesh_path = "bunny.obj"
    n_modes = 10
    hierarchy = [128, 512, 1024]  # final level is full mesh
    k_neighbors = 4
    epochs = 1000

    print("Loading mesh...")
    mesh = Mesh(mesh_path)
    mesh = MultigridGNN.normalize_mesh(mesh)
    X_full = mesh.verts
    n_total = X_full.shape[0]
    hierarchy = [n for n in hierarchy if n <= n_total]
    if hierarchy[-1] != n_total:
        hierarchy.append(n_total)
    print("Hierarchy:", hierarchy)

    rng = np.random.default_rng(seed=42)
    all_idx = np.arange(n_total)
    rng.shuffle(all_idx)
    indices_per_level = {i: all_idx[:n].copy() for i,n in enumerate(hierarchy)}

    solver = MultigridGNN()

    # ------------------------
    # Level 0: exact coarse solve
    # ------------------------
    idx0 = indices_per_level[0]
    X0 = X_full[idx0]
    print(f"\nLEVEL 0: exact solve on {X0.shape[0]} points...")
    lambda0, U0, L0, M0 = solver.solve_eigenvalue_problem(X0, n_modes)
    print("Coarse eigenvalues:", np.round(lambda0,6))

    # ------------------------
    # Coarse-to-fine prolongation
    # ------------------------
    U_prev = U0.copy()
    X_list, U_init_list, edge_index_list = [X0], [U0], [solver.build_knn_graph(X0, k=k_neighbors)]
    for level in range(1, len(hierarchy)):
        idx_coarse = indices_per_level[level-1]
        idx_fine = indices_per_level[level]
        Xc = X_full[idx_coarse]
        Xf = X_full[idx_fine]

        P = solver.build_prolongation(Xc, Xf, k=1)
        U_init = P @ U_prev
        edge_index = solver.build_knn_graph(Xf, k=k_neighbors)

        X_list.append(Xf)
        U_init_list.append(U_init)
        edge_index_list.append(edge_index)

        U_prev = U_init.copy()

    # ------------------------
    # Train physics-informed GNN
    # ------------------------
    print("\nTraining physics-informed multiresolution GNN...")
    U_pred_all = solver.train_multiresolution(X_list, U_init_list, edge_index_list,
                                              epochs=epochs)

    # ------------------------
    # Rayleigh-Ritz refinement per level
    # ------------------------
    node_offset = 0
    for level, X in enumerate(X_list):
        n_nodes = X.shape[0]
        U_pred = U_pred_all[node_offset:node_offset+n_nodes]
        node_offset += n_nodes
        L, M = robust_laplacian.point_cloud_laplacian(X)
        vals_refined, _ = solver.refine_eigenvectors(U_pred, L, M)
        print(f"Level {level} refined eigenvalues: {np.round(vals_refined,3)}")

if __name__ == "__main__":
    main()


Loading mesh...
Hierarchy: [128, 512, 1024, 2503]

LEVEL 0: exact solve on 128 points...
Coarse eigenvalues: [0.       0.464393 0.860068 1.196359 1.413617 1.861207 2.273839 2.759782
 2.87012  3.25778 ]

Training physics-informed multiresolution GNN...
Epoch    0: Loss=0.554847
Epoch  100: Loss=0.402844
Epoch  200: Loss=0.218859
Epoch  300: Loss=0.112217
Epoch  400: Loss=0.056183
Epoch  500: Loss=0.037726
Epoch  600: Loss=0.028144
Epoch  700: Loss=0.023266
Epoch  800: Loss=0.020314
Epoch  900: Loss=0.018311
Epoch  999: Loss=0.016806
Level 0 refined eigenvalues: [0.    0.473 0.877 1.211 1.43  1.878 2.299 2.789 2.903 3.366]
Level 1 refined eigenvalues: [0.    0.457 0.927 1.269 1.435 1.89  2.057 2.912 3.185 3.799]
Level 2 refined eigenvalues: [0.    0.461 0.951 1.179 1.278 1.448 1.948 2.99  3.298 3.702]
Level 3 refined eigenvalues: [0.    0.505 1.009 1.378 1.39  1.585 2.087 3.245 3.536 3.838]


In [16]:
#Exact eigenvalues: [0.000, 0.288, 0.722, 0.842, 1.039, 1.202, 1.762, 2.600, 2.923, 2.973]

In [33]:
# multigrid_gnn_refine_fixed.py
"""
Complete working rewrite of your multigrid + GNN eigen-refinement pipeline.

Assumptions:
 - `Mesh` class exists and Mesh('bunny.obj') loads .verts (n x 3) and .connectivity (triangles).
 - `robust_laplacian.point_cloud_laplacian(X)` returns (L, M) as scipy sparse matrices where L and M are compatible with eigsh.
 - scikit-learn, scipy, numpy, matplotlib, torch are installed.

Key features:
 - All classes and functions defined in one file (no missing names).
 - Stable training: column normalization, small correction scale, normalized losses, grad clipping, configurable weights.
 - Auto-detect input feature dimension to avoid matmul mismatches.
"""

import os
import numpy as np
from scipy.linalg import eigh
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import eigsh
from sklearn.neighbors import NearestNeighbors

import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt

# Project imports (must be available in your environment)
from Mesh import Mesh
import robust_laplacian


# ------------------------
# Utility helpers
# ------------------------
def sp_to_torch_sparse(A):
    """Convert scipy sparse matrix to torch.sparse_coo_tensor (CPU or GPU depending on .to(device))."""
    A = A.tocoo()
    indices = np.vstack((A.row, A.col)).astype(np.int64)
    i = torch.LongTensor(indices)
    v = torch.FloatTensor(A.data)
    return torch.sparse_coo_tensor(i, v, A.shape).coalesce()


def normalize_columns_np(U, eps=1e-12):
    """Normalize numpy matrix columns to have unit L2 norm. Returns normalized U and norms."""
    norms = np.linalg.norm(U, axis=0) + eps
    return U / norms, norms


def normalize_columns_torch(U, eps=1e-12):
    """Normalize torch tensor columns to have unit L2 norm. Returns normalized U and norms (torch)."""
    norms = torch.norm(U, dim=0) + eps
    return U / norms, norms


# ------------------------
# Simple per-node corrector (message-passing via neighbor mean + MLP)
# ------------------------
class SimpleCorrector(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_sizes=(128, 64, 32), dropout=0.0):
        super().__init__()
        layers = []
        prev = in_dim * 2  # because we will concat self + neighbor-mean in forward
        for h in hidden_sizes:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU(inplace=True))
            if dropout > 0.0:
                layers.append(nn.Dropout(dropout))
            prev = h
        layers.append(nn.Linear(prev, out_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x, edge_index):
        """
        x: [n, in_dim]
        edge_index: LongTensor shape [2, n_edges] with (row=target, col=source) semantics
        """
        row, col = edge_index  # both LongTensor
        n = x.shape[0]
        # aggregate neighbor features: mean aggregator
        agg = torch.zeros_like(x)
        agg.index_add_(0, row, x[col])
        deg = torch.bincount(row, minlength=n).unsqueeze(1).to(x.dtype).to(x.device)
        deg = deg.clamp(min=1.0)
        agg = agg / deg
        h = torch.cat([x, agg], dim=1)  # shape [n, 2*in_dim]
        return self.net(h)


# ------------------------
# Multigrid eigensolver with GNN corrector
# ------------------------
class MultigridEigensolver:
    def __init__(self, device=None, checkpoint_dir="./checkpoints"):
        self.device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
        self.model = None
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(self.checkpoint_dir, exist_ok=True)

    @staticmethod
    def normalize_mesh(mesh):
        centroid = mesh.verts.mean(0)
        std_max = mesh.verts.std(0).max() + 1e-12
        verts_normalized = (mesh.verts - centroid) / std_max
        return Mesh(verts=verts_normalized, connectivity=mesh.connectivity)

    @staticmethod
    def build_prolongation(X_coarse, X_fine, k=1):
        nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(X_coarse)
        distances, indices = nbrs.kneighbors(X_fine)
        n_fine, n_coarse = X_fine.shape[0], X_coarse.shape[0]
        rows, cols, vals = [], [], []
        for i in range(n_fine):
            weights = 1.0 / (distances[i] + 1e-12)
            weights /= weights.sum()
            for j, idx in enumerate(indices[i]):
                rows.append(i)
                cols.append(idx)
                vals.append(weights[j])
        return coo_matrix((vals, (rows, cols)), shape=(n_fine, n_coarse))

    @staticmethod
    def build_knn_graph(X, k=4):
        n_points = X.shape[0]
        nbrs = NearestNeighbors(n_neighbors=k + 1).fit(X)
        _, neighbors = nbrs.kneighbors(X)
        rows, cols = [], []
        for i in range(n_points):
            for j in neighbors[i][1:]:
                rows.append(i)
                cols.append(j)
        return torch.LongTensor([rows, cols]).to(torch.long)

    def solve_eigenvalue_problem(self, X, n_modes):
        L, M = robust_laplacian.point_cloud_laplacian(X)
        # use eigsh with M as mass
        vals, vecs = eigsh(L, k=n_modes, M=M, which='SM')
        return vals, np.array(vecs), L, M

    # ------------------------
    # Core training routine
    # ------------------------
    def train_gnn(self, model, x_feats, edge_index, U_init, L_fine, M_fine, U_coarse, P,
                  n_modes,
                  epochs=200,
                  lr=1e-3,
                  corr_scale=1e-2,
                  w_res=10.0,
                  w_orth=1.0,
                  w_proj=1e-3,
                  grad_clip=1.0,
                  weight_decay=1e-6,
                  log_every=200):
        """
        Train corrector model:
          - x_feats: torch.FloatTensor [n_fine, in_dim] on device
          - edge_index: torch.LongTensor [2, n_edges] on device
          - U_init: numpy array [n_fine, n_modes] (will be normalized inside)
          - L_fine, M_fine: scipy sparse matrices
          - U_coarse: numpy array [n_coarse, n_modes]
          - P: scipy sparse prolongation (n_fine x n_coarse)
        Returns U_pred (numpy array [n_fine, n_modes]) - denormalized to original U_init scale.
        """
        device = self.device

        # Convert sparse matrices to torch sparse on device
        L_t = sp_to_torch_sparse(L_fine).to(device)
        M_t = sp_to_torch_sparse(M_fine).to(device)
        R_t = sp_to_torch_sparse(P.T).to(device)

        # Normalize columns of U_init and U_coarse (keep original norms for rescaling)
        U_init_normed, uinit_norms = normalize_columns_np(U_init)
        U_coarse_normed, ucoarse_norms = normalize_columns_np(U_coarse)

        U_init_t = torch.FloatTensor(U_init_normed).to(device)   # [n_fine, n_modes]
        U_coarse_t = torch.FloatTensor(U_coarse_normed).to(device)

        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)

        n_fine = U_init_t.shape[0]
        n_coarse = U_coarse_t.shape[0]
        denom_res = float(max(1, n_fine * n_modes))
        denom_proj = float(max(1, n_coarse * n_modes))
        I = torch.eye(n_modes, device=device)

        model.train()
        for ep in range(epochs):
            optimizer.zero_grad()

            corr_raw = model(x_feats, edge_index)  # [n_fine, n_modes]
            corr = corr_scale * corr_raw
            U_pred = U_init_t + corr

            # Rayleigh-related tensors
            Lu = torch.sparse.mm(L_t, U_pred)
            Mu = torch.sparse.mm(M_t, U_pred)
            num = torch.sum(U_pred * Lu, dim=0)
            den = torch.sum(U_pred * Mu, dim=0) + 1e-12
            lambdas = num / den

            # Residual loss (normalized)
            res = Lu - Mu * lambdas.unsqueeze(0)
            L_res = torch.sum(res**2) / denom_res

            # Orthonormality loss (M-weighted Gram)
            MUt = torch.sparse.mm(M_t, U_pred)
            Gram = U_pred.t() @ MUt
            L_orth = torch.sum((Gram - I)**2) / (n_modes * n_modes)

            # Projection loss
            proj = torch.sparse.mm(R_t, U_pred)
            L_proj = torch.sum((proj - U_coarse_t)**2) / denom_proj

            loss = w_res * L_res + w_orth * L_orth + w_proj * L_proj
            loss.backward()

            if grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), grad_clip)

            optimizer.step()

            if (ep % log_every == 0) or (ep == epochs - 1):
                with torch.no_grad():
                    u_norm = float(U_pred.norm().cpu().item())
                    corr_std = float(corr.std().cpu().item())
                print(f"    Epoch {ep:4d}: Loss={loss.item():.6f} (Res={L_res.item():.6f}, Orth={L_orth.item():.6f}, Proj={L_proj.item():.6f}) U_norm={u_norm:.4f} corr_std={corr_std:.6f}")

        # Denormalize: multiply columns by original column norms of U_init
        U_pred_np = U_pred.detach().cpu().numpy() * uinit_norms.reshape(1, -1)
        return U_pred_np

    # ------------------------
    # Rayleigh-Ritz refinement
    # ------------------------
    def refine_eigenvectors(self, U_pred, L_fine, M_fine):
        L_t = sp_to_torch_sparse(L_fine).to(self.device)
        M_t = sp_to_torch_sparse(M_fine).to(self.device)
        U = torch.FloatTensor(U_pred).to(self.device)
        A = (U.t() @ torch.sparse.mm(L_t, U)).cpu().numpy()
        B = (U.t() @ torch.sparse.mm(M_t, U)).cpu().numpy()
        vals, C = eigh(A, B)
        U_refined = U.cpu().numpy() @ C
        return vals, U_refined

    # ------------------------
    # Refine one level (coarse -> fine)
    # ------------------------
    def refine_level(self, X_coarse, U_coarse, X_fine, n_modes,
                     hidden_sizes=(128, 64, 32),
                     dropout=0.0,
                     k_neighbors=4,
                     epochs=200,
                     lr=1e-3,
                     corr_scale=1e-2,
                     w_res=10.0,
                     w_orth=1.0,
                     w_proj=1e-3,
                     freeze_layers=0,
                     checkpoint_name=None):
        """
        Single-level refinement.
        - Automatically sets input dim from features.
        - Creates model if not existing; reuses and optionally freezes layers if existing.
        """
        device = self.device
        print(f"  Computing Laplacian for {X_fine.shape[0]} points...")
        L_fine, M_fine = robust_laplacian.point_cloud_laplacian(X_fine)

        print("  Building prolongation operator...")
        P = self.build_prolongation(X_coarse, X_fine, k=1)

        print("  Building kNN graph...")
        edge_index = self.build_knn_graph(X_fine, k=k_neighbors).to(device)

        # Build U_init on fine grid
        U_init = P @ U_coarse  # shape [n_fine, n_modes]

        # Build features: coords + U_init (we pass raw U_init; normalization happens inside train_gnn)
        x_feats = torch.FloatTensor(np.hstack([X_fine, U_init])).to(device)

        in_dim = x_feats.shape[1]
        out_dim = n_modes

        if self.model is None:
            print(f"  Creating new corrector model (in_dim={in_dim}, out_dim={out_dim})...")
            self.model = SimpleCorrector(in_dim, out_dim, hidden_sizes=hidden_sizes, dropout=dropout).to(device)
        else:
            # If model exists but input dimension changed, re-create model to match new in_dim
            # (safer than trying to partially load weights with mismatched shapes)
            existing_in_dim = None
            # try to infer existing in_dim by checking first Linear in model.net if present
            for m in self.model.net:
                if isinstance(m, nn.Linear):
                    existing_in_dim = m.in_features
                    break
            if existing_in_dim != in_dim:
                print(f"  Recreating model to match new input dim (was {existing_in_dim}, now {in_dim})...")
                # Optionally copy weights for layers that match by size
                old_state = self.model.state_dict()
                self.model = SimpleCorrector(in_dim, out_dim, hidden_sizes=hidden_sizes, dropout=dropout).to(device)
                # attempt to copy subset of weights where shapes match
                new_state = self.model.state_dict()
                for k, v in old_state.items():
                    if k in new_state and old_state[k].shape == new_state[k].shape:
                        new_state[k] = old_state[k]
                self.model.load_state_dict(new_state)

        # Optionally freeze first few linear layers (count of Linear modules)
        if freeze_layers > 0:
            linear_count = 0
            for module in self.model.net:
                if isinstance(module, nn.Linear):
                    linear_count += 1
                    if linear_count <= freeze_layers:
                        for p in module.parameters():
                            p.requires_grad = False
            print(f"  Frozen first {freeze_layers} linear layers.")

        print(f"  Training corrector: epochs={epochs}, lr={lr}, corr_scale={corr_scale}")
        U_pred = self.train_gnn(self.model, x_feats, edge_index, U_init, L_fine, M_fine, U_coarse, P,
                                n_modes,
                                epochs=epochs,
                                lr=lr,
                                corr_scale=corr_scale,
                                w_res=w_res,
                                w_orth=w_orth,
                                w_proj=w_proj)

        print("  Rayleigh-Ritz refinement...")
        lambda_refined, U_refined = self.refine_eigenvectors(U_pred, L_fine, M_fine)

        if checkpoint_name is not None:
            ckpt = {"model_state": self.model.state_dict(), "lambda_refined": lambda_refined}
            torch.save(ckpt, os.path.join(self.checkpoint_dir, checkpoint_name))
            print(f"  Saved checkpoint: {os.path.join(self.checkpoint_dir, checkpoint_name)}")

        return lambda_refined, U_refined, L_fine, M_fine


# ------------------------
# Visualization helper
# ------------------------
def visualize_mesh(mesh, title='Mesh Visualization', highlight_indices=None, show=True):
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(projection='3d')
    ax.plot_trisurf(mesh.verts[:, 0], mesh.verts[:, 1], mesh.verts[:, 2],
                    triangles=mesh.connectivity, alpha=0.35)
    if highlight_indices is not None:
        hv = mesh.verts[highlight_indices]
        ax.scatter(hv[:, 0], hv[:, 1], hv[:, 2], s=6, label=f"{len(highlight_indices)} pts")
        ax.legend()
    ax.set_title(title)
    ax.view_init(elev=120, azim=-90)
    if show:
        plt.show()


# ------------------------
# Main script
# ------------------------
def main():
    mesh_path = "bunny.obj"
    n_modes = 10
    hidden_sizes = (128, 128, 128)
    dropout = 0.0

    # schedule and hyperparams
    epochs_schedule = {0: 0, 1: 1500, 2: 1000, 3: 800, 4: 800}
    hierarchy = [128, 512, 1024]  # final level will append full
    k_neighbors = 4
    lr_start = 1e-3
    lr_min = 5e-4
    corr_scale = 1e-2
    w_res = 10.0
    w_orth = 10.0
    w_proj = 1e-3
    freeze_schedule = {1: 0, 2: 1, 3: 1, 4: 2}

    print("Loading mesh...")
    mesh = Mesh(mesh_path)
    mesh = MultigridEigensolver.normalize_mesh(mesh)
    X_full = mesh.verts
    n_total = X_full.shape[0]
    print(f"Mesh loaded: {n_total} vertices")

    hierarchy = [n for n in hierarchy if n <= n_total]
    if hierarchy[-1] != n_total:
        hierarchy.append(n_total)
    print("Hierarchy:", hierarchy)

    rng = np.random.default_rng(seed=42)
    all_idx = np.arange(n_total)
    rng.shuffle(all_idx)
    indices_per_level = {}
    for i, n_points in enumerate(hierarchy):
        indices_per_level[i] = all_idx[:n_points].copy()
        print(f"  Level {i}: {n_points} points (nested)")

    solver = MultigridEigensolver()

    # Level 0 coarse solve
    idx0 = indices_per_level[0]
    X0 = X_full[idx0]
    print("\nLEVEL 0: coarse solving...")
    lambda_cur, U_cur, L_cur, M_cur = solver.solve_eigenvalue_problem(X0, n_modes)
    print("Coarse eigenvalues:", np.round(lambda_cur, 6))

    # iterative refinement
    for level in range(1, len(hierarchy)):
        idx_coarse = indices_per_level[level - 1]
        idx_fine = indices_per_level[level]
        Xc = X_full[idx_coarse]
        Xf = X_full[idx_fine]
        epochs = epochs_schedule.get(level, 1000)

        print(f"\nLEVEL {level}: refine {Xc.shape[0]} -> {Xf.shape[0]}, epochs={epochs}")
        freeze_layers = freeze_schedule.get(level, 0)

        total_levels = len(hierarchy)
        decay = (level - 1) / max(1, total_levels - 1)
        lr = lr_start * ((lr_min / lr_start) ** decay)

        lambda_cur, U_cur, L_cur, M_cur = solver.refine_level(
            Xc, U_cur, Xf, n_modes,
            hidden_sizes=hidden_sizes,
            dropout=dropout,
            k_neighbors=k_neighbors,
            epochs=epochs,
            lr=lr,
            corr_scale=corr_scale,
            w_res=w_res,
            w_orth=w_orth,
            w_proj=w_proj,
            freeze_layers=freeze_layers,
            checkpoint_name=f"level_{level}_ckpt.pt"
        )

        print("GNN-refined eigenvalues:", np.round(lambda_cur, 6))

        # exact eigenvalues for verification
        print("  computing exact eigenvalues for verification...")
        lambda_exact, _, _, _ = solver.solve_eigenvalue_problem(Xf, n_modes)
        rel_err = np.abs(lambda_cur - lambda_exact) / (np.abs(lambda_exact) + 1e-12)
        print("  Exact eigenvalues:", np.round(lambda_exact, 6))
        print("  Relative errors:  ", np.round(rel_err, 6))

    print("\nDone. Final eigenvalues:", np.round(lambda_cur, 6))


if __name__ == "__main__":
    main()


Loading mesh...
Mesh loaded: 2503 vertices
Hierarchy: [128, 512, 1024, 2503]
  Level 0: 128 points (nested)
  Level 1: 512 points (nested)
  Level 2: 1024 points (nested)
  Level 3: 2503 points (nested)

LEVEL 0: coarse solving...
Coarse eigenvalues: [0.       0.464393 0.860068 1.196359 1.413617 1.861207 2.273839 2.759782
 2.87012  3.25778 ]

LEVEL 1: refine 128 -> 512, epochs=1500
  Computing Laplacian for 512 points...
  Building prolongation operator...
  Building kNN graph...
  Creating new corrector model (in_dim=13, out_dim=10)...
  Training corrector: epochs=1500, lr=0.001, corr_scale=0.01
    Epoch    0: Loss=0.896934 (Res=0.001216, Orth=0.088476, Proj=0.016340) U_norm=3.1672 corr_std=0.000563
    Epoch  200: Loss=0.175740 (Res=0.002795, Orth=0.014735, Proj=0.435962) U_norm=11.7478 corr_std=0.122789
    Epoch  400: Loss=0.018089 (Res=0.001746, Orth=0.000009, Proj=0.545564) U_norm=13.0934 corr_std=0.141240
    Epoch  600: Loss=0.010099 (Res=0.000952, Orth=0.000004, Proj=0.546351