In [None]:
import numpy as np
from scipy.linalg import eigh
import scipy.sparse.linalg as sla
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import robust_laplacian
from Mesh import Mesh
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.sparse import csr_matrix, dia_matrix, csc_matrix
from scipy import sparse
from tqdm import trange

In [None]:
# ============ DEVICE SETUP ============
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# ============ NETWORK ARCHITECTURE ============
class Sin(nn.Module):
    """Sine activation function"""
    def forward(self, x):
        return torch.sin(x)


class EigenfunctionNN(nn.Module):
    """
    Neural network to learn eigenfunctions on point clouds.
    Input: 3D coordinates (x, y, z)
    Output: eigenfunction value u(x,y,z) and eigenvalue λ
    """
    def __init__(self, hidden_dim=64, input_dim=3, initial_eigenvalue=0.0):
        super().__init__()
        self.activation = Sin()
        
        # Learnable eigenvalue with better initialization
        self.eigenvalue_layer = nn.Linear(1, 1, bias=False)
        with torch.no_grad():
            self.eigenvalue_layer.weight.fill_(initial_eigenvalue)
        
        # Network layers - concatenate eigenvalue at each layer
        self.fc1 = nn.Linear(input_dim + 1, hidden_dim)  # +1 for eigenvalue
        self.fc2 = nn.Linear(hidden_dim + 1, hidden_dim)  # +1 for eigenvalue
        self.fc3 = nn.Linear(hidden_dim + 1, hidden_dim)  # +1 for eigenvalue
        self.fc4 = nn.Linear(hidden_dim + 1, 1)  # +1 for eigenvalue
        
    def forward(self, x):
        """
        Args:
            x: (N, 3) point cloud coordinates
        Returns:
            u: (N, 1) eigenfunction values
            eigenvalue: scalar learnable eigenvalue
        """
        # Learn eigenvalue and broadcast to match batch size
        eigenvalue = torch.abs(self.eigenvalue_layer(torch.ones(1, 1).to(x.device)))
        eigenvalue_expanded = eigenvalue.expand(x.shape[0], 1)  # (N, 1)
        
        # Forward pass - concatenate eigenvalue at each layer
        h = torch.cat([x, eigenvalue_expanded], dim=1)  # (N, input_dim+1)
        h = self.activation(self.fc1(h))
        
        h = torch.cat([h, eigenvalue_expanded], dim=1)  # (N, hidden_dim+1)
        h = self.activation(self.fc2(h))
        
        h = torch.cat([h, eigenvalue_expanded], dim=1)  # (N, hidden_dim+1)
        h = self.activation(self.fc3(h))
        
        h = torch.cat([h, eigenvalue_expanded], dim=1)  # (N, hidden_dim+1)
        u = self.fc4(h)
        
        return u, eigenvalue


# ============ LOSS COMPUTATION ============
def compute_eigenvalue_loss(u, eigenvalue, L, M, X, device):
    """
    Compute residual for Lu = λMu using discrete operators.
    
    Args:
        u: (N, 1) predicted eigenfunction values
        eigenvalue: scalar predicted eigenvalue
        L: (N, N) Laplacian matrix (scipy sparse: csr, csc, dia, etc.)
        M: (N, N) Mass matrix (scipy sparse: csr, csc, dia, etc.)
        X: (N, 3) point cloud coordinates
        device: torch device
    
    Returns:
        loss: MSE of residual ||Lu - λMu||²
    """
    u_flat = u.squeeze()  # (N,)
    
    # Convert sparse matrices to torch sparse tensors if needed
    if sparse.issparse(L):
        L_torch = sparse_to_torch(L, device)
        M_torch = sparse_to_torch(M, device)
    else:
        L_torch = L
        M_torch = M
    
    # Compute Lu and λMu
    Lu = torch.sparse.mm(L_torch, u_flat.unsqueeze(1)).squeeze()
    Mu = torch.sparse.mm(M_torch, u_flat.unsqueeze(1)).squeeze()
    lMu = eigenvalue * Mu
    
    # Residual loss
    residual = Lu - lMu
    loss = torch.mean(residual ** 2)
    
    return loss, Lu, Mu


def compute_normalization_loss(u, M, device):
    """
    Enforce u^T M u = 1 (mass-matrix normalization).
    """
    u_flat = u.squeeze()
    
    if sparse.issparse(M):
        M_torch = sparse_to_torch(M, device)
    else:
        M_torch = M
    
    Mu = torch.sparse.mm(M_torch, u_flat.unsqueeze(1)).squeeze()
    norm_squared = torch.dot(u_flat, Mu)
    
    # Penalize deviation from unit norm
    loss = (norm_squared - 1.0) ** 2
    
    return loss


def compute_orthogonality_loss(u, previous_eigenfunctions, M, device):
    """
    Enforce u ⊥ u_i for all previously found eigenfunctions.
    Uses M-orthogonality: u^T M u_i = 0
    """
    if len(previous_eigenfunctions) == 0:
        return torch.tensor(0.0, device=device)
    
    u_flat = u.squeeze()
    ortho_loss = torch.tensor(0.0, device=device)
    
    if sparse.issparse(M):
        M_torch = sparse_to_torch(M, device)
    else:
        M_torch = M
    
    for u_prev in previous_eigenfunctions:
        u_prev_flat = u_prev.squeeze()
        # Compute u^T M u_prev
        Mu_prev = torch.sparse.mm(M_torch, u_prev_flat.unsqueeze(1)).squeeze()
        overlap = torch.dot(u_flat, Mu_prev)
        ortho_loss += overlap ** 2
    
    return ortho_loss


# ============ UTILITY FUNCTIONS ============
def sparse_to_torch(sparse_matrix, device):
    """Convert scipy sparse matrix to torch sparse tensor."""
    # Handle any scipy sparse format
    if sparse.issparse(sparse_matrix):
        coo = sparse_matrix.tocoo()
    else:
        raise ValueError(f"Expected scipy sparse matrix, got {type(sparse_matrix)}")
    
    indices = torch.LongTensor(np.vstack((coo.row, coo.col)))
    values = torch.FloatTensor(coo.data)
    shape = coo.shape
    return torch.sparse_coo_tensor(indices, values, shape).to(device)


def initialize_weights(m):
    """Reinitialize network weights for finding next eigenfunction."""
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)


# ============ MAIN TRAINING FUNCTION ============
def train_eigenvalue_pinn(X, L, M, hidden_dim=64, epochs=20000, 
                          lr=1e-3, num_eigenfunctions=5, 
                          convergence_threshold=1e-7,
                          ortho_weight=1.0):
    """print(f"The true Eigenvalues are: {np.array2string(eigvals[:10], formatter={'float': lambda x: f'{x:.3f}'})}")
print(f"The predicted Eigenvalues are: {np.array2string(np.array(eigenvalues), formatter={'float': lambda x: f'{x:.3f}'})}")
    Train PINN to solve Lu = λMu eigenvalue problem.
    
    Args:
        X: (N, 3) point cloud coordinates
        L: (N, N) Laplacian matrix (scipy sparse: csr, csc, dia, etc.)
        M: (N, N) Mass matrix (scipy sparse: csr, csc, dia, etc.)
        hidden_dim: Hidden layer dimension
        epochs: Training epochs per eigenfunction
        lr: Learning rate
        num_eigenfunctions: Number of eigenfunctions to find
        convergence_threshold: Threshold for detecting convergence
        ortho_weight: Weight for orthogonality loss (increase if eigenfunctions overlap)
    
    Returns:
        eigenvalues: List of found eigenvalues
        eigenfunctions: List of (N, 1) eigenfunctions
        loss_history: Training loss history
    """
    
    # Convert inputs to torch (handle both numpy arrays and torch tensors)
    if isinstance(X, np.ndarray):
        X_torch = torch.FloatTensor(X).to(device)
    elif isinstance(X, torch.Tensor):
        X_torch = X.float().to(device)  # Ensure float32
    else:
        raise ValueError(f"X must be numpy array or torch tensor, got {type(X)}")
    
    X_torch.requires_grad = True
    
    # Storage for results
    eigenvalues = []
    eigenfunctions = []
    all_models = []
    loss_history = {'total': [], 'eigenvalue': [], 'normalization': [], 'orthogonality': []}
    
    print(f"Training on device: {device}")
    print(f"Point cloud size: {X.shape[0]} points")
    print(f"Matrix format: L is {type(L).__name__}, M is {type(M).__name__}")
    
    # ============ ITERATIVE EIGENFUNCTION DISCOVERY ============
    for eig_idx in range(num_eigenfunctions):
        print(f"\n{'='*60}")
        print(f"Finding eigenfunction {eig_idx + 1}/{num_eigenfunctions}")
        print(f"{'='*60}")
        
        # Initialize network with progressively larger eigenvalue guess
        # For Laplacian: smallest eigenvalue is 0
        if eig_idx == 0:
            initial_eigenvalue = 0.0  # First eigenvalue for Laplacian
        elif eig_idx > 0:
            # Use previous eigenvalue as lower bound + small increment
            initial_eigenvalue = eigenvalues[-1] + 0.15
        else:
            initial_eigenvalue = eig_idx * 0.2
        
        model = EigenfunctionNN(hidden_dim=hidden_dim, input_dim=X.shape[1], 
                               initial_eigenvalue=initial_eigenvalue).to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.999, 0.9999))
        
        best_model = None
        best_loss = float('inf')
        loss_slope_history = []
        
        for epoch in range(epochs):
            optimizer.zero_grad()
            
            # Forward pass
            u, eigenvalue = model(X_torch)
            
            # Compute losses
            eig_loss, Lu, Mu = compute_eigenvalue_loss(u, eigenvalue, L, M, X_torch, device)
            norm_loss = compute_normalization_loss(u, M, device)
            ortho_loss = compute_orthogonality_loss(u, eigenfunctions, M, device)
            
            # Total loss with weighting
            total_loss = eig_loss + norm_loss + ortho_weight * ortho_loss
            
            # Backward pass
            total_loss.backward()
            optimizer.step()
            
            # Track loss slope for convergence detection
            loss_slope_history.append(eig_loss.item())
            if len(loss_slope_history) > 1000:
                loss_slope_history.pop(0)
                slope = np.mean(np.diff(loss_slope_history))
            else:
                slope = 0.0
            
            # Save best model
            if eig_loss.item() < best_loss:
                best_loss = eig_loss.item()
                best_model = copy.deepcopy(model)
            
            # Logging
            if epoch % 500 == 0:
                print(f"Epoch {epoch:5d} | λ={eigenvalue.item():.6f} | "
                      f"Eig loss={eig_loss.item():.2e} | Norm loss={norm_loss.item():.2e} | "
                      f"Ortho loss={ortho_loss.item():.2e} | Slope={slope:.2e}")
            
            # Store history
            loss_history['total'].append(total_loss.item())
            loss_history['eigenvalue'].append(eig_loss.item())
            loss_history['normalization'].append(norm_loss.item())
            loss_history['orthogonality'].append(ortho_loss.item())
            
            # Check for convergence and reinitialize if stuck
            if epoch > 5000 and len(loss_slope_history) == 1000:
                if abs(slope) < convergence_threshold:
                    print(f"Converged at epoch {epoch}!")
                    break
        
        # Store results for this eigenfunction
        with torch.no_grad():
            u_final, eigenvalue_final = best_model(X_torch)
            eigenvalues.append(eigenvalue_final.item())
            eigenfunctions.append(u_final.detach())
            all_models.append(best_model)
        
        print(f"\nFound eigenvalue: λ_{eig_idx} = {eigenvalue_final.item():.6f}")
    
    return eigenvalues, eigenfunctions, all_models, loss_history


In [None]:
"""

THIS ONE ACTUALLY WORKS OK

"""



# ============ DEVICE SETUP ============
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# ============ NETWORK ARCHITECTURE ============
class Sin(nn.Module):
    """Sine activation function"""
    def forward(self, x):
        return torch.sin(x)


class EigenfunctionNN(nn.Module):
    """
    Neural network to learn eigenfunctions on point clouds.
    Input: 3D coordinates (x, y, z)
    Output: eigenfunction value u(x,y,z) and eigenvalue λ
    """
    def __init__(self, hidden_dim=64, input_dim=3, initial_eigenvalue=0.0):
        super().__init__()
        self.activation = Sin()
        
        # Learnable eigenvalue with better initialization
        self.eigenvalue_layer = nn.Linear(1, 1, bias=False)
        with torch.no_grad():
            self.eigenvalue_layer.weight.fill_(initial_eigenvalue)
        
        # Network layers - concatenate eigenvalue at each layer
        self.fc1 = nn.Linear(input_dim + 1, hidden_dim)  # +1 for eigenvalue
        self.fc2 = nn.Linear(hidden_dim + 1, hidden_dim)  # +1 for eigenvalue
        self.fc3 = nn.Linear(hidden_dim + 1, hidden_dim)  # +1 for eigenvalue
        self.fc4 = nn.Linear(hidden_dim + 1, 1)  # +1 for eigenvalue
        
    def forward(self, x):
        """
        Args:
            x: (N, 3) point cloud coordinates
        Returns:
            u: (N, 1) eigenfunction values
            eigenvalue: scalar learnable eigenvalue
        """
        # Learn eigenvalue and broadcast to match batch size
        eigenvalue = torch.abs(self.eigenvalue_layer(torch.ones(1, 1).to(x.device)))
        eigenvalue_expanded = eigenvalue.expand(x.shape[0], 1)  # (N, 1)
        
        # Forward pass - concatenate eigenvalue at each layer
        h = torch.cat([x, eigenvalue_expanded], dim=1)  # (N, input_dim+1)
        h = self.activation(self.fc1(h))
        
        h = torch.cat([h, eigenvalue_expanded], dim=1)  # (N, hidden_dim+1)
        h = self.activation(self.fc2(h))
        
        h = torch.cat([h, eigenvalue_expanded], dim=1)  # (N, hidden_dim+1)
        h = self.activation(self.fc3(h))
        
        h = torch.cat([h, eigenvalue_expanded], dim=1)  # (N, hidden_dim+1)
        u = self.fc4(h)
        
        return u, eigenvalue


# ============ LOSS COMPUTATION ============
def compute_eigenvalue_loss(u, eigenvalue, L_torch, M_torch):
    """
    Compute residual for Lu = λMu using discrete operators.
    Returns:
        loss: MSE of residual ||Lu - λMu||²
    """
    u_flat = u.squeeze()  # (N,)
    
    # Compute Lu and λMu
    Lu = torch.sparse.mm(L_torch, u_flat.unsqueeze(1)).squeeze()
    Mu = torch.sparse.mm(M_torch, u_flat.unsqueeze(1)).squeeze()
    residual = Lu - eigenvalue * Mu
    
    return torch.mean(residual ** 2), Lu, Mu


def compute_normalization_loss(u, M_torch):
    """
    Enforce u^T M u = 1 (mass-matrix normalization).
    """
    u_flat = u.squeeze()    
    Mu = torch.sparse.mm(M_torch, u_flat.unsqueeze(1)).squeeze()
    norm_squared = torch.dot(u_flat, Mu)

    return (norm_squared - 1.0) ** 2


def compute_orthogonality_loss(u, previous_eigenfunctions, M_torch):
    """
    Enforce u ⊥ u_i for all previously found eigenfunctions.
    Uses M-orthogonality: u^T M u_i = 0
    """
    if not previous_eigenfunctions:
        return torch.tensor(0.0, device=M_torch.device)
    
    u_flat = u.squeeze()
    ortho_loss = torch.tensor(0.0, device=M_torch.device)
    
    
    for u_prev in previous_eigenfunctions:
        u_prev_flat = u_prev.squeeze()
        # Compute u^T M u_prev
        Mu_prev = torch.sparse.mm(M_torch, u_prev_flat.unsqueeze(1)).squeeze()
        overlap = torch.dot(u_flat, Mu_prev)
        ortho_loss += overlap ** 2
    
    return ortho_loss


# ============ UTILITY FUNCTIONS ============
def sparse_to_torch(sparse_matrix, device):
    """Convert scipy sparse matrix to torch sparse tensor."""
    # Handle any scipy sparse format
    if sparse.issparse(sparse_matrix):
        coo = sparse_matrix.tocoo()
    else:
        raise ValueError(f"Expected scipy sparse matrix, got {type(sparse_matrix)}")
    
    indices = torch.LongTensor(np.vstack((coo.row, coo.col)))
    values = torch.FloatTensor(coo.data)
    shape = coo2,.shape
    return torch.sparse_coo_tensor(indices, values, shape).to(device)


def initialize_weights(m):
    """Reinitialize network weights for finding next eigenfunction."""
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)


# ============ MAIN TRAINING FUNCTION ============
def train_eigenvalue_pinn(X, L, M, hidden_dim=64, epochs=20000, 
                          lr=1e-3, num_eigenfunctions=5, 
                          convergence_threshold=1e-7,
                          ortho_weight=1.0):
    """print(f"The true Eigenvalues are: {np.array2string(eigvals[:10], formatter={'float': lambda x: f'{x:.3f}'})}")
print(f"The predicted Eigenvalues are: {np.array2string(np.array(eigenvalues), formatter={'float': lambda x: f'{x:.3f}'})}")
    Train PINN to solve Lu = λMu eigenvalue problem.
    
    Args:
        X: (N, 3) point cloud coordinates
        L: (N, N) Laplacian matrix (scipy sparse: csr, csc, dia, etc.)
        M: (N, N) Mass matrix (scipy sparse: csr, csc, dia, etc.)
        hidden_dim: Hidden layer dimension
        epochs: Training epochs per eigenfunction
        lr: Learning rate
        num_eigenfunctions: Number of eigenfunctions to find
        convergence_threshold: Threshold for detecting convergence
        ortho_weight: Weight for orthogonality loss (increase if eigenfunctions overlap)
    
    Returns:
        eigenvalues: List of found eigenvalues
        eigenfunctions: List of (N, 1) eigenfunctions
        loss_history: Training loss history
    """
    
    # Prepare inputs
    X_torch = torch.as_tensor(X, dtype=torch.float32, device=device)
    X_torch.requires_grad = True

    # Pre-convert sparse matrices
    L_torch = sparse_to_torch(L, device) if sparse.issparse(L) else L.to(device)
    M_torch = sparse_to_torch(M, device) if sparse.issparse(M) else M.to(device)
    
    # Storage for results
    eigenvalues = []
    eigenfunctions = []
    all_models = []
    loss_history = {'total': [], 'eig': [], 'norm': [], 'ortho': []}
    
    print(f"Training on device: {device}")
    print(f"Point cloud size: {X.shape[0]} points")
    print(f"Matrix format: L is {type(L).__name__}, M is {type(M).__name__}")
    
    # ============ ITERATIVE EIGENFUNCTION DISCOVERY ============
    for eig_idx in range(num_eigenfunctions):
        print(f"\n{'='*60}")
        print(f"Finding eigenfunction {eig_idx + 1}/{num_eigenfunctions}")
        print(f"{'='*60}")
        
        # Initialize network with progressively larger eigenvalue guess
        # For Laplacian: smallest eigenvalue is 0
        if eig_idx == 0:
            initial_eigenvalue = 0.0  # First eigenvalue for Laplacian
        elif eig_idx > 0:
            # Use previous eigenvalue as lower bound + small increment
            initial_eigenvalue = eigenvalues[-1] + 0.15
        else:
            initial_eigenvalue = eig_idx * 0.2
        
        model = EigenfunctionNN(hidden_dim=hidden_dim, input_dim=X.shape[1], 
                               initial_eigenvalue=initial_eigenvalue).to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.999, 0.9999))
        
        best_model = None
        best_loss = float('inf')
        ema_slope = 1.0
        prev_loss = None
        
        for epoch in trange(epochs, desc=f"Eigen {eig_idx+1}/{num_eigenfunctions}"):
            optimizer.zero_grad()
            
            # Forward pass
            u, eigenvalue = model(X_torch)
            
            # Compute losses
            eig_loss, _, _ = compute_eigenvalue_loss(u, eigenvalue, L_torch, M_torch)
            norm_loss = compute_normalization_loss(u, M_torch)
            ortho_loss = compute_orthogonality_loss(u, eigenfunctions, M_torch)
            
            # Total loss with weighting
            total_loss = eig_loss + norm_loss + ortho_weight * ortho_loss
            
            # Backward pass
            total_loss.backward()
            optimizer.step()
            
            # Convergence tracking
            if prev_loss is not None:
                ema_slope = 0.75 * ema_slope + 0.25 * abs(prev_loss - eig_loss.item())
            prev_loss = eig_loss.item()

            if ema_slope < convergence_threshold and epoch > 2000:
                print(f"Converged at epoch {epoch}")
                break

            # Save best
            if eig_loss.item() < best_loss:
                best_loss = eig_loss.item()
                best_model = copy.deepcopy(model)

            # Logging every 500 epochs
            if epoch % 500 == 0:
                print(f"Epoch {epoch:5d} | λ={eigenvalue.item():.6f} | "
                      f"Eig={eig_loss.item():.2e} | Norm={norm_loss.item():.2e} | "
                      f"Ortho={ortho_loss.item():.2e} | EMA slope={ema_slope:.2e}")

            # Lightweight history
            if epoch % 100 == 0:
                loss_history['total'].append(total_loss.item())
                loss_history['eig'].append(eig_loss.item())
                loss_history['norm'].append(norm_loss.item())
                loss_history['ortho'].append(ortho_loss.item())

        # Store results
        with torch.no_grad():
            u_final, eigenvalue_final = best_model(X_torch)
            eigenvalues.append(eigenvalue_final.item())
            eigenfunctions.append(u_final.detach())
            all_models.append(best_model)

        print(f"\nFound eigenvalue: λ_{eig_idx} = {eigenvalue_final.item():.6f}")
    
    return eigenvalues, eigenfunctions, all_models, loss_history


In [125]:
 # ============ DEVICE SETUP ============
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# ============ NETWORK ARCHITECTURE ============
class Sin(nn.Module):
    """Sine activation function"""
    def forward(self, x):
        return torch.sin(x)


class EigenfunctionNN(nn.Module):
    """
    Neural network to learn eigenfunctions on point clouds.
    Input: 3D coordinates (x, y, z)
    Output: eigenfunction value u(x,y,z) and eigenvalue λ
    """
    def __init__(self, hidden_dim=64, input_dim=3, initial_eigenvalue=0.0):
        super().__init__()
        self.activation = Sin()
        
        # Learnable eigenvalue with better initialization
        self.eigenvalue_layer = nn.Linear(1, 1, bias=False)
        with torch.no_grad():
            self.eigenvalue_layer.weight.fill_(initial_eigenvalue)
        
        # Network layers - ONLY concatenate eigenvalue at the first layer
        # This keeps the eigenvalue informative but prevents over-coupling in deeper layers
        self.fc1 = nn.Linear(input_dim + 1, hidden_dim)  # +1 for eigenvalue
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, 1)
        
    def forward(self, x):
        """
        Args:
            x: (N, 3) point cloud coordinates
        Returns:
            u: (N, 1) eigenfunction values
            eigenvalue: scalar learnable eigenvalue
        """
        # Learn eigenvalue and broadcast to match batch size
        # Use torch.clamp(..., min=0.0) to enforce non-negativity
        eigenvalue = torch.clamp(self.eigenvalue_layer(torch.ones(1, 1).to(x.device)), min=0.0)
        eigenvalue_expanded = eigenvalue.expand(x.shape[0], 1)  # (N, 1)
        
        # Forward pass - concatenate eigenvalue ONLY at the input
        h = torch.cat([x, eigenvalue_expanded], dim=1)  # (N, input_dim+1)
        h = self.activation(self.fc1(h))
        
        # No eigenvalue concatenation in hidden layers
        h = self.activation(self.fc2(h))
        h = self.activation(self.fc3(h))
        
        u = self.fc4(h)
        
        return u, eigenvalue


# ============ LOSS COMPUTATION ============
def compute_eigenvalue_loss(u, eigenvalue, L_torch, M_torch):
    """
    Compute residual for Lu = λMu using discrete operators.
    Returns:
        loss: MSE of residual ||Lu - λMu||²
    """
    u_flat = u.squeeze()  # (N,)
    
    # Compute Lu and λMu
    # Ensure u_flat is (N, 1) for sparse.mm, which expects a 2D tensor
    Lu = torch.sparse.mm(L_torch, u_flat.unsqueeze(1)).squeeze()
    Mu = torch.sparse.mm(M_torch, u_flat.unsqueeze(1)).squeeze()
    residual = Lu - eigenvalue * Mu
    
    return torch.mean(residual ** 2), Lu, Mu


def compute_normalization_loss(u, M_torch):
    """
    Enforce u^T M u = 1 (mass-matrix normalization).
    """
    u_flat = u.squeeze()    
    Mu = torch.sparse.mm(M_torch, u_flat.unsqueeze(1)).squeeze()
    norm_squared = torch.dot(u_flat, Mu)

    return (norm_squared - 1.0) ** 2


def compute_orthogonality_loss(u, previous_eigenfunctions, M_torch):
    """
    Enforce u ⊥ u_i for all previously found eigenfunctions.
    Uses M-orthogonality: u^T M u_i = 0
    """
    if not previous_eigenfunctions:
        return torch.tensor(0.0, device=M_torch.device)
    
    u_flat = u.squeeze()
    ortho_loss = torch.tensor(0.0, device=M_torch.device)
    
    
    for u_prev in previous_eigenfunctions:
        u_prev_flat = u_prev.squeeze()
        # Compute u^T M u_prev
        Mu_prev = torch.sparse.mm(M_torch, u_prev_flat.unsqueeze(1)).squeeze()
        overlap = torch.dot(u_flat, Mu_prev)
        ortho_loss += overlap ** 2
    
    return ortho_loss


def compute_ordering_loss(eigenvalue, previous_eigenvalue, margin=1e-3):
    """
    Enforce λ_i > λ_{i-1} + margin.
    This is crucial for preventing mode collapse to a lower-energy mode.
    """
    if previous_eigenvalue is None:
        return torch.tensor(0.0, device=eigenvalue.device)
    
    # max(0, λ_{i-1} + margin - λ_i)^2
    loss = torch.max(
        torch.tensor(0.0, device=eigenvalue.device), 
        previous_eigenvalue.detach() + margin - eigenvalue
    ) ** 2
    return loss.squeeze()


# ============ UTILITY FUNCTIONS ============
def sparse_to_torch(sparse_matrix, device):
    """Convert scipy sparse matrix to torch sparse tensor."""
    # Handle any scipy sparse format
    if sparse.issparse(sparse_matrix):
        coo = sparse_matrix.tocoo()
    else:
        raise ValueError(f"Expected scipy sparse matrix, got {type(sparse_matrix)}")
    
    indices = torch.LongTensor(np.vstack((coo.row, coo.col)))
    values = torch.FloatTensor(coo.data)
    shape = coo.shape
    return torch.sparse_coo_tensor(indices, values, shape).to(device)


def initialize_weights(m):
    """Reinitialize network weights for finding next eigenfunction."""
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)


# ============ MAIN TRAINING FUNCTION ============
def train_eigenvalue_pinn(X, L, M, hidden_dim=64, epochs=20000, 
                          lr=1e-3, num_eigenfunctions=10, 
                          convergence_threshold=1e-7,
                          ortho_weight=100.0,  # CRITICAL INCREASE: From 1.0 to 100.0
                          order_weight=100.0,   # NEW: Weight for the ordering constraint
                          initial_lambda_step=0.3): # CRITICAL INCREASE: From 0.15 to 0.3
    """
    Train PINN to solve Lu = λMu eigenvalue problem.
    """
    
    # Prepare inputs
    X_torch = torch.as_tensor(X, dtype=torch.float32, device=device)
    X_torch.requires_grad = True

    # Pre-convert sparse matrices
    L_torch = sparse_to_torch(L, device) if sparse.issparse(L) else L.to(device)
    M_torch = sparse_to_torch(M, device) if sparse.issparse(M) else M.to(device)
    
    # Storage for results
    eigenvalues = []
    eigenfunctions = []
    all_models = []
    loss_history = {'total': [], 'eig': [], 'norm': [], 'ortho': [], 'order': []}
    
    print(f"Training on device: {device}")
    print(f"Point cloud size: {X.shape[0]} points")
    print(f"Matrix format: L is {type(L).__name__}, M is {type(M).__name__}")
    
    # ============ ITERATIVE EIGENFUNCTION DISCOVERY ============
    for eig_idx in range(num_eigenfunctions):
        print(f"\n{'='*60}")
        print(f"Finding eigenfunction {eig_idx + 1}/{num_eigenfunctions}")
        print(f"Hyperparams: λ_step={initial_lambda_step:.2f}, ortho_w={ortho_weight:.1f}, order_w={order_weight:.1f}")
        print(f"{'='*60}")
        
        # Determine previous eigenvalue and initial guess
        previous_eigenvalue = None
        if eig_idx == 0:
            initial_eigenvalue = 0.0
        else:
            previous_eigenvalue = torch.tensor(eigenvalues[-1], dtype=torch.float32, device=device)
            # Use larger, non-fixed step size
            initial_eigenvalue = eigenvalues[-1] + initial_lambda_step 
            
        
        # Initialize network
        model = EigenfunctionNN(hidden_dim=hidden_dim, input_dim=X.shape[1], 
                               initial_eigenvalue=initial_eigenvalue).to(device)
        
        # REINITIALIZE EVERYTHING FOR NOW (can be optimized later)
        model.apply(initialize_weights) 
        
        # Re-set initial eigenvalue guess, as initialize_weights overwrites it
        if previous_eigenvalue is not None:
             with torch.no_grad():
                model.eigenvalue_layer.weight.fill_(initial_eigenvalue)


        optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.999, 0.9999))
        
        # --- Convergence Tracking Setup ---
        best_model = None
        best_eig_loss = float('inf')
        ema_slope = 1.0
        prev_loss = None
        
        # --- Training Loop ---
        for epoch in trange(epochs, desc=f"Eigen {eig_idx+1}/{num_eigenfunctions}"):
            optimizer.zero_grad()
            
            # Forward pass
            u, eigenvalue = model(X_torch)
            
            # Compute losses
            eig_loss, _, _ = compute_eigenvalue_loss(u, eigenvalue, L_torch, M_torch)
            norm_loss = compute_normalization_loss(u, M_torch)
            ortho_loss = compute_orthogonality_loss(u, eigenfunctions, M_torch)
            order_loss = compute_ordering_loss(eigenvalue, previous_eigenvalue) # NEW LOSS
            
            # Total loss with weighting
            total_loss = eig_loss + norm_loss + \
                         ortho_weight * ortho_loss + \
                         order_weight * order_loss # APPLY WEIGHT
            
            # Backward pass
            total_loss.backward()
            optimizer.step()
            
            # Convergence tracking (Use Eig loss for convergence)
            if prev_loss is not None:
                ema_slope = 0.75 * ema_slope + 0.25 * abs(prev_loss - eig_loss.item())
            prev_loss = eig_loss.item()

            # Wait longer for convergence check (e.g., after 5000 epochs)
            # And use a higher threshold (e.g., 1e-5) as convergence can be noisy
            if ema_slope < convergence_threshold * 10.0 and epoch > 5000:
                print(f"Converged at epoch {epoch}")
                break

            # Save best model based on Eig loss
            if eig_loss.item() < best_eig_loss:
                best_eig_loss = eig_loss.item()
                best_model = copy.deepcopy(model)

            # Logging every 500 epochs
            if epoch % 500 == 0:
                print(f"Epoch {epoch:5d} | λ={eigenvalue.item():.6f} | "
                      f"Eig={eig_loss.item():.2e} | Norm={norm_loss.item():.2e} | "
                      f"Ortho={ortho_loss.item():.2e} | Order={order_loss.item():.2e} | " # NEW LOG
                      f"EMA slope={ema_slope:.2e}")

            # Lightweight history
            if epoch % 100 == 0:
                loss_history['total'].append(total_loss.item())
                loss_history['eig'].append(eig_loss.item())
                loss_history['norm'].append(norm_loss.item())
                loss_history['ortho'].append(ortho_loss.item())
                loss_history['order'].append(order_loss.item())

        # Store results
        if best_model is None:
             print("Warning: No best model found (first epoch was best, using final model).")
             best_model = model
             
        with torch.no_grad():
            u_final, eigenvalue_final = best_model(X_torch)
            eigenvalues.append(eigenvalue_final.item())
            
            # Re-normalize the eigenfunction after training
            u_flat = u_final.squeeze()
            Mu = torch.sparse.mm(M_torch, u_flat.unsqueeze(1)).squeeze()
            norm_squared = torch.dot(u_flat, Mu)
            u_normalized = u_final / torch.sqrt(norm_squared)

            eigenfunctions.append(u_normalized.detach())
            all_models.append(best_model)

        print(f"\nFound eigenvalue: λ_{eig_idx} = {eigenvalue_final.item():.6f}")
    
    return eigenvalues, eigenfunctions, all_models, loss_history

In [126]:
m = Mesh('bunny.obj')

centroid = m.verts.mean(0)
std_max = m.verts.std(0).max()

X = (m.verts - centroid)/std_max

m = Mesh(verts = X, connectivity = m.connectivity)

L, M = robust_laplacian.point_cloud_laplacian(X)

print('Computing Laplacian')
K_igl, M_igl = m.computeLaplacian()

# following Finite Elements methodology 
# K is stiffness matrix, M is mass matrix
# The problem to solve becomes 
# K*u = lambda * M*u
print('Computing eigen values')
eigvals, eigvecs = eigh(K_igl,M_igl)

# send all relevant numpy arrays to torch tensors
K_ = torch.from_numpy(K_igl).float().to(device)
M_ = torch.from_numpy(M_igl).float().to(device)
X_ = torch.from_numpy(m.verts).float().to(device)



eigenvalues, eigenfunctions, models, history = train_eigenvalue_pinn(
    X_, K_, M_, 
    hidden_dim=128, 
    epochs=5000, 
    lr=1e-3, 
    num_eigenfunctions=10
)

print("\n" + "="*60)
print("RESULTS")
print("="*60)
for i, lam in enumerate(eigenvalues):
    print(f"λ_{i} = {lam:.6f}")

Computing Laplacian
Computing eigen values
Training on device: cuda
Point cloud size: 2503 points
Matrix format: L is Tensor, M is Tensor

Finding eigenfunction 1/10
Hyperparams: λ_step=0.30, ortho_w=100.0, order_w=100.0


Eigen 1/10:   1%|          | 30/5000 [00:00<00:16, 292.40it/s]

Epoch     0 | λ=0.000000 | Eig=1.77e-06 | Norm=8.32e-01 | Ortho=0.00e+00 | Order=0.00e+00 | EMA slope=1.00e+00


Eigen 1/10:  12%|█▏        | 590/5000 [00:01<00:09, 444.53it/s]

Epoch   500 | λ=0.000000 | Eig=9.06e-06 | Norm=3.35e-02 | Ortho=0.00e+00 | Order=0.00e+00 | EMA slope=2.43e-07


Eigen 1/10:  21%|██▏       | 1064/5000 [00:02<00:08, 469.66it/s]

Epoch  1000 | λ=0.000000 | Eig=3.46e-05 | Norm=9.06e-02 | Ortho=0.00e+00 | Order=0.00e+00 | EMA slope=1.44e-07


Eigen 1/10:  32%|███▏      | 1591/5000 [00:03<00:07, 462.05it/s]

Epoch  1500 | λ=0.000000 | Eig=1.44e-05 | Norm=1.13e-01 | Ortho=0.00e+00 | Order=0.00e+00 | EMA slope=3.48e-08


Eigen 1/10:  41%|████▏     | 2071/5000 [00:04<00:06, 474.32it/s]

Epoch  2000 | λ=0.000000 | Eig=2.36e-05 | Norm=4.69e-02 | Ortho=0.00e+00 | Order=0.00e+00 | EMA slope=9.34e-08


Eigen 1/10:  52%|█████▏    | 2589/5000 [00:05<00:05, 465.13it/s]

Epoch  2500 | λ=0.000000 | Eig=2.41e-05 | Norm=2.33e-04 | Ortho=0.00e+00 | Order=0.00e+00 | EMA slope=1.04e-07


Eigen 1/10:  61%|██████▏   | 3065/5000 [00:06<00:04, 473.25it/s]

Epoch  3000 | λ=0.000000 | Eig=1.50e-05 | Norm=1.39e-02 | Ortho=0.00e+00 | Order=0.00e+00 | EMA slope=1.82e-08


Eigen 1/10:  72%|███████▏  | 3592/5000 [00:07<00:02, 472.06it/s]

Epoch  3500 | λ=0.000000 | Eig=2.39e-05 | Norm=9.41e-03 | Ortho=0.00e+00 | Order=0.00e+00 | EMA slope=4.64e-08


Eigen 1/10:  81%|████████▏ | 4064/5000 [00:08<00:02, 465.34it/s]

Epoch  4000 | λ=0.000000 | Eig=2.54e-05 | Norm=2.71e-03 | Ortho=0.00e+00 | Order=0.00e+00 | EMA slope=3.28e-08


Eigen 1/10:  92%|█████████▏| 4593/5000 [00:10<00:00, 474.42it/s]

Epoch  4500 | λ=0.000000 | Eig=2.55e-05 | Norm=1.04e-03 | Ortho=0.00e+00 | Order=0.00e+00 | EMA slope=3.46e-08


Eigen 1/10: 100%|██████████| 5000/5000 [00:10<00:00, 459.47it/s]



Found eigenvalue: λ_0 = 0.000000

Finding eigenfunction 2/10
Hyperparams: λ_step=0.30, ortho_w=100.0, order_w=100.0


Eigen 2/10:   0%|          | 10/5000 [00:00<00:51, 97.52it/s]

Epoch     0 | λ=0.300000 | Eig=4.81e-05 | Norm=1.41e-02 | Ortho=1.53e-02 | Order=0.00e+00 | EMA slope=1.00e+00


Eigen 2/10:  12%|█▏        | 576/5000 [00:01<00:11, 399.58it/s]

Epoch   500 | λ=0.323117 | Eig=1.58e-05 | Norm=4.56e-01 | Ortho=8.05e-02 | Order=0.00e+00 | EMA slope=2.28e-08


Eigen 2/10:  21%|██        | 1057/5000 [00:02<00:10, 383.66it/s]

Epoch  1000 | λ=0.326048 | Eig=1.86e-05 | Norm=2.14e-01 | Ortho=1.25e-02 | Order=0.00e+00 | EMA slope=4.59e-07


Eigen 2/10:  31%|███       | 1550/5000 [00:03<00:08, 414.66it/s]

Epoch  1500 | λ=0.309746 | Eig=2.16e-05 | Norm=2.23e-01 | Ortho=1.91e-03 | Order=0.00e+00 | EMA slope=3.07e-07


Eigen 2/10:  41%|████      | 2051/5000 [00:05<00:07, 411.07it/s]

Epoch  2000 | λ=0.268622 | Eig=5.87e-05 | Norm=1.86e-02 | Ortho=2.58e-06 | Order=0.00e+00 | EMA slope=6.40e-07


Eigen 2/10:  51%|█████     | 2544/5000 [00:06<00:06, 404.23it/s]

Epoch  2500 | λ=0.242385 | Eig=6.55e-05 | Norm=2.82e-01 | Ortho=7.45e-04 | Order=0.00e+00 | EMA slope=3.26e-07


Eigen 2/10:  61%|██████▏   | 3065/5000 [00:07<00:04, 389.80it/s]

Epoch  3000 | λ=0.224054 | Eig=2.02e-05 | Norm=1.16e-02 | Ortho=7.76e-04 | Order=0.00e+00 | EMA slope=1.06e-07


Eigen 2/10:  71%|███████   | 3552/5000 [00:09<00:03, 404.56it/s]

Epoch  3500 | λ=0.209146 | Eig=1.05e-05 | Norm=6.21e-06 | Ortho=1.46e-04 | Order=0.00e+00 | EMA slope=1.96e-09


Eigen 2/10:  81%|████████  | 4050/5000 [00:10<00:02, 401.78it/s]

Epoch  4000 | λ=0.195466 | Eig=1.50e-05 | Norm=4.67e-03 | Ortho=6.22e-06 | Order=0.00e+00 | EMA slope=2.04e-08


Eigen 2/10:  92%|█████████▏| 4575/5000 [00:11<00:01, 384.59it/s]

Epoch  4500 | λ=0.185902 | Eig=1.83e-05 | Norm=1.38e-02 | Ortho=1.21e-05 | Order=0.00e+00 | EMA slope=4.68e-08


Eigen 2/10: 100%|██████████| 5000/5000 [00:12<00:00, 395.07it/s]



Found eigenvalue: λ_1 = 0.320646

Finding eigenfunction 3/10
Hyperparams: λ_step=0.30, ortho_w=100.0, order_w=100.0


Eigen 3/10:   1%|          | 33/5000 [00:00<00:15, 328.77it/s]

Epoch     0 | λ=0.620646 | Eig=5.43e-05 | Norm=1.80e+01 | Ortho=5.52e+00 | Order=0.00e+00 | EMA slope=1.00e+00


Eigen 3/10:  11%|█▏        | 565/5000 [00:01<00:11, 378.72it/s]

Epoch   500 | λ=0.432166 | Eig=2.17e-05 | Norm=2.88e-02 | Ortho=2.02e-02 | Order=0.00e+00 | EMA slope=5.69e-07


Eigen 3/10:  21%|██        | 1061/5000 [00:02<00:10, 374.44it/s]

Epoch  1000 | λ=0.358600 | Eig=5.60e-05 | Norm=2.32e-02 | Ortho=8.21e-03 | Order=0.00e+00 | EMA slope=2.73e-08


Eigen 3/10:  32%|███▏      | 1578/5000 [00:04<00:08, 394.90it/s]

Epoch  1500 | λ=0.314720 | Eig=2.50e-05 | Norm=2.08e-01 | Ortho=1.30e-03 | Order=4.80e-05 | EMA slope=3.47e-07


Eigen 3/10:  41%|████      | 2049/5000 [00:05<00:07, 378.98it/s]

Epoch  2000 | λ=0.308613 | Eig=1.03e-05 | Norm=4.34e-01 | Ortho=4.40e-03 | Order=1.70e-04 | EMA slope=2.14e-08


Eigen 3/10:  51%|█████     | 2544/5000 [00:06<00:06, 369.98it/s]

Epoch  2500 | λ=0.319538 | Eig=4.50e-05 | Norm=3.33e-02 | Ortho=1.29e-03 | Order=4.44e-06 | EMA slope=4.41e-07


Eigen 3/10:  61%|██████    | 3043/5000 [00:08<00:05, 380.68it/s]

Epoch  3000 | λ=0.328583 | Eig=2.71e-05 | Norm=2.24e-06 | Ortho=9.29e-04 | Order=0.00e+00 | EMA slope=7.84e-08


Eigen 3/10:  71%|███████   | 3549/5000 [00:09<00:03, 374.34it/s]

Epoch  3500 | λ=0.333295 | Eig=2.17e-05 | Norm=7.11e-02 | Ortho=7.83e-04 | Order=0.00e+00 | EMA slope=1.01e-08


Eigen 3/10:  81%|████████  | 4042/5000 [00:10<00:02, 374.82it/s]

Epoch  4000 | λ=0.336227 | Eig=3.11e-05 | Norm=8.58e-03 | Ortho=2.90e-04 | Order=0.00e+00 | EMA slope=4.78e-08


Eigen 3/10:  91%|█████████ | 4539/5000 [00:12<00:01, 375.64it/s]

Epoch  4500 | λ=0.338025 | Eig=4.68e-05 | Norm=1.56e-02 | Ortho=5.01e-04 | Order=0.00e+00 | EMA slope=3.63e-08


Eigen 3/10: 100%|██████████| 5000/5000 [00:13<00:00, 375.16it/s]



Found eigenvalue: λ_2 = 0.581464

Finding eigenfunction 4/10
Hyperparams: λ_step=0.30, ortho_w=100.0, order_w=100.0


Eigen 4/10:   1%|          | 33/5000 [00:00<00:15, 329.25it/s]

Epoch     0 | λ=0.881464 | Eig=1.68e-04 | Norm=2.39e+01 | Ortho=5.03e+00 | Order=0.00e+00 | EMA slope=1.00e+00


Eigen 4/10:  11%|█         | 555/5000 [00:01<00:12, 366.08it/s]

Epoch   500 | λ=0.670922 | Eig=3.75e-05 | Norm=3.29e-02 | Ortho=3.99e-02 | Order=0.00e+00 | EMA slope=9.51e-07


Eigen 4/10:  21%|██        | 1061/5000 [00:02<00:11, 357.49it/s]

Epoch  1000 | λ=0.579472 | Eig=1.18e-05 | Norm=3.88e-01 | Ortho=1.02e-02 | Order=8.96e-06 | EMA slope=2.14e-07


Eigen 4/10:  31%|███       | 1541/5000 [00:04<00:10, 342.77it/s]

Epoch  1500 | λ=0.556698 | Eig=1.99e-06 | Norm=9.05e-01 | Ortho=5.20e-03 | Order=6.64e-04 | EMA slope=4.98e-08


Eigen 4/10:  41%|████      | 2047/5000 [00:05<00:08, 354.24it/s]

Epoch  2000 | λ=0.590127 | Eig=4.80e-05 | Norm=7.97e-02 | Ortho=2.50e-03 | Order=0.00e+00 | EMA slope=5.73e-07


Eigen 4/10:  51%|█████     | 2551/5000 [00:07<00:06, 354.86it/s]

Epoch  2500 | λ=0.617444 | Eig=2.07e-05 | Norm=2.25e-01 | Ortho=2.61e-04 | Order=0.00e+00 | EMA slope=2.29e-08


Eigen 4/10:  61%|██████    | 3058/5000 [00:08<00:05, 356.05it/s]

Epoch  3000 | λ=0.634817 | Eig=4.41e-05 | Norm=1.99e-02 | Ortho=1.28e-04 | Order=0.00e+00 | EMA slope=9.05e-08


Eigen 4/10:  71%|███████▏  | 3568/5000 [00:10<00:03, 359.30it/s]

Epoch  3500 | λ=0.645841 | Eig=3.89e-05 | Norm=8.15e-04 | Ortho=1.01e-03 | Order=0.00e+00 | EMA slope=7.37e-08


Eigen 4/10:  81%|████████  | 4036/5000 [00:11<00:02, 347.06it/s]

Epoch  4000 | λ=0.652977 | Eig=4.65e-05 | Norm=3.92e-02 | Ortho=2.35e-04 | Order=0.00e+00 | EMA slope=3.46e-08


Eigen 4/10:  91%|█████████ | 4537/5000 [00:12<00:01, 353.06it/s]

Epoch  4500 | λ=0.657551 | Eig=3.91e-05 | Norm=7.35e-05 | Ortho=3.23e-04 | Order=0.00e+00 | EMA slope=6.34e-08


Eigen 4/10: 100%|██████████| 5000/5000 [00:14<00:00, 353.73it/s]



Found eigenvalue: λ_3 = 0.786167

Finding eigenfunction 5/10
Hyperparams: λ_step=0.30, ortho_w=100.0, order_w=100.0


Eigen 5/10:   1%|          | 31/5000 [00:00<00:16, 309.54it/s]

Epoch     0 | λ=1.086167 | Eig=6.83e-05 | Norm=2.37e+00 | Ortho=2.45e+00 | Order=0.00e+00 | EMA slope=1.00e+00


Eigen 5/10:  11%|█         | 545/5000 [00:01<00:13, 334.42it/s]

Epoch   500 | λ=0.932443 | Eig=1.38e-05 | Norm=3.50e-01 | Ortho=1.32e-01 | Order=0.00e+00 | EMA slope=1.84e-07


Eigen 5/10:  21%|██        | 1054/5000 [00:03<00:11, 331.55it/s]

Epoch  1000 | λ=0.861854 | Eig=2.89e-06 | Norm=8.58e-01 | Ortho=5.47e-02 | Order=0.00e+00 | EMA slope=4.73e-09


Eigen 5/10:  31%|███       | 1541/5000 [00:04<00:10, 342.19it/s]

Epoch  1500 | λ=0.817879 | Eig=2.00e-06 | Norm=9.34e-01 | Ortho=2.43e-02 | Order=0.00e+00 | EMA slope=1.06e-08


Eigen 5/10:  41%|████      | 2060/5000 [00:06<00:08, 340.55it/s]

Epoch  2000 | λ=0.789382 | Eig=3.72e-07 | Norm=9.86e-01 | Ortho=7.04e-03 | Order=0.00e+00 | EMA slope=4.30e-09


Eigen 5/10:  51%|█████     | 2558/5000 [00:07<00:07, 336.81it/s]

Epoch  2500 | λ=0.773662 | Eig=1.29e-07 | Norm=9.97e-01 | Ortho=1.38e-03 | Order=1.82e-04 | EMA slope=1.91e-09


Eigen 5/10:  61%|██████    | 3046/5000 [00:09<00:05, 343.09it/s]

Epoch  3000 | λ=0.777834 | Eig=4.10e-07 | Norm=9.86e-01 | Ortho=1.05e-03 | Order=8.71e-05 | EMA slope=1.97e-10


Eigen 5/10:  71%|███████▏  | 3565/5000 [00:10<00:04, 341.62it/s]

Epoch  3500 | λ=0.791008 | Eig=1.57e-06 | Norm=9.14e-01 | Ortho=1.59e-03 | Order=0.00e+00 | EMA slope=2.94e-09


Eigen 5/10:  81%|████████  | 4049/5000 [00:12<00:02, 336.04it/s]

Epoch  4000 | λ=0.800370 | Eig=6.34e-06 | Norm=7.10e-01 | Ortho=3.86e-03 | Order=0.00e+00 | EMA slope=7.99e-09


Eigen 5/10:  91%|█████████ | 4557/5000 [00:13<00:01, 334.56it/s]

Epoch  4500 | λ=0.806367 | Eig=2.34e-06 | Norm=8.83e-01 | Ortho=1.50e-03 | Order=0.00e+00 | EMA slope=3.77e-09


Eigen 5/10: 100%|██████████| 5000/5000 [00:14<00:00, 335.83it/s]



Found eigenvalue: λ_4 = 0.849548

Finding eigenfunction 6/10
Hyperparams: λ_step=0.30, ortho_w=100.0, order_w=100.0


Eigen 6/10:   1%|          | 26/5000 [00:00<00:19, 256.99it/s]

Epoch     0 | λ=1.149548 | Eig=1.29e-05 | Norm=3.43e-01 | Ortho=1.59e-01 | Order=0.00e+00 | EMA slope=1.00e+00


Eigen 6/10:  11%|█         | 552/5000 [00:01<00:13, 321.63it/s]

Epoch   500 | λ=1.005289 | Eig=1.41e-05 | Norm=2.97e-01 | Ortho=1.15e-01 | Order=0.00e+00 | EMA slope=7.74e-07


Eigen 6/10:  21%|██        | 1037/5000 [00:03<00:12, 314.17it/s]

Epoch  1000 | λ=0.916741 | Eig=4.20e-06 | Norm=7.80e-01 | Ortho=4.29e-02 | Order=0.00e+00 | EMA slope=1.18e-07


Eigen 6/10:  31%|███▏      | 1564/5000 [00:04<00:10, 326.30it/s]

Epoch  1500 | λ=0.850394 | Eig=3.74e-06 | Norm=7.60e-01 | Ortho=3.80e-02 | Order=2.38e-08 | EMA slope=1.38e-09


Eigen 6/10:  41%|████      | 2050/5000 [00:06<00:09, 317.47it/s]

Epoch  2000 | λ=0.821007 | Eig=1.58e-06 | Norm=8.93e-01 | Ortho=1.31e-02 | Order=8.73e-04 | EMA slope=4.89e-09


Eigen 6/10:  51%|█████     | 2537/5000 [00:07<00:07, 326.49it/s]

Epoch  2500 | λ=0.851779 | Eig=8.66e-07 | Norm=9.37e-01 | Ortho=8.47e-03 | Order=0.00e+00 | EMA slope=4.31e-09


Eigen 6/10:  61%|██████    | 3042/5000 [00:09<00:05, 330.65it/s]

Epoch  3000 | λ=0.882266 | Eig=3.18e-07 | Norm=9.78e-01 | Ortho=3.04e-03 | Order=0.00e+00 | EMA slope=1.11e-09


Eigen 6/10:  71%|███████   | 3539/5000 [00:11<00:04, 316.47it/s]

Epoch  3500 | λ=0.901832 | Eig=1.47e-08 | Norm=9.99e-01 | Ortho=5.52e-04 | Order=0.00e+00 | EMA slope=4.75e-10


Eigen 6/10:  81%|████████  | 4041/5000 [00:12<00:02, 322.81it/s]

Epoch  4000 | λ=0.914446 | Eig=1.40e-07 | Norm=9.89e-01 | Ortho=1.22e-03 | Order=0.00e+00 | EMA slope=2.95e-10


Eigen 6/10:  91%|█████████ | 4536/5000 [00:14<00:01, 320.68it/s]

Epoch  4500 | λ=0.922547 | Eig=3.65e-08 | Norm=9.99e-01 | Ortho=8.60e-05 | Order=0.00e+00 | EMA slope=3.54e-10


Eigen 6/10: 100%|██████████| 5000/5000 [00:15<00:00, 320.43it/s]



Found eigenvalue: λ_5 = 0.901922

Finding eigenfunction 7/10
Hyperparams: λ_step=0.30, ortho_w=100.0, order_w=100.0


Eigen 7/10:   0%|          | 24/5000 [00:00<00:21, 235.94it/s]

Epoch     0 | λ=1.201922 | Eig=4.07e-05 | Norm=3.68e-01 | Ortho=2.98e+00 | Order=0.00e+00 | EMA slope=1.00e+00


Eigen 7/10:  11%|█         | 537/5000 [00:01<00:14, 310.56it/s]

Epoch   500 | λ=1.206693 | Eig=7.69e-06 | Norm=5.58e-01 | Ortho=9.96e-02 | Order=0.00e+00 | EMA slope=1.10e-08


Eigen 7/10:  21%|██        | 1053/5000 [00:03<00:12, 317.98it/s]

Epoch  1000 | λ=1.207000 | Eig=4.85e-06 | Norm=7.73e-01 | Ortho=4.05e-02 | Order=0.00e+00 | EMA slope=1.26e-08


Eigen 7/10:  31%|███       | 1532/5000 [00:04<00:11, 309.11it/s]

Epoch  1500 | λ=1.211822 | Eig=1.53e-06 | Norm=9.24e-01 | Ortho=2.14e-02 | Order=0.00e+00 | EMA slope=2.32e-08


Eigen 7/10:  41%|████      | 2058/5000 [00:06<00:09, 304.81it/s]

Epoch  2000 | λ=1.219957 | Eig=1.20e-06 | Norm=9.40e-01 | Ortho=1.23e-02 | Order=0.00e+00 | EMA slope=1.73e-08


Eigen 7/10:  51%|█████     | 2555/5000 [00:08<00:07, 306.96it/s]

Epoch  2500 | λ=1.225442 | Eig=1.60e-07 | Norm=9.95e-01 | Ortho=1.90e-03 | Order=0.00e+00 | EMA slope=2.42e-09


Eigen 7/10:  61%|██████    | 3054/5000 [00:09<00:06, 304.92it/s]

Epoch  3000 | λ=1.229207 | Eig=1.98e-07 | Norm=9.93e-01 | Ortho=2.88e-03 | Order=0.00e+00 | EMA slope=2.28e-09


Eigen 7/10:  71%|███████   | 3533/5000 [00:11<00:04, 313.27it/s]

Epoch  3500 | λ=1.232675 | Eig=4.83e-08 | Norm=9.98e-01 | Ortho=9.86e-04 | Order=0.00e+00 | EMA slope=8.34e-10


Eigen 7/10:  81%|████████  | 4051/5000 [00:13<00:03, 313.71it/s]

Epoch  4000 | λ=1.235281 | Eig=4.56e-09 | Norm=1.00e+00 | Ortho=4.54e-04 | Order=0.00e+00 | EMA slope=2.13e-10


Eigen 7/10:  91%|█████████ | 4540/5000 [00:14<00:01, 321.34it/s]

Epoch  4500 | λ=1.236456 | Eig=3.90e-08 | Norm=9.98e-01 | Ortho=5.13e-04 | Order=0.00e+00 | EMA slope=3.39e-10


Eigen 7/10: 100%|██████████| 5000/5000 [00:16<00:00, 309.86it/s]



Found eigenvalue: λ_6 = 1.235297

Finding eigenfunction 8/10
Hyperparams: λ_step=0.30, ortho_w=100.0, order_w=100.0


Eigen 8/10:   0%|          | 23/5000 [00:00<00:22, 223.52it/s]

Epoch     0 | λ=1.535297 | Eig=5.13e-05 | Norm=2.26e-01 | Ortho=3.31e+00 | Order=0.00e+00 | EMA slope=1.00e+00


Eigen 8/10:  11%|█         | 542/5000 [00:01<00:14, 309.71it/s]

Epoch   500 | λ=1.354010 | Eig=1.02e-05 | Norm=5.15e-01 | Ortho=2.02e-01 | Order=0.00e+00 | EMA slope=5.64e-07


Eigen 8/10:  21%|██        | 1043/5000 [00:03<00:13, 296.90it/s]

Epoch  1000 | λ=1.236750 | Eig=1.13e-05 | Norm=5.65e-01 | Ortho=1.12e-01 | Order=0.00e+00 | EMA slope=1.89e-07


Eigen 8/10:  31%|███       | 1558/5000 [00:05<00:11, 297.47it/s]

Epoch  1500 | λ=1.176942 | Eig=3.69e-06 | Norm=8.28e-01 | Ortho=5.09e-02 | Order=3.52e-03 | EMA slope=6.52e-08


Eigen 8/10:  41%|████      | 2035/5000 [00:06<00:10, 295.85it/s]

Epoch  2000 | λ=1.211298 | Eig=1.33e-06 | Norm=9.39e-01 | Ortho=2.10e-02 | Order=6.25e-04 | EMA slope=2.54e-08


Eigen 8/10:  51%|█████     | 2544/5000 [00:08<00:08, 293.38it/s]

Epoch  2500 | λ=1.268834 | Eig=1.43e-07 | Norm=9.94e-01 | Ortho=2.47e-03 | Order=0.00e+00 | EMA slope=2.01e-09


Eigen 8/10:  61%|██████    | 3056/5000 [00:10<00:06, 295.04it/s]

Epoch  3000 | λ=1.306471 | Eig=1.53e-07 | Norm=9.95e-01 | Ortho=4.01e-03 | Order=0.00e+00 | EMA slope=7.27e-10


Eigen 8/10:  71%|███████   | 3531/5000 [00:12<00:05, 284.44it/s]

Epoch  3500 | λ=1.330823 | Eig=8.89e-08 | Norm=9.97e-01 | Ortho=1.04e-03 | Order=0.00e+00 | EMA slope=3.21e-09


Eigen 8/10:  81%|████████  | 4042/5000 [00:13<00:03, 294.47it/s]

Epoch  4000 | λ=1.346422 | Eig=8.99e-08 | Norm=9.95e-01 | Ortho=2.18e-03 | Order=0.00e+00 | EMA slope=7.68e-10


Eigen 8/10:  91%|█████████ | 4552/5000 [00:15<00:01, 291.43it/s]

Epoch  4500 | λ=1.356502 | Eig=1.34e-07 | Norm=9.95e-01 | Ortho=9.32e-04 | Order=0.00e+00 | EMA slope=1.26e-10


Eigen 8/10: 100%|██████████| 5000/5000 [00:17<00:00, 292.22it/s]



Found eigenvalue: λ_7 = 1.333021

Finding eigenfunction 9/10
Hyperparams: λ_step=0.30, ortho_w=100.0, order_w=100.0


Eigen 9/10:   1%|          | 28/5000 [00:00<00:18, 272.92it/s]

Epoch     0 | λ=1.633021 | Eig=1.41e-04 | Norm=1.98e+00 | Ortho=2.89e+00 | Order=0.00e+00 | EMA slope=1.00e+00


Eigen 9/10:  11%|█         | 549/5000 [00:01<00:15, 286.79it/s]

Epoch   500 | λ=1.400569 | Eig=2.69e-05 | Norm=2.18e-01 | Ortho=3.87e-01 | Order=0.00e+00 | EMA slope=2.57e-07


Eigen 9/10:  21%|██        | 1037/5000 [00:03<00:14, 275.34it/s]

Epoch  1000 | λ=1.285214 | Eig=3.99e-06 | Norm=8.72e-01 | Ortho=1.13e-01 | Order=2.38e-03 | EMA slope=4.83e-08


Eigen 9/10:  31%|███       | 1530/5000 [00:05<00:12, 288.05it/s]

Epoch  1500 | λ=1.225106 | Eig=9.65e-06 | Norm=6.98e-01 | Ortho=6.35e-02 | Order=1.19e-02 | EMA slope=1.76e-07


Eigen 9/10:  41%|████      | 2049/5000 [00:07<00:10, 279.49it/s]

Epoch  2000 | λ=1.207513 | Eig=4.76e-06 | Norm=8.07e-01 | Ortho=5.35e-02 | Order=1.60e-02 | EMA slope=1.48e-08


Eigen 9/10:  51%|█████     | 2538/5000 [00:09<00:08, 279.27it/s]

Epoch  2500 | λ=1.220610 | Eig=1.16e-06 | Norm=9.50e-01 | Ortho=2.98e-02 | Order=1.29e-02 | EMA slope=4.08e-09


Eigen 9/10:  61%|██████    | 3058/5000 [00:10<00:06, 281.02it/s]

Epoch  3000 | λ=1.251642 | Eig=1.57e-07 | Norm=9.93e-01 | Ortho=6.40e-03 | Order=6.79e-03 | EMA slope=2.04e-09


Eigen 9/10:  71%|███████   | 3553/5000 [00:12<00:05, 281.35it/s]

Epoch  3500 | λ=1.289276 | Eig=2.43e-07 | Norm=9.89e-01 | Ortho=7.49e-03 | Order=2.00e-03 | EMA slope=3.18e-10


Eigen 9/10:  81%|████████  | 4033/5000 [00:14<00:03, 275.67it/s]

Epoch  4000 | λ=1.323997 | Eig=1.75e-07 | Norm=9.95e-01 | Ortho=3.00e-03 | Order=1.00e-04 | EMA slope=6.54e-10


Eigen 9/10:  91%|█████████ | 4554/5000 [00:16<00:01, 282.34it/s]

Epoch  4500 | λ=1.349166 | Eig=1.70e-07 | Norm=9.92e-01 | Ortho=2.95e-03 | Order=0.00e+00 | EMA slope=2.40e-10


Eigen 9/10: 100%|██████████| 5000/5000 [00:17<00:00, 279.41it/s]



Found eigenvalue: λ_8 = 1.361330

Finding eigenfunction 10/10
Hyperparams: λ_step=0.30, ortho_w=100.0, order_w=100.0


Eigen 10/10:   0%|          | 21/5000 [00:00<00:24, 206.37it/s]

Epoch     0 | λ=1.661330 | Eig=3.96e-04 | Norm=3.90e+01 | Ortho=1.48e+01 | Order=0.00e+00 | EMA slope=1.00e+00


Eigen 10/10:  11%|█         | 537/5000 [00:02<00:16, 268.07it/s]

Epoch   500 | λ=1.466294 | Eig=4.22e-05 | Norm=8.17e-02 | Ortho=6.00e-01 | Order=0.00e+00 | EMA slope=2.36e-07


Eigen 10/10:  21%|██        | 1037/5000 [00:03<00:14, 270.92it/s]

Epoch  1000 | λ=1.324839 | Eig=1.49e-06 | Norm=9.00e-01 | Ortho=1.34e-01 | Order=1.41e-03 | EMA slope=9.59e-09


Eigen 10/10:  31%|███       | 1537/5000 [00:05<00:12, 269.16it/s]

Epoch  1500 | λ=1.270736 | Eig=3.62e-06 | Norm=8.21e-01 | Ortho=5.40e-02 | Order=8.39e-03 | EMA slope=2.57e-08


Eigen 10/10:  41%|████      | 2035/5000 [00:07<00:10, 277.03it/s]

Epoch  2000 | λ=1.324574 | Eig=6.02e-07 | Norm=9.77e-01 | Ortho=1.72e-02 | Order=1.43e-03 | EMA slope=1.24e-08


Eigen 10/10:  51%|█████     | 2531/5000 [00:09<00:09, 271.49it/s]

Epoch  2500 | λ=1.399446 | Eig=3.28e-07 | Norm=9.85e-01 | Ortho=5.69e-03 | Order=0.00e+00 | EMA slope=3.05e-09


Eigen 10/10:  61%|██████    | 3055/5000 [00:11<00:07, 264.87it/s]

Epoch  3000 | λ=1.447892 | Eig=1.26e-07 | Norm=9.95e-01 | Ortho=7.71e-03 | Order=0.00e+00 | EMA slope=6.97e-10


Eigen 10/10:  71%|███████   | 3547/5000 [00:13<00:05, 255.08it/s]

Epoch  3500 | λ=1.478621 | Eig=1.29e-07 | Norm=9.96e-01 | Ortho=2.97e-03 | Order=0.00e+00 | EMA slope=1.66e-09


Eigen 10/10:  81%|████████  | 4041/5000 [00:15<00:03, 269.69it/s]

Epoch  4000 | λ=1.498295 | Eig=2.15e-07 | Norm=9.93e-01 | Ortho=2.58e-03 | Order=0.00e+00 | EMA slope=2.13e-09


Eigen 10/10:  91%|█████████ | 4534/5000 [00:17<00:01, 274.18it/s]

Epoch  4500 | λ=1.511043 | Eig=6.61e-08 | Norm=9.97e-01 | Ortho=1.59e-03 | Order=0.00e+00 | EMA slope=4.15e-10


Eigen 10/10: 100%|██████████| 5000/5000 [00:18<00:00, 266.28it/s]


Found eigenvalue: λ_9 = 1.508162

RESULTS
λ_0 = 0.000000
λ_1 = 0.320646
λ_2 = 0.581464
λ_3 = 0.786167
λ_4 = 0.849548
λ_5 = 0.901922
λ_6 = 1.235297
λ_7 = 1.333021
λ_8 = 1.361330
λ_9 = 1.508162





In [124]:
print(f"The true Eigenvalues are: {np.array2string(eigvals[:10], formatter={'float': lambda x: f'{x:.3f}'})}")
print(f"The pred Eigenvalues are: {np.array2string(np.array(eigenvalues), formatter={'float': lambda x: f'{x:.3f}'})}")

The true Eigenvalues are: [0.000 0.160 0.425 0.438 0.538 0.612 0.896 1.274 1.496 1.643]
The pred Eigenvalues are: [0.000 0.148 0.297 0.394 0.540 0.596 0.743 0.874 0.972 1.116]


In [127]:
print(f"The true Eigenvalues are: {np.array2string(eigvals[:10], formatter={'float': lambda x: f'{x:.3f}'})}")
print(f"The pred Eigenvalues are: {np.array2string(np.array(eigenvalues), formatter={'float': lambda x: f'{x:.3f}'})}")

The true Eigenvalues are: [0.000 0.160 0.425 0.438 0.538 0.612 0.896 1.274 1.496 1.643]
The pred Eigenvalues are: [0.000 0.321 0.581 0.786 0.850 0.902 1.235 1.333 1.361 1.508]


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import scipy.sparse as sparse
from tqdm import trange
import copy

# ============ DEVICE SETUP ============
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# ============ NETWORK ARCHITECTURE ============
# **CHANGE 1: Increased Default Hidden Dimension to 128**
class Sin(nn.Module):
    """Sine activation function"""
    def forward(self, x):
        return torch.sin(x)


class EigenfunctionNN(nn.Module):
    """
    Neural network to learn eigenfunctions on point clouds.
    Input: 3D coordinates (x, y, z)
    Output: eigenfunction value u(x,y,z) and eigenvalue λ
    """
    def __init__(self, hidden_dim=128, input_dim=3, initial_eigenvalue=0.0):
        super().__init__()
        self.activation = Sin()
        
        # Learnable eigenvalue with better initialization
        self.eigenvalue_layer = nn.Linear(1, 1, bias=False)
        with torch.no_grad():
            self.eigenvalue_layer.weight.fill_(initial_eigenvalue)
        
        # Network layers - ONLY concatenate eigenvalue at the first layer
        self.fc1 = nn.Linear(input_dim + 1, hidden_dim)  # +1 for eigenvalue
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, 1)
        
    def forward(self, x):
        """
        Args:
            x: (N, 3) point cloud coordinates
        Returns:
            u: (N, 1) eigenfunction values
            eigenvalue: scalar learnable eigenvalue
        """
        # Use torch.clamp(..., min=0.0) to enforce non-negativity
        eigenvalue = torch.clamp(self.eigenvalue_layer(torch.ones(1, 1).to(x.device)), min=0.0)
        eigenvalue_expanded = eigenvalue.expand(x.shape[0], 1)  # (N, 1)
        
        # Forward pass - concatenate eigenvalue ONLY at the input
        h = torch.cat([x, eigenvalue_expanded], dim=1)  # (N, input_dim+1)
        h = self.activation(self.fc1(h))
        
        # No eigenvalue concatenation in hidden layers
        h = self.activation(self.fc2(h))
        h = self.activation(self.fc3(h))
        
        u = self.fc4(h)
        
        return u, eigenvalue


# ============ LOSS COMPUTATION ============
def compute_eigenvalue_loss(u, eigenvalue, L_torch, M_torch):
    """
    Compute residual for Lu = λMu using discrete operators.
    Returns:
        loss: MSE of residual ||Lu - λMu||²
    """
    u_flat = u.squeeze()  # (N,)
    
    # Compute Lu and λMu
    Lu = torch.sparse.mm(L_torch, u_flat.unsqueeze(1)).squeeze()
    Mu = torch.sparse.mm(M_torch, u_flat.unsqueeze(1)).squeeze()
    residual = Lu - eigenvalue * Mu
    
    return torch.mean(residual ** 2), Lu, Mu


def compute_normalization_loss(u, M_torch):
    """
    Enforce u^T M u = 1 (mass-matrix normalization).
    """
    u_flat = u.squeeze()    
    Mu = torch.sparse.mm(M_torch, u_flat.unsqueeze(1)).squeeze()
    norm_squared = torch.dot(u_flat, Mu)

    return (norm_squared - 1.0) ** 2


def compute_orthogonality_loss(u, previous_eigenfunctions, M_torch):
    """
    Enforce u ⊥ u_i for all previously found eigenfunctions.
    Uses M-orthogonality: u^T M u_i = 0
    """
    if not previous_eigenfunctions:
        return torch.tensor(0.0, device=M_torch.device)
    
    u_flat = u.squeeze()
    ortho_loss = torch.tensor(0.0, device=M_torch.device)
    
    
    for u_prev in previous_eigenfunctions:
        u_prev_flat = u_prev.squeeze()
        # Compute u^T M u_prev
        Mu_prev = torch.sparse.mm(M_torch, u_prev_flat.unsqueeze(1)).squeeze()
        overlap = torch.dot(u_flat, Mu_prev)
        ortho_loss += overlap ** 2
    
    return ortho_loss


# **CHANGE 2: Reduced Margin in Ordering Loss**
def compute_ordering_loss(eigenvalue, previous_eigenvalue, margin=1e-4):
    """
    Enforce λ_i > λ_{i-1} + margin.
    """
    if previous_eigenvalue is None:
        return torch.tensor(0.0, device=eigenvalue.device)
    
    # max(0, λ_{i-1} + margin - λ_i)^2
    loss = torch.max(
        torch.tensor(0.0, device=eigenvalue.device), 
        previous_eigenvalue.detach() + margin - eigenvalue
    ) ** 2
    return loss.squeeze()


# ============ UTILITY FUNCTIONS ============
def sparse_to_torch(sparse_matrix, device):
    """Convert scipy sparse matrix to torch sparse tensor."""
    # Handle any scipy sparse format
    if sparse.issparse(sparse_matrix):
        coo = sparse_matrix.tocoo()
    else:
        raise ValueError(f"Expected scipy sparse matrix, got {type(sparse_matrix)}")
    
    indices = torch.LongTensor(np.vstack((coo.row, coo.col)))
    values = torch.FloatTensor(coo.data)
    shape = coo.shape
    return torch.sparse_coo_tensor(indices, values, shape).to(device)


def initialize_weights(m):
    """Reinitialize network weights for finding next eigenfunction."""
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)


# ============ MAIN TRAINING FUNCTION ============
def train_eigenvalue_pinn(X, L, M, hidden_dim=128, epochs=20000, 
                          lr=1e-3, num_eigenfunctions=10, 
                          convergence_threshold=1e-7,
                          ortho_weight=100.0,  
                          order_weight=10.0,   # **CHANGE 3: Reduced Order Weight from 100.0 to 10.0**
                          initial_lambda_step=0.2): # **CHANGE 4: Reduced Initial Lambda Step from 0.3 to 0.2**
    """
    Train PINN to solve Lu = λMu eigenvalue problem.
    """
    
    # Prepare inputs
    X_torch = torch.as_tensor(X, dtype=torch.float32, device=device)
    X_torch.requires_grad = True

    # Pre-convert sparse matrices
    L_torch = sparse_to_torch(L, device) if sparse.issparse(L) else L.to(device)
    M_torch = sparse_to_torch(M, device) if sparse.issparse(M) else M.to(device)
    
    # Storage for results
    eigenvalues = []
    eigenfunctions = []
    all_models = []
    loss_history = {'total': [], 'eig': [], 'norm': [], 'ortho': [], 'order': []}
    
    print(f"Training on device: {device}")
    print(f"Point cloud size: {X.shape[0]} points")
    
    # ============ ITERATIVE EIGENFUNCTION DISCOVERY ============
    for eig_idx in range(num_eigenfunctions):
        print(f"\n{'='*60}")
        print(f"Finding eigenfunction {eig_idx + 1}/{num_eigenfunctions}")
        print(f"Hyperparams: λ_step={initial_lambda_step:.2f}, hidden_dim={hidden_dim}, ortho_w={ortho_weight:.1f}, order_w={order_weight:.1f}")
        print(f"{'='*60}")
        
        # Determine previous eigenvalue and initial guess
        previous_eigenvalue = None
        if eig_idx == 0:
            initial_eigenvalue = 0.0
        else:
            previous_eigenvalue = torch.tensor(eigenvalues[-1], dtype=torch.float32, device=device)
            # Use the specified step size
            initial_eigenvalue = eigenvalues[-1] + initial_lambda_step 
            
        
        # Initialize network
        model = EigenfunctionNN(hidden_dim=hidden_dim, input_dim=X.shape[1], 
                               initial_eigenvalue=initial_eigenvalue).to(device)
        
        # Reinitialize weights and biases
        model.apply(initialize_weights) 
        
        # Re-set initial eigenvalue guess, as initialize_weights overwrites it
        if previous_eigenvalue is not None:
             with torch.no_grad():
                model.eigenvalue_layer.weight.fill_(initial_eigenvalue)


        optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.999, 0.9999))
        
        # --- Convergence Tracking Setup ---
        best_model = None
        best_eig_loss = float('inf')
        ema_slope = 1.0
        prev_loss = None
        
        # --- Training Loop ---
        for epoch in trange(epochs, desc=f"Eigen {eig_idx+1}/{num_eigenfunctions}"):
            optimizer.zero_grad()
            
            # Forward pass
            u, eigenvalue = model(X_torch)
            
            # Compute losses
            eig_loss, _, _ = compute_eigenvalue_loss(u, eigenvalue, L_torch, M_torch)
            norm_loss = compute_normalization_loss(u, M_torch)
            ortho_loss = compute_orthogonality_loss(u, eigenfunctions, M_torch)
            order_loss = compute_ordering_loss(eigenvalue, previous_eigenvalue)
            
            # Total loss with weighting
            total_loss = eig_loss + norm_loss + \
                         ortho_weight * ortho_loss + \
                         order_weight * order_loss
            
            # Backward pass
            total_loss.backward()
            optimizer.step()
            
            # Convergence tracking (Use Eig loss for convergence)
            if prev_loss is not None:
                ema_slope = 0.75 * ema_slope + 0.25 * abs(prev_loss - eig_loss.item())
            prev_loss = eig_loss.item()

            # **CHANGE 5: Start convergence check later to ensure stability**
            if ema_slope < convergence_threshold * 10.0 and epoch > 8000: 
                print(f"Converged at epoch {epoch}")
                break

            # Save best model based on Eig loss
            if eig_loss.item() < best_eig_loss:
                best_eig_loss = eig_loss.item()
                best_model = copy.deepcopy(model)

            # Logging every 500 epochs
            if epoch % 500 == 0:
                print(f"Epoch {epoch:5d} | λ={eigenvalue.item():.6f} | "
                      f"Eig={eig_loss.item():.2e} | Norm={norm_loss.item():.2e} | "
                      f"Ortho={ortho_loss.item():.2e} | Order={order_loss.item():.2e} | "
                      f"EMA slope={ema_slope:.2e}")

            # Lightweight history
            if epoch % 100 == 0:
                loss_history['total'].append(total_loss.item())
                loss_history['eig'].append(eig_loss.item())
                loss_history['norm'].append(norm_loss.item())
                loss_history['ortho'].append(ortho_loss.item())
                loss_history['order'].append(order_loss.item())

        # Store results
        if best_model is None:
             print("Warning: No best model found (first epoch was best, using final model).")
             best_model = model
             
        with torch.no_grad():
            u_final, eigenvalue_final = best_model(X_torch)
            eigenvalues.append(eigenvalue_final.item())
            
            # Re-normalize the eigenfunction after training (u^T M u = 1)
            u_flat = u_final.squeeze()
            Mu = torch.sparse.mm(M_torch, u_flat.unsqueeze(1)).squeeze()
            norm_squared = torch.dot(u_flat, Mu)
            u_normalized = u_final / torch.sqrt(norm_squared)

            eigenfunctions.append(u_normalized.detach())
            all_models.append(best_model)

        print(f"\nFound eigenvalue: λ_{eig_idx} = {eigenvalue_final.item():.6f}")
    
    return eigenvalues, eigenfunctions, all_models, loss_history