In [1]:
# multigrid_gnn_refine_fixed.py
"""
Complete working rewrite of your multigrid + GNN eigen-refinement pipeline.

Assumptions:
 - `Mesh` class exists and Mesh('bunny.obj') loads .verts (n x 3) and .connectivity (triangles).
 - `robust_laplacian.point_cloud_laplacian(X)` returns (L, M) as scipy sparse matrices where L and M are compatible with eigsh.
 - scikit-learn, scipy, numpy, matplotlib, torch are installed.

Key features:
 - All classes and functions defined in one file (no missing names).
 - Stable training: column normalization, small correction scale, normalized losses, grad clipping, configurable weights.
 - Auto-detect input feature dimension to avoid matmul mismatches.
"""

import os
import numpy as np
from scipy.linalg import eigh
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import eigsh
from sklearn.neighbors import NearestNeighbors

import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt

# Project imports (must be available in your environment)
from Mesh import Mesh
import robust_laplacian


# ------------------------
# Utility helpers
# ------------------------
def sp_to_torch_sparse(A):
    """Convert scipy sparse matrix to torch.sparse_coo_tensor (CPU or GPU depending on .to(device))."""
    A = A.tocoo()
    indices = np.vstack((A.row, A.col)).astype(np.int64)
    i = torch.LongTensor(indices)
    v = torch.FloatTensor(A.data)
    return torch.sparse_coo_tensor(i, v, A.shape).coalesce()


def normalize_columns_np(U, eps=1e-12):
    """Normalize numpy matrix columns to have unit L2 norm. Returns normalized U and norms."""
    norms = np.linalg.norm(U, axis=0) + eps
    return U / norms, norms


def normalize_columns_torch(U, eps=1e-12):
    """Normalize torch tensor columns to have unit L2 norm. Returns normalized U and norms (torch)."""
    norms = torch.norm(U, dim=0) + eps
    return U / norms, norms


# ------------------------
# Simple per-node corrector (message-passing via neighbor mean + MLP)
# ------------------------
class SimpleCorrector(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_sizes=(128, 64, 32), dropout=0.0):
        super().__init__()
        layers = []
        prev = in_dim * 2  # because we will concat self + neighbor-mean in forward
        for h in hidden_sizes:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU(inplace=True))
            if dropout > 0.0:
                layers.append(nn.Dropout(dropout))
            prev = h
        layers.append(nn.Linear(prev, out_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x, edge_index):
        """
        x: [n, in_dim]
        edge_index: LongTensor shape [2, n_edges] with (row=target, col=source) semantics
        """
        row, col = edge_index  # both LongTensor
        n = x.shape[0]
        # aggregate neighbor features: mean aggregator
        agg = torch.zeros_like(x)
        agg.index_add_(0, row, x[col])
        deg = torch.bincount(row, minlength=n).unsqueeze(1).to(x.dtype).to(x.device)
        deg = deg.clamp(min=1.0)
        agg = agg / deg
        h = torch.cat([x, agg], dim=1)  # shape [n, 2*in_dim]
        return self.net(h)


# ------------------------
# Multigrid eigensolver with GNN corrector
# ------------------------
class MultigridEigensolver:
    def __init__(self, device=None, checkpoint_dir="./checkpoints"):
        self.device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
        self.model = None
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(self.checkpoint_dir, exist_ok=True)

    @staticmethod
    def normalize_mesh(mesh):
        centroid = mesh.verts.mean(0)
        std_max = mesh.verts.std(0).max() + 1e-12
        verts_normalized = (mesh.verts - centroid) / std_max
        return Mesh(verts=verts_normalized, connectivity=mesh.connectivity)

    @staticmethod
    def build_prolongation(X_coarse, X_fine, k=1):
        nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(X_coarse)
        distances, indices = nbrs.kneighbors(X_fine)
        n_fine, n_coarse = X_fine.shape[0], X_coarse.shape[0]
        rows, cols, vals = [], [], []
        for i in range(n_fine):
            weights = 1.0 / (distances[i] + 1e-12)
            weights /= weights.sum()
            for j, idx in enumerate(indices[i]):
                rows.append(i)
                cols.append(idx)
                vals.append(weights[j])
        return coo_matrix((vals, (rows, cols)), shape=(n_fine, n_coarse))

    @staticmethod
    def build_knn_graph(X, k=4):
        n_points = X.shape[0]
        nbrs = NearestNeighbors(n_neighbors=k + 1).fit(X)
        _, neighbors = nbrs.kneighbors(X)
        rows, cols = [], []
        for i in range(n_points):
            for j in neighbors[i][1:]:
                rows.append(i)
                cols.append(j)
        return torch.LongTensor([rows, cols]).to(torch.long)

    def solve_eigenvalue_problem(self, X, n_modes):
        L, M = robust_laplacian.point_cloud_laplacian(X)
        # use eigsh with M as mass
        vals, vecs = eigsh(L, k=n_modes, M=M, which='SM')
        return vals, np.array(vecs), L, M

    # ------------------------
    # Core training routine
    # ------------------------
    def train_gnn(self, model, x_feats, edge_index, U_init, L_fine, M_fine, U_coarse, P,
                  n_modes,
                  epochs=200,
                  lr=1e-3,
                  corr_scale=1e-2,
                  w_res=10.0,
                  w_orth=1.0,
                  w_proj=1e-3,
                  grad_clip=1.0,
                  weight_decay=1e-6,
                  log_every=200):
        """
        Train corrector model:
          - x_feats: torch.FloatTensor [n_fine, in_dim] on device
          - edge_index: torch.LongTensor [2, n_edges] on device
          - U_init: numpy array [n_fine, n_modes] (will be normalized inside)
          - L_fine, M_fine: scipy sparse matrices
          - U_coarse: numpy array [n_coarse, n_modes]
          - P: scipy sparse prolongation (n_fine x n_coarse)
        Returns U_pred (numpy array [n_fine, n_modes]) - denormalized to original U_init scale.
        """
        device = self.device

        # Convert sparse matrices to torch sparse on device
        L_t = sp_to_torch_sparse(L_fine).to(device)
        M_t = sp_to_torch_sparse(M_fine).to(device)
        R_t = sp_to_torch_sparse(P.T).to(device)

        # Normalize columns of U_init and U_coarse (keep original norms for rescaling)
        U_init_normed, uinit_norms = normalize_columns_np(U_init)
        U_coarse_normed, ucoarse_norms = normalize_columns_np(U_coarse)

        U_init_t = torch.FloatTensor(U_init_normed).to(device)   # [n_fine, n_modes]
        U_coarse_t = torch.FloatTensor(U_coarse_normed).to(device)

        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)

        n_fine = U_init_t.shape[0]
        n_coarse = U_coarse_t.shape[0]
        denom_res = float(max(1, n_fine * n_modes))
        denom_proj = float(max(1, n_coarse * n_modes))
        I = torch.eye(n_modes, device=device)

        model.train()
        for ep in range(epochs):
            optimizer.zero_grad()

            corr_raw = model(x_feats, edge_index)  # [n_fine, n_modes]
            corr = corr_scale * corr_raw
            U_pred = U_init_t + corr

            # Rayleigh-related tensors
            Lu = torch.sparse.mm(L_t, U_pred)
            Mu = torch.sparse.mm(M_t, U_pred)
            num = torch.sum(U_pred * Lu, dim=0)
            den = torch.sum(U_pred * Mu, dim=0) + 1e-12
            lambdas = num / den

            # Residual loss (normalized)
            res = Lu - Mu * lambdas.unsqueeze(0)
            L_res = torch.sum(res**2) / denom_res

            # Orthonormality loss (M-weighted Gram)
            MUt = torch.sparse.mm(M_t, U_pred)
            Gram = U_pred.t() @ MUt
            L_orth = torch.sum((Gram - I)**2) / (n_modes * n_modes)

            # Projection loss
            proj = torch.sparse.mm(R_t, U_pred)
            L_proj = torch.sum((proj - U_coarse_t)**2) / denom_proj

            loss = w_res * L_res + w_orth * L_orth + w_proj * L_proj
            loss.backward()

            if grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), grad_clip)

            optimizer.step()

            if (ep % log_every == 0) or (ep == epochs - 1):
                with torch.no_grad():
                    u_norm = float(U_pred.norm().cpu().item())
                    corr_std = float(corr.std().cpu().item())
                print(f"    Epoch {ep:4d}: Loss={loss.item():.6f} (Res={L_res.item():.6f}, Orth={L_orth.item():.6f}, Proj={L_proj.item():.6f}) U_norm={u_norm:.4f} corr_std={corr_std:.6f}")

        # Denormalize: multiply columns by original column norms of U_init
        U_pred_np = U_pred.detach().cpu().numpy() * uinit_norms.reshape(1, -1)
        return U_pred_np

    # ------------------------
    # Rayleigh-Ritz refinement
    # ------------------------
    def refine_eigenvectors(self, U_pred, L_fine, M_fine):
        L_t = sp_to_torch_sparse(L_fine).to(self.device)
        M_t = sp_to_torch_sparse(M_fine).to(self.device)
        U = torch.FloatTensor(U_pred).to(self.device)
        A = (U.t() @ torch.sparse.mm(L_t, U)).cpu().numpy()
        B = (U.t() @ torch.sparse.mm(M_t, U)).cpu().numpy()
        vals, C = eigh(A, B)
        U_refined = U.cpu().numpy() @ C
        return vals, U_refined

    # ------------------------
    # Refine one level (coarse -> fine)
    # ------------------------
    def refine_level(self, X_coarse, U_coarse, X_fine, n_modes,
                     hidden_sizes=(128, 64, 32),
                     dropout=0.0,
                     k_neighbors=4,
                     epochs=200,
                     lr=1e-3,
                     corr_scale=1e-2,
                     w_res=10.0,
                     w_orth=1.0,
                     w_proj=1e-3,
                     freeze_layers=0,
                     checkpoint_name=None):
        """
        Single-level refinement.
        - Automatically sets input dim from features.
        - Creates model if not existing; reuses and optionally freezes layers if existing.
        """
        device = self.device
        print(f"  Computing Laplacian for {X_fine.shape[0]} points...")
        L_fine, M_fine = robust_laplacian.point_cloud_laplacian(X_fine)

        print("  Building prolongation operator...")
        P = self.build_prolongation(X_coarse, X_fine, k=1)

        print("  Building kNN graph...")
        edge_index = self.build_knn_graph(X_fine, k=k_neighbors).to(device)

        # Build U_init on fine grid
        U_init = P @ U_coarse  # shape [n_fine, n_modes]

        # Build features: coords + U_init (we pass raw U_init; normalization happens inside train_gnn)
        x_feats = torch.FloatTensor(np.hstack([X_fine, U_init])).to(device)

        in_dim = x_feats.shape[1]
        out_dim = n_modes

        if self.model is None:
            print(f"  Creating new corrector model (in_dim={in_dim}, out_dim={out_dim})...")
            self.model = SimpleCorrector(in_dim, out_dim, hidden_sizes=hidden_sizes, dropout=dropout).to(device)
        else:
            # If model exists but input dimension changed, re-create model to match new in_dim
            # (safer than trying to partially load weights with mismatched shapes)
            existing_in_dim = None
            # try to infer existing in_dim by checking first Linear in model.net if present
            for m in self.model.net:
                if isinstance(m, nn.Linear):
                    existing_in_dim = m.in_features
                    break
            if existing_in_dim != in_dim:
                print(f"  Recreating model to match new input dim (was {existing_in_dim}, now {in_dim})...")
                # Optionally copy weights for layers that match by size
                old_state = self.model.state_dict()
                self.model = SimpleCorrector(in_dim, out_dim, hidden_sizes=hidden_sizes, dropout=dropout).to(device)
                # attempt to copy subset of weights where shapes match
                new_state = self.model.state_dict()
                for k, v in old_state.items():
                    if k in new_state and old_state[k].shape == new_state[k].shape:
                        new_state[k] = old_state[k]
                self.model.load_state_dict(new_state)

        # Optionally freeze first few linear layers (count of Linear modules)
        if freeze_layers > 0:
            linear_count = 0
            for module in self.model.net:
                if isinstance(module, nn.Linear):
                    linear_count += 1
                    if linear_count <= freeze_layers:
                        for p in module.parameters():
                            p.requires_grad = False
            print(f"  Frozen first {freeze_layers} linear layers.")

        print(f"  Training corrector: epochs={epochs}, lr={lr}, corr_scale={corr_scale}")
        U_pred = self.train_gnn(self.model, x_feats, edge_index, U_init, L_fine, M_fine, U_coarse, P,
                                n_modes,
                                epochs=epochs,
                                lr=lr,
                                corr_scale=corr_scale,
                                w_res=w_res,
                                w_orth=w_orth,
                                w_proj=w_proj)

        print("  Rayleigh-Ritz refinement...")
        lambda_refined, U_refined = self.refine_eigenvectors(U_pred, L_fine, M_fine)

        if checkpoint_name is not None:
            ckpt = {"model_state": self.model.state_dict(), "lambda_refined": lambda_refined}
            torch.save(ckpt, os.path.join(self.checkpoint_dir, checkpoint_name))
            print(f"  Saved checkpoint: {os.path.join(self.checkpoint_dir, checkpoint_name)}")

        return lambda_refined, U_refined, L_fine, M_fine


# ------------------------
# Visualization helper
# ------------------------
def visualize_mesh(mesh, title='Mesh Visualization', highlight_indices=None, show=True):
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(projection='3d')
    ax.plot_trisurf(mesh.verts[:, 0], mesh.verts[:, 1], mesh.verts[:, 2],
                    triangles=mesh.connectivity, alpha=0.35)
    if highlight_indices is not None:
        hv = mesh.verts[highlight_indices]
        ax.scatter(hv[:, 0], hv[:, 1], hv[:, 2], s=6, label=f"{len(highlight_indices)} pts")
        ax.legend()
    ax.set_title(title)
    ax.view_init(elev=120, azim=-90)
    if show:
        plt.show()


# ------------------------
# Main script
# ------------------------
def main():
    mesh_path = "bunny.obj"
    n_modes = 10
    hidden_sizes = (128, 128, 128)
    dropout = 0.0

    # schedule and hyperparams
    epochs_schedule = {0: 0, 1: 1500, 2: 1000, 3: 800, 4: 800}
    #hierarchy = [128, 512, 1024]  # final level will append full
    hierarchy = [1024, 2048, 4096]  # final level will append full
    k_neighbors = 4
    lr_start = 1e-3
    lr_min = 5e-4
    corr_scale = 1e-2
    w_res = 10.0
    w_orth = 10.0
    w_proj = 1e-3
    freeze_schedule = {1: 0, 2: 1, 3: 1, 4: 2}

    print("Loading mesh...")
    mesh = Mesh(mesh_path)
    mesh = MultigridEigensolver.normalize_mesh(mesh)
    X_full = mesh.verts
    n_total = X_full.shape[0]
    print(f"Mesh loaded: {n_total} vertices")

    hierarchy = [n for n in hierarchy if n <= n_total]
    if hierarchy[-1] != n_total:
        hierarchy.append(n_total)
    print("Hierarchy:", hierarchy)

    rng = np.random.default_rng(seed=42)
    all_idx = np.arange(n_total)
    rng.shuffle(all_idx)
    indices_per_level = {}
    for i, n_points in enumerate(hierarchy):
        indices_per_level[i] = all_idx[:n_points].copy()
        print(f"  Level {i}: {n_points} points (nested)")

    solver = MultigridEigensolver()

    # Level 0 coarse solve
    idx0 = indices_per_level[0]
    X0 = X_full[idx0]
    print("\nLEVEL 0: coarse solving...")
    lambda_cur, U_cur, L_cur, M_cur = solver.solve_eigenvalue_problem(X0, n_modes)
    print("Coarse eigenvalues:", np.round(lambda_cur, 6))

    # iterative refinement
    for level in range(1, len(hierarchy)):
        idx_coarse = indices_per_level[level - 1]
        idx_fine = indices_per_level[level]
        Xc = X_full[idx_coarse]
        Xf = X_full[idx_fine]
        epochs = epochs_schedule.get(level, 1000)

        print(f"\nLEVEL {level}: refine {Xc.shape[0]} -> {Xf.shape[0]}, epochs={epochs}")
        freeze_layers = freeze_schedule.get(level, 0)

        total_levels = len(hierarchy)
        decay = (level - 1) / max(1, total_levels - 1)
        lr = lr_start * ((lr_min / lr_start) ** decay)

        lambda_cur, U_cur, L_cur, M_cur = solver.refine_level(
            Xc, U_cur, Xf, n_modes,
            hidden_sizes=hidden_sizes,
            dropout=dropout,
            k_neighbors=k_neighbors,
            epochs=epochs,
            lr=lr,
            corr_scale=corr_scale,
            w_res=w_res,
            w_orth=w_orth,
            w_proj=w_proj,
            freeze_layers=freeze_layers,
            checkpoint_name=f"level_{level}_ckpt.pt"
        )

        print("GNN-refined eigenvalues:", np.round(lambda_cur, 3))

        # exact eigenvalues for verification
        #print("  computing exact eigenvalues for verification...")
        #lambda_exact, _, _, _ = solver.solve_eigenvalue_problem(Xf, n_modes)
        #rel_err = np.abs(lambda_cur - lambda_exact) / (np.abs(lambda_exact) + 1e-12)
        #print("  Exact eigenvalues:", np.round(lambda_exact, 6))
        #print("  Relative errors:  ", np.round(rel_err, 6))

    print("\nDone. Final eigenvalues:", np.round(lambda_cur, 3))

    return U_cur


In [2]:
U_pred = main()

Loading mesh...
Mesh loaded: 2503 vertices
Hierarchy: [1024, 2048, 2503]
  Level 0: 1024 points (nested)
  Level 1: 2048 points (nested)
  Level 2: 2503 points (nested)

LEVEL 0: coarse solving...
Coarse eigenvalues: [0.       0.333353 0.765464 0.8359   1.064607 1.2346   1.738815 2.626301
 2.899525 3.119024]

LEVEL 1: refine 1024 -> 2048, epochs=1500
  Computing Laplacian for 2048 points...
  Building prolongation operator...
  Building kNN graph...
  Creating new corrector model (in_dim=13, out_dim=10)...
  Training corrector: epochs=1500, lr=0.001, corr_scale=0.01
    Epoch    0: Loss=0.970057 (Res=0.000054, Orth=0.096952, Proj=0.000696) U_norm=3.1571 corr_std=0.000546
    Epoch  200: Loss=0.404786 (Res=0.000198, Orth=0.040273, Proj=0.077514) U_norm=18.9547 corr_std=0.112880
    Epoch  400: Loss=0.004728 (Res=0.000458, Orth=0.000001, Proj=0.141191) U_norm=25.3831 corr_std=0.156705
    Epoch  600: Loss=0.004114 (Res=0.000397, Orth=0.000001, Proj=0.141196) U_norm=25.3861 corr_std=0.156

In [4]:
np.round(U_pred.T @ U_pred, 3)

array([[ 7.5931e+01, -1.1700e+00, -2.3900e-01,  4.1250e+00, -1.6180e+00,
         2.7710e+00, -1.7580e+00, -5.0000e-02,  4.3860e+00,  3.0300e+00],
       [-1.1700e+00,  7.4879e+01, -2.5010e+00,  1.5660e+00, -2.0390e+00,
        -1.5900e-01, -2.1210e+00,  2.8400e-01,  7.9900e-01,  1.0640e+00],
       [-2.3900e-01, -2.5010e+00,  7.5439e+01,  4.6060e+00,  4.3600e-01,
         3.0010e+00, -2.2270e+00, -1.6500e+00,  1.9160e+00,  6.4100e-01],
       [ 4.1250e+00,  1.5660e+00,  4.6060e+00,  8.0014e+01, -1.3900e+00,
         2.5500e-01,  1.1350e+00,  3.4430e+00,  3.7500e-01, -3.9500e-01],
       [-1.6180e+00, -2.0390e+00,  4.3600e-01, -1.3900e+00,  7.6387e+01,
        -5.9100e-01,  2.8290e+00, -2.2240e+00,  2.0910e+00,  2.1900e+00],
       [ 2.7710e+00, -1.5900e-01,  3.0010e+00,  2.5500e-01, -5.9100e-01,
         8.0063e+01,  3.8430e+00, -1.4230e+00,  1.0540e+00, -1.7680e+00],
       [-1.7580e+00, -2.1210e+00, -2.2270e+00,  1.1350e+00,  2.8290e+00,
         3.8430e+00,  7.2624e+01, -1.8720e+00

In [5]:
import meshio

m = Mesh('bunny.obj')

centroid = m.verts.mean(0)
std_max = m.verts.std(0).max()

verts_new = (m.verts - centroid)/std_max

m = Mesh(verts = verts_new, connectivity = m.connectivity)

cells = [('triangle', m.connectivity)]
m_out = meshio.Mesh(m.verts, cells, point_data={f'v{i}': U_pred[:, i] for i in range(1, 10)})

m_out.write('bunny_eigfuncs_pred_seq.vtu')

In [23]:
"""
Complete fixed version of multigrid + GNN eigen-refinement pipeline.

Key improvements:
 - Fixed dimension consistency (U_coarse subsetting)
 - Adaptive per-mode correction scaling
 - Best model checkpointing during training
 - Enhanced monitoring and validation
 - Eigenvalue quality metrics
 - SMOOTHNESS REGULARIZATION to prevent high-frequency noise
"""

import os
import numpy as np
from scipy.linalg import eigh
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import eigsh, spsolve
from sklearn.neighbors import NearestNeighbors

import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt

# Project imports (must be available in your environment)
from Mesh import Mesh
import robust_laplacian


# ------------------------
# Utility helpers
# ------------------------
def sp_to_torch_sparse(A):
    """Convert scipy sparse matrix to torch.sparse_coo_tensor."""
    A = A.tocoo()
    indices = np.vstack((A.row, A.col)).astype(np.int64)
    i = torch.LongTensor(indices)
    v = torch.FloatTensor(A.data)
    return torch.sparse_coo_tensor(i, v, A.shape).coalesce()


def normalize_columns_np(U, eps=1e-12):
    """Normalize numpy matrix columns to have unit L2 norm. Returns normalized U and norms."""
    norms = np.linalg.norm(U, axis=0) + eps
    return U / norms, norms


def normalize_columns_torch(U, eps=1e-12):
    """Normalize torch tensor columns to have unit L2 norm. Returns normalized U and norms (torch)."""
    norms = torch.norm(U, dim=0) + eps
    return U / norms, norms


def validate_eigenvalues(U, L, M, lambda_vals):
    """
    Validate that U, lambda satisfy L*U = lambda*M*U
    Returns max relative residual norm
    """
    from scipy.sparse import issparse
    n_modes = U.shape[1]
    residuals = []
    
    for i in range(n_modes):
        u = U[:, i].reshape(-1, 1)
        lam = lambda_vals[i]
        
        Lu = L @ u if issparse(L) else L.dot(u)
        Mu = M @ u if issparse(M) else M.dot(u)
        residual = Lu - lam * Mu
        
        res_norm = np.linalg.norm(residual)
        u_norm = np.linalg.norm(Mu)
        
        rel_residual = res_norm / (abs(lam) * u_norm + 1e-12)
        residuals.append(rel_residual)
    
    return max(residuals), np.array(residuals)


def smooth_eigenfunctions(U, L, M, n_iters=3, tau=0.01):
    """
    Smooth eigenfunctions using implicit smoothing: (M + tau*L) * U_new = M * U_old
    This reduces high-frequency oscillations while preserving eigenspace.
    
    AGGRESSIVE VERSION: Uses larger tau and more iterations.
    """
    U_smooth = U.copy()
    
    # Use progressively larger tau for more smoothing
    for it in range(n_iters):
        # Increase tau over iterations for stronger smoothing
        current_tau = tau * (1.0 + 0.5 * it / max(1, n_iters))
        A = M + current_tau * L
        
        for i in range(U.shape[1]):
            rhs = M @ U_smooth[:, i]
            U_smooth[:, i] = spsolve(A, rhs)
    
    return U_smooth


def m_orthonormalize(U, M):
    """Explicit M-orthonormalization using Cholesky."""
    from scipy.linalg import cholesky, solve_triangular
    
    M_dense = M.toarray() if hasattr(M, 'toarray') else M
    
    # Compute M-weighted Gram matrix
    MU = M_dense @ U
    G = U.T @ MU
    
    # Cholesky factorization: G = L @ L.T
    try:
        L_chol = cholesky(G, lower=True)
        # M-orthonormalize: U_orth = U @ inv(L)
        U_orth = solve_triangular(L_chol.T, U.T, lower=False).T
        return U_orth
    except np.linalg.LinAlgError:
        print("    Warning: Cholesky failed, Gram matrix not positive definite. Skipping orthonormalization.")
        return U


# ------------------------
# Adaptive corrector with per-mode scaling
# ------------------------
class AdaptiveCorrector(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_sizes=(128, 64, 32), dropout=0.0, init_scale=0.01):
        super().__init__()
        layers = []
        prev = in_dim * 2  # because we concat self + neighbor-mean
        for h in hidden_sizes:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU(inplace=True))
            if dropout > 0.0:
                layers.append(nn.Dropout(dropout))
            prev = h
        layers.append(nn.Linear(prev, out_dim))
        self.net = nn.Sequential(*layers)
        
        # Learnable per-mode scaling factors
        self.mode_scales = nn.Parameter(torch.ones(out_dim) * init_scale)

    def forward(self, x, edge_index):
        """
        x: [n, in_dim]
        edge_index: LongTensor shape [2, n_edges]
        """
        row, col = edge_index
        n = x.shape[0]
        # aggregate neighbor features: mean aggregator
        agg = torch.zeros_like(x)
        agg.index_add_(0, row, x[col])
        deg = torch.bincount(row, minlength=n).unsqueeze(1).to(x.dtype).to(x.device)
        deg = deg.clamp(min=1.0)
        agg = agg / deg
        h = torch.cat([x, agg], dim=1)
        correction = self.net(h)
        
        # Apply per-mode adaptive scaling
        return correction * self.mode_scales.unsqueeze(0)


# ------------------------
# Multigrid eigensolver with GNN corrector
# ------------------------
class MultigridEigensolver:
    def __init__(self, device=None, checkpoint_dir="./checkpoints"):
        self.device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
        self.model = None
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(self.checkpoint_dir, exist_ok=True)

    @staticmethod
    def normalize_mesh(mesh):
        centroid = mesh.verts.mean(0)
        std_max = mesh.verts.std(0).max() + 1e-12
        verts_normalized = (mesh.verts - centroid) / std_max
        return Mesh(verts=verts_normalized, connectivity=mesh.connectivity)

    @staticmethod
    def build_prolongation(X_coarse, X_fine, k=1):
        nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(X_coarse)
        distances, indices = nbrs.kneighbors(X_fine)
        n_fine, n_coarse = X_fine.shape[0], X_coarse.shape[0]
        rows, cols, vals = [], [], []
        for i in range(n_fine):
            weights = 1.0 / (distances[i] + 1e-12)
            weights /= weights.sum()
            for j, idx in enumerate(indices[i]):
                rows.append(i)
                cols.append(idx)
                vals.append(weights[j])
        return coo_matrix((vals, (rows, cols)), shape=(n_fine, n_coarse))

    @staticmethod
    def build_knn_graph(X, k=4):
        n_points = X.shape[0]
        nbrs = NearestNeighbors(n_neighbors=k + 1).fit(X)
        _, neighbors = nbrs.kneighbors(X)
        rows, cols = [], []
        for i in range(n_points):
            for j in neighbors[i][1:]:
                rows.append(i)
                cols.append(j)
        return torch.LongTensor([rows, cols]).to(torch.long)

    def solve_eigenvalue_problem(self, X, n_modes):
        L, M = robust_laplacian.point_cloud_laplacian(X)
        vals, vecs = eigsh(L, k=n_modes, M=M, which='SM')
        return vals, np.array(vecs), L, M

    # ------------------------
    # Enhanced training routine WITH SMOOTHNESS REGULARIZATION
    # ------------------------
    def train_gnn(self, model, x_feats, edge_index, U_init, L_fine, M_fine, U_coarse, P,
                  n_modes,
                  epochs=200,
                  lr=1e-3,
                  w_res=10.0,
                  w_orth=1.0,
                  w_proj=1e-3,
                  w_smooth=1.0,  # NEW: smoothness regularization weight
                  grad_clip=1.0,
                  weight_decay=1e-6,
                  log_every=200):
        """
        Train corrector model with smoothness regularization to prevent high-frequency noise.
        
        NEW: w_smooth controls Laplacian smoothness penalty on corrections.
        Higher w_smooth = smoother corrections = less speckled eigenfunctions.
        """
        device = self.device

        # Convert sparse matrices to torch sparse on device
        L_t = sp_to_torch_sparse(L_fine).to(device)
        M_t = sp_to_torch_sparse(M_fine).to(device)
        R_t = sp_to_torch_sparse(P.T).to(device)

        # Normalize columns of U_init and U_coarse
        U_init_normed, uinit_norms = normalize_columns_np(U_init)
        U_coarse_normed, ucoarse_norms = normalize_columns_np(U_coarse)

        U_init_t = torch.FloatTensor(U_init_normed).to(device)
        U_coarse_t = torch.FloatTensor(U_coarse_normed).to(device)

        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 
                              lr=lr, weight_decay=weight_decay)

        n_fine = U_init_t.shape[0]
        n_coarse = U_coarse_t.shape[0]
        denom_res = float(max(1, n_fine * n_modes))
        denom_proj = float(max(1, n_coarse * n_modes))
        I = torch.eye(n_modes, device=device)

        # Track best model
        best_loss = float('inf')
        best_state = None
        best_epoch = 0

        model.train()
        for ep in range(epochs):
            optimizer.zero_grad()

            corr = model(x_feats, edge_index)  # Already scaled by learnable mode_scales
            U_pred = U_init_t + corr

            # Rayleigh-related tensors
            Lu = torch.sparse.mm(L_t, U_pred)
            Mu = torch.sparse.mm(M_t, U_pred)
            num = torch.sum(U_pred * Lu, dim=0)
            den = torch.sum(U_pred * Mu, dim=0) + 1e-12
            lambdas = num / den

            # Residual loss (normalized)
            res = Lu - Mu * lambdas.unsqueeze(0)
            L_res = torch.sum(res**2) / denom_res

            # Orthonormality loss (M-weighted Gram)
            MUt = torch.sparse.mm(M_t, U_pred)
            Gram = U_pred.t() @ MUt
            L_orth = torch.sum((Gram - I)**2) / (n_modes * n_modes)

            # Projection loss
            proj = torch.sparse.mm(R_t, U_pred)
            L_proj = torch.sum((proj - U_coarse_t)**2) / denom_proj

            # NEW: Smoothness loss - penalize high Laplacian energy of CORRECTION
            # This prevents high-frequency oscillations/speckles
            # We penalize corr.T @ L @ corr, which measures "roughness"
            L_corr = torch.sparse.mm(L_t, corr)
            L_smooth_corr = torch.sum(corr * L_corr) / denom_res
            
            # ADDITIONAL: Also penalize total Laplacian energy of U_pred
            # This enforces that final eigenfunctions are smooth
            L_smooth_total = torch.sum(U_pred * Lu) / denom_res

            # Combined loss with BOTH smoothness terms
            loss = w_res * L_res + w_orth * L_orth + w_proj * L_proj + w_smooth * (L_smooth_corr + L_smooth_total)
            
            loss.backward()

            if grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), grad_clip)

            optimizer.step()

            # Track best model
            current_loss = loss.item()
            if current_loss < best_loss:
                best_loss = current_loss
                best_epoch = ep
                best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

            if (ep % log_every == 0) or (ep == epochs - 1):
                with torch.no_grad():
                    u_norm = float(U_pred.norm().cpu().item())
                    corr_norm = float(corr.norm().cpu().item())
                    corr_std = float(corr.std().cpu().item())
                    corr_mean = float(corr.mean().cpu().item())
                    lambdas_np = lambdas.cpu().numpy()
                    
                    # Get mode scales
                    scales = model.mode_scales.detach().cpu().numpy()
                    
                    print(f"    Epoch {ep:4d}: Loss={current_loss:.6f} "
                          f"(Res={L_res.item():.6f}, Orth={L_orth.item():.6f}, "
                          f"Proj={L_proj.item():.6f}, Smooth={L_smooth_corr.item():.4f}+{L_smooth_total.item():.4f})")
                    print(f"              U_norm={u_norm:.4f} corr_norm={corr_norm:.6f} "
                          f"corr_std={corr_std:.6f} corr_mean={corr_mean:.6f}")
                    print(f"              Lambdas: {np.round(lambdas_np[:5], 4)}")
                    print(f"              Scales:  {np.round(scales[:5], 6)}")

        # Restore best model
        if best_state is not None:
            model.load_state_dict(best_state)
            print(f"    --> Restored best model from epoch {best_epoch} (loss={best_loss:.6f})")

        # Final prediction with best model
        model.eval()
        with torch.no_grad():
            corr = model(x_feats, edge_index)
            U_pred = U_init_t + corr

        # Denormalize: multiply columns by original column norms of U_init
        U_pred_np = U_pred.detach().cpu().numpy() * uinit_norms.reshape(1, -1)
        return U_pred_np

    # ------------------------
    # Rayleigh-Ritz refinement with optional smoothing
    # ------------------------
    def refine_eigenvectors(self, U_pred, L_fine, M_fine, apply_smoothing=True, smooth_iters=3, tau=0.1):
        L_t = sp_to_torch_sparse(L_fine).to(self.device)
        M_t = sp_to_torch_sparse(M_fine).to(self.device)
        U = torch.FloatTensor(U_pred).to(self.device)
        A = (U.t() @ torch.sparse.mm(L_t, U)).cpu().numpy()
        B = (U.t() @ torch.sparse.mm(M_t, U)).cpu().numpy()
        vals, C = eigh(A, B)
        U_refined = U.cpu().numpy() @ C
        
        # Optional: Apply smoothing to remove any remaining high-frequency noise
        if apply_smoothing:
            print(f"    Applying {smooth_iters} iterations of implicit smoothing (tau={tau})...")
            U_refined = smooth_eigenfunctions(U_refined, L_fine, M_fine, n_iters=smooth_iters, tau=tau)
        
        # Explicit M-orthonormalization to ensure identity Gram matrix
        U_refined = m_orthonormalize(U_refined, M_fine)
        
        return vals, U_refined

    # ------------------------
    # Refine one level (coarse -> fine) - FIXED VERSION
    # ------------------------
    def refine_level(self, X_coarse, U_coarse, X_fine, n_modes,
                     hidden_sizes=(128, 64, 32),
                     dropout=0.0,
                     k_neighbors=4,
                     epochs=200,
                     lr=1e-3,
                     init_scale=0.01,
                     w_res=10.0,
                     w_orth=1.0,
                     w_proj=1e-3,
                     w_smooth=1.0,  # NEW: smoothness weight
                     freeze_layers=0,
                     checkpoint_name=None,
                     validate=False,
                     apply_smoothing=True,
                     smooth_iters=3,
                     tau=0.1):  # NEW: smoothing strength parameter
        """
        Single-level refinement with dimension fixes and smoothness regularization.
        """
        device = self.device
        print(f"  Computing Laplacian for {X_fine.shape[0]} points...")
        L_fine, M_fine = robust_laplacian.point_cloud_laplacian(X_fine)

        print("  Building prolongation operator...")
        P = self.build_prolongation(X_coarse, X_fine, k=1)

        print("  Building kNN graph...")
        edge_index = self.build_knn_graph(X_fine, k=k_neighbors).to(device)

        # CRITICAL FIX: Ensure U_coarse only uses first n_modes columns
        if U_coarse.shape[1] > n_modes:
            print(f"  WARNING: U_coarse has {U_coarse.shape[1]} modes, taking first {n_modes}")
            U_coarse_subset = U_coarse[:, :n_modes]
        else:
            U_coarse_subset = U_coarse
        
        # Build U_init on fine grid
        U_init = P @ U_coarse_subset  # shape [n_fine, n_modes]
        
        # Dimension validation
        assert U_init.shape == (X_fine.shape[0], n_modes), \
            f"U_init has wrong shape: {U_init.shape}, expected ({X_fine.shape[0]}, {n_modes})"

        # Build features: coords + U_init
        x_feats = torch.FloatTensor(np.hstack([X_fine, U_init])).to(device)

        in_dim = x_feats.shape[1]
        out_dim = n_modes
        
        expected_in_dim = X_fine.shape[1] + n_modes
        assert in_dim == expected_in_dim, f"Feature dim mismatch: {in_dim} != {expected_in_dim}"

        print(f"  Feature dimensions: coords={X_fine.shape[1]}, modes={n_modes}, total_in={in_dim}")

        if self.model is None:
            print(f"  Creating new corrector model (in_dim={in_dim}, out_dim={out_dim})...")
            self.model = AdaptiveCorrector(in_dim, out_dim, hidden_sizes=hidden_sizes, 
                                          dropout=dropout, init_scale=init_scale).to(device)
        else:
            # Check if input dimension matches
            existing_in_dim = None
            for m in self.model.net:
                if isinstance(m, nn.Linear):
                    existing_in_dim = m.in_features
                    break
            
            if existing_in_dim != in_dim:
                print(f"  WARNING: Input dim changed {existing_in_dim} -> {in_dim}. Recreating model...")
                old_state = self.model.state_dict()
                self.model = AdaptiveCorrector(in_dim, out_dim, hidden_sizes=hidden_sizes, 
                                              dropout=dropout, init_scale=init_scale).to(device)
                # Try to copy matching weights
                new_state = self.model.state_dict()
                for k, v in old_state.items():
                    if k in new_state and old_state[k].shape == new_state[k].shape:
                        new_state[k] = old_state[k]
                self.model.load_state_dict(new_state)
            else:
                print(f"  Reusing existing model (in_dim={in_dim})")

        # Optionally freeze layers
        if freeze_layers > 0:
            linear_count = 0
            for module in self.model.net:
                if isinstance(module, nn.Linear):
                    linear_count += 1
                    if linear_count <= freeze_layers:
                        for p in module.parameters():
                            p.requires_grad = False
            print(f"  Frozen first {freeze_layers} linear layers.")

        print(f"  Training corrector: epochs={epochs}, lr={lr}, init_scale={init_scale}, w_smooth={w_smooth}")
        U_pred = self.train_gnn(self.model, x_feats, edge_index, U_init, L_fine, M_fine, 
                                U_coarse_subset, P, n_modes,
                                epochs=epochs, lr=lr, w_res=w_res, w_orth=w_orth, 
                                w_proj=w_proj, w_smooth=w_smooth)

        print("  Rayleigh-Ritz refinement...")
        lambda_refined, U_refined = self.refine_eigenvectors(U_pred, L_fine, M_fine, 
                                                             apply_smoothing=apply_smoothing,
                                                             smooth_iters=smooth_iters,
                                                             tau=tau)

        # Validation
        if validate:
            print("  Validating eigenvalue quality...")
            max_res, residuals = validate_eigenvalues(U_refined, L_fine, M_fine, lambda_refined)
            print(f"    Max relative residual: {max_res:.2e}")
            print(f"    Residuals per mode: {np.round(residuals[:5], 8)}")

        if checkpoint_name is not None:
            ckpt = {"model_state": self.model.state_dict(), 
                   "lambda_refined": lambda_refined,
                   "max_residual": max_res if validate else None}
            torch.save(ckpt, os.path.join(self.checkpoint_dir, checkpoint_name))
            print(f"  Saved checkpoint: {os.path.join(self.checkpoint_dir, checkpoint_name)}")

        return lambda_refined, U_refined, L_fine, M_fine


# ------------------------
# Visualization helper
# ------------------------
def visualize_mesh(mesh, title='Mesh Visualization', highlight_indices=None, show=True):
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(projection='3d')
    ax.plot_trisurf(mesh.verts[:, 0], mesh.verts[:, 1], mesh.verts[:, 2],
                    triangles=mesh.connectivity, alpha=0.35)
    if highlight_indices is not None:
        hv = mesh.verts[highlight_indices]
        ax.scatter(hv[:, 0], hv[:, 1], hv[:, 2], s=6, label=f"{len(highlight_indices)} pts")
        ax.legend()
    ax.set_title(title)
    ax.view_init(elev=120, azim=-90)
    if show:
        plt.show()


# ------------------------
# Main script
# ------------------------
def main():
    mesh_path = "bunny.obj"
    n_modes = 11
    hidden_sizes = (128, 128, 128)
    dropout = 0.0

    # Enhanced schedule with AGGRESSIVE smoothness regularization
    epochs_schedule = {0: 0, 1: 1500, 2: 1000, 3: 800, 4: 800}
    hierarchy = [1024, 2048, 4096]
    k_neighbors = 4
    lr_start = 5e-4    # Reduced learning rate
    lr_min = 1e-4
    init_scale = 0.0001  # VERY SMALL corrections
    
    # Loss weights - AGGRESSIVE SMOOTHNESS
    w_res = 1000.0     # Very high residual weight
    w_orth = 10.0
    w_proj = 0.1       # Reduced projection weight
    w_smooth = 1000.0  # VERY HIGH smoothness weight
    
    freeze_schedule = {1: 0, 2: 1, 3: 1, 4: 2}
    
    # Post-processing smoothing - MORE AGGRESSIVE
    apply_smoothing = True
    smooth_iters = 20  # Increased from 5 to 20
    tau = 0.1          # Larger tau = more aggressive smoothing

    print("Loading mesh...")
    mesh = Mesh(mesh_path)
    mesh = MultigridEigensolver.normalize_mesh(mesh)
    X_full = mesh.verts
    n_total = X_full.shape[0]
    print(f"Mesh loaded: {n_total} vertices")

    hierarchy = [n for n in hierarchy if n <= n_total]
    if hierarchy[-1] != n_total:
        hierarchy.append(n_total)
    print("Hierarchy:", hierarchy)

    rng = np.random.default_rng(seed=42)
    all_idx = np.arange(n_total)
    rng.shuffle(all_idx)
    indices_per_level = {}
    for i, n_points in enumerate(hierarchy):
        indices_per_level[i] = all_idx[:n_points].copy()
        print(f"  Level {i}: {n_points} points (nested)")

    solver = MultigridEigensolver()

    # Level 0 coarse solve
    idx0 = indices_per_level[0]
    X0 = X_full[idx0]
    print("\nLEVEL 0: coarse solving...")
    lambda_cur, U_cur, L_cur, M_cur = solver.solve_eigenvalue_problem(X0, n_modes)
    print("Coarse eigenvalues:", np.round(lambda_cur, 6))
    
    # Validate initial
    max_res, _ = validate_eigenvalues(U_cur, L_cur, M_cur, lambda_cur)
    print(f"Initial max residual: {max_res:.2e}")

    # Iterative refinement
    for level in range(1, len(hierarchy)):
        idx_coarse = indices_per_level[level - 1]
        idx_fine = indices_per_level[level]
        Xc = X_full[idx_coarse]
        Xf = X_full[idx_fine]
        epochs = epochs_schedule.get(level, 1000)

        print(f"\n{'='*60}")
        print(f"LEVEL {level}: refine {Xc.shape[0]} -> {Xf.shape[0]}, epochs={epochs}")
        print(f"{'='*60}")
        
        freeze_layers = freeze_schedule.get(level, 0)

        total_levels = len(hierarchy)
        decay = (level - 1) / max(1, total_levels - 1)
        lr = lr_start * ((lr_min / lr_start) ** decay)

        lambda_cur, U_cur, L_cur, M_cur = solver.refine_level(
            Xc, U_cur, Xf, n_modes,
            hidden_sizes=hidden_sizes,
            dropout=dropout,
            k_neighbors=k_neighbors,
            epochs=epochs,
            lr=lr,
            init_scale=init_scale,
            w_res=w_res,
            w_orth=w_orth,
            w_proj=w_proj,
            w_smooth=w_smooth,  # NEW
            freeze_layers=freeze_layers,
            checkpoint_name=f"level_{level}_ckpt.pt",
            validate=True,
            apply_smoothing=apply_smoothing,
            smooth_iters=smooth_iters,
            tau=tau  # NEW
        )

        print(f"\nGNN-refined eigenvalues: {np.round(lambda_cur, 4)}")

    print(f"\n{'='*60}")
    print("FINAL RESULTS")
    print(f"{'='*60}")
    print(f"Final eigenvalues: {np.round(lambda_cur, 4)}")
    
    return U_cur, L_cur, M_cur

In [24]:
U_final, L_final, M_final = main()

# Check M-orthonormality
M_dense = M_final.toarray()
gram = U_final.T @ M_dense @ U_final
print("\nM-orthonormal Gram matrix:")
np.set_printoptions(precision=2, suppress=True, linewidth=150)
print(gram)

Loading mesh...
Mesh loaded: 2503 vertices
Hierarchy: [1024, 2048, 2503]
  Level 0: 1024 points (nested)
  Level 1: 2048 points (nested)
  Level 2: 2503 points (nested)

LEVEL 0: coarse solving...
Coarse eigenvalues: [0.   0.33 0.77 0.84 1.06 1.23 1.74 2.63 2.9  3.12 3.27]
Initial max residual: 1.59e-01

LEVEL 1: refine 1024 -> 2048, epochs=1500
  Computing Laplacian for 2048 points...
  Building prolongation operator...
  Building kNN graph...
  Feature dimensions: coords=3, modes=11, total_in=14
  Creating new corrector model (in_dim=14, out_dim=11)...
  Training corrector: epochs=1500, lr=0.0005, init_scale=0.0001, w_smooth=1000.0
    Epoch    0: Loss=0.962880 (Res=0.000058, Orth=0.088140, Proj=0.000700, Smooth=0.0000+0.0000)
              U_norm=3.3167 corr_norm=0.001102 corr_std=0.000007 corr_mean=0.000000
              Lambdas: [-0.    0.59  2.08  1.71  2.19]
              Scales:  [ 0.  0. -0.  0. -0.]
    Epoch  200: Loss=0.731226 (Res=0.000017, Orth=0.057686, Proj=0.037780, Sm

In [25]:
import meshio

m = Mesh('bunny.obj')

centroid = m.verts.mean(0)
std_max = m.verts.std(0).max()

verts_new = (m.verts - centroid)/std_max

m = Mesh(verts = verts_new, connectivity = m.connectivity)

cells = [('triangle', m.connectivity)]
m_out = meshio.Mesh(m.verts, cells,point_data={f'v{i}': U_final[:, i] for i in range(1, 11)})

m_out.write('bunny_eigfuncs_seq.vtu')