In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import linalg
from Mesh import Mesh
import matplotlib.pyplot as plt
from pathlib import Path

# ============================================================================
# CONFIGURATION
# ============================================================================

torch.set_default_dtype(torch.double)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training configuration
CONFIG = {
    'k': 50,  # Number of eigenmodes
    'max_epochs': 150_000,  # Reduced from 300k due to better architecture
    'stage1_epochs': 50_000,  # Focus on low modes + orthogonality
    'stage2_epochs': 100_000,  # Fine-tune all modes
    'lr_stage1': 0.01,
    'lr_stage2': 0.001,
    'lr_min': 0.0001,
    'print_every': 2000,
    'checkpoint_every': 5000,
    'grad_clip': 1.0,
    'accumulation_steps': 1,  # Set to 4 if you have memory issues
    'early_stopping_patience': 50_000,
    'early_stopping_threshold': 0.999,  # 0.1% improvement needed
}

# ============================================================================
# DATA LOADING AND PREPROCESSING
# ============================================================================

print("Loading mesh...")
m = Mesh('data/coil_1.2_MM.obj')

# Normalize vertices
centroid = m.verts.mean(0)
std_max = m.verts.std(0).max()
verts_new = (m.verts - centroid) / std_max
m = Mesh(verts=verts_new, connectivity=m.connectivity)

print('Computing Laplacian...')
K, M = m.computeLaplacian()

print('Computing eigenvalues (reference)...')
eigvals, eigvecs = linalg.eigh(K, M)

# Convert to torch tensors
K = torch.from_numpy(K).to(device)
M = torch.from_numpy(M).to(device)
X = torch.from_numpy(m.verts).to(device)
eigvals_torch = torch.from_numpy(eigvals[:CONFIG['k']]).to(device)
eigvecs_torch = torch.from_numpy(eigvecs[:, :CONFIG['k']]).to(device)

k = CONFIG['k']
N = X.shape[0]

# ============================================================================
# MATRIX CONDITIONING
# ============================================================================

print("\n=== Matrix Conditioning ===")
epsilon = 1e-4
K_reg = K + epsilon * torch.eye(N, device=device)
K_scale = torch.norm(K_reg, p='fro')

# Normalize both matrices by same scale
K = K_reg / K_scale
M = M / K_scale

print(f"Regularization: ε={epsilon}")
print(f"Condition number: {torch.linalg.cond(K).item():.2e}")
print(f"Scaling factor: {K_scale:.2e}")

# ============================================================================
# MODEL DEFINITION
# ============================================================================

class ImprovedMLP(nn.Module):
    """
    Enhanced MLP with:
    - Fourier feature encoding
    - Layer normalization
    - Residual connections
    - Increased capacity
    """
    def __init__(self, in_dim=3, out_dim=50, hidden=[256, 256, 128, 128], 
                 use_fourier=True, n_fourier_features=3):
        super().__init__()
        
        self.use_fourier = use_fourier
        self.n_fourier = n_fourier_features
        
        # Learnable frequency scales for Fourier features
        if use_fourier:
            self.freq_scale = nn.Parameter(torch.ones(in_dim) * 10.0)
            input_dim = in_dim * (1 + 2 * n_fourier_features)  # original + sin + cos for each freq
        else:
            input_dim = in_dim
        
        # Build network with normalization and residual capability
        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        
        last_dim = input_dim
        for h in hidden:
            self.layers.append(nn.Linear(last_dim, h))
            self.norms.append(nn.LayerNorm(h))
            last_dim = h
        
        # Output layer
        self.output = nn.Linear(last_dim, out_dim)
        
    def forward(self, x):
        # Fourier feature mapping
        if self.use_fourier:
            features = [x]
            for i in range(self.n_fourier):
                freq = self.freq_scale * (2 ** i)  # Multiple frequency scales
                features.append(torch.sin(freq * x))
                features.append(torch.cos(freq * x))
            h = torch.cat(features, dim=-1)
        else:
            h = x
        
        # Forward pass with residual connections
        for i, (layer, norm) in enumerate(zip(self.layers, self.norms)):
            h_new = layer(h)
            h_new = norm(h_new)
            h_new = torch.nn.functional.silu(h_new)
            
            # Residual connection (if dimensions match)
            if i > 0 and h.shape[-1] == h_new.shape[-1]:
                h = h_new + 0.1 * h  # Scaled residual
            else:
                h = h_new
        
        return self.output(h)

# ============================================================================
# ORTHOGONALIZATION UTILITIES
# ============================================================================

def gram_schmidt_batch(U, M, num_steps=1):
    """
    M-orthogonalize columns of U using Gram-Schmidt process
    
    Args:
        U: (N, k) matrix to orthogonalize
        M: (N, N) mass matrix
        num_steps: number of GS iterations (>1 for stability)
    """
    Q = U.clone()
    
    for step in range(num_steps):
        for i in range(Q.shape[1]):
            # Orthogonalize against all previous vectors
            for j in range(i):
                # M-inner product
                numerator = Q[:, j] @ (M @ Q[:, i])
                denominator = Q[:, j] @ (M @ Q[:, j]) + 1e-10
                proj = numerator / denominator
                Q[:, i] = Q[:, i] - proj * Q[:, j]
            
            # M-normalize
            norm = torch.sqrt(Q[:, i] @ (M @ Q[:, i]) + 1e-10)
            Q[:, i] = Q[:, i] / norm
    
    return Q

# ============================================================================
# LOSS COMPUTATION
# ============================================================================

def compute_loss(U, K, M, eigvals_ref, epoch, config):
    """
    Comprehensive loss function with multiple objectives
    UNSUPERVISED: eigvals_ref only used for logging/comparison, NOT for training!
    """
    # Compute Rayleigh quotients
    UMU = U.T @ (M @ U)
    UKU = U.T @ (K @ U)
    
    # Extract diagonal (eigenvalue estimates)
    eigenvalues_approx = torch.diag(UKU)
    sorted_eigs, sort_idx = torch.sort(eigenvalues_approx)
    
    # ========================================================================
    # 1. EIGENVALUE LOSSES (UNSUPERVISED)
    # ========================================================================
    
    # 1a. First eigenvalue should be ZERO (rigid body mode)
    #     This is a physical constraint for Laplacian operators
    zero_eig_loss = sorted_eigs[0] ** 2
    
    # 1b. Trace loss (minimize sum - finds smallest modes)
    #     But exclude first mode since we're already constraining it to zero
    trace_loss = torch.sum(sorted_eigs[1:])
    
    # 1c. Diversity loss (encourage eigenvalue separation)
    # Use adaptive gaps based on current eigenvalue scale
    gaps = sorted_eigs[1:] - sorted_eigs[:-1]
    # Target minimum gap: 1% of current eigenvalue spread
    current_spread = sorted_eigs[-1] - sorted_eigs[1] + 1e-8  # Exclude first eigenvalue
    min_gap = current_spread * 0.01 / k  # Adaptive to current scale
    diversity_loss = torch.sum(torch.relu(min_gap - gaps))
    
    # 1d. Off-diagonal penalty (enforce diagonalization)
    off_diag_mask = 1 - torch.eye(k, device=device, dtype=torch.float64)
    off_diag_loss = torch.sum((UKU * off_diag_mask) ** 2)
    
    # ========================================================================
    # 2. ORTHOGONALITY LOSS
    # ========================================================================
    
    identity_k = torch.eye(k, device=device, dtype=torch.float64)
    orth_loss = torch.norm(UMU - identity_k, p='fro') ** 2
    
    # ========================================================================
    # 3. ORDERING LOSS (maintain λ₁ ≤ λ₂ ≤ ... ≤ λₖ)
    # ========================================================================
    
    ordering_loss = torch.sum(torch.relu(sorted_eigs[:-1] - sorted_eigs[1:])) / k
    
    # ========================================================================
    # 4. SMOOTHNESS REGULARIZATION (optional, for later stages)
    # ========================================================================
    
    if epoch > config['stage1_epochs']:
        U_sorted = U[:, sort_idx]
        # Penalize large jumps between consecutive modes
        smoothness_loss = torch.mean(torch.sum((U_sorted[:, 1:] - U_sorted[:, :-1]) ** 2, dim=0))
    else:
        smoothness_loss = torch.tensor(0.0, device=device)
    
    # ========================================================================
    # ADAPTIVE WEIGHTING (FULLY UNSUPERVISED)
    # ========================================================================
    
    # Stage 1: Focus on orthogonality and finding low modes
    if epoch <= config['stage1_epochs']:
        w_zero = 100.0       # CRITICAL: Force first eigenvalue to zero
        w_trace = 5.0        # Strong push to find smallest eigenvalues
        w_div = 2.0          # Encourage separation
        w_offdiag = 10.0     # Strong diagonalization
        w_orth = 10.0        # Strong orthogonality
        w_order = 0.5        # Maintain ordering
        w_smooth = 0.0
    # Stage 2: Fine-tune, relax some constraints
    else:
        w_zero = 50.0        # Still enforce, but can relax a bit
        w_trace = 2.0        # Still minimize, but less aggressive
        w_div = 1.0          # Maintain separation
        w_offdiag = 15.0     # Even stronger diagonalization
        w_orth = 5.0         # Maintain orthogonality
        w_order = 0.2        # Maintain ordering
        w_smooth = 0.1       # Add smoothness
    
    # Total loss (NO SUPERVISION)
    total_loss = (w_zero * zero_eig_loss +
                  w_trace * trace_loss +
                  w_div * diversity_loss +
                  w_offdiag * off_diag_loss +
                  w_orth * orth_loss +
                  w_order * ordering_loss +
                  w_smooth * smoothness_loss)
    
    # Return loss and components for logging
    components = {
        'zero_eig': zero_eig_loss.item(),
        'trace': trace_loss.item(),
        'diversity': diversity_loss.item(),
        'offdiag': off_diag_loss.item(),
        'orth': orth_loss.item(),
        'order': ordering_loss.item(),
        'smooth': smoothness_loss.item(),
    }
    
    return total_loss, components

# ============================================================================
# MODEL INITIALIZATION
# ============================================================================

print("\n=== Initializing Model ===")
model = ImprovedMLP(in_dim=3, out_dim=k, use_fourier=True).double().to(device)

# Xavier initialization for hidden layers, small weights for output
for name, p in model.named_parameters():
    if 'weight' in name:
        if 'output' in name:
            nn.init.normal_(p.data, std=1e-4)
        elif 'freq_scale' not in name and p.ndim >= 2:  # Only for 2D+ tensors
            nn.init.xavier_uniform_(p.data)
    elif 'bias' in name:
        nn.init.zeros_(p.data)

print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Fourier features: {model.use_fourier}")
print(f"Architecture: {[l.out_features for l in model.layers]} -> {k}")

# ============================================================================
# NO PRE-TRAINING - Pure unsupervised learning!
# ============================================================================

print("\n=== Unsupervised Mode: No pre-training ===")
print("Model will discover eigenmodes from scratch using only K and M matrices\n")

# ============================================================================
# TRAINING SETUP
# ============================================================================

# Stage 1: High learning rate, focus on orthogonality
optimizer = optim.AdamW(model.parameters(), 
                        lr=CONFIG['lr_stage1'], 
                        weight_decay=1e-5)

# Cosine annealing with restarts
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=10000, T_mult=2, eta_min=CONFIG['lr_min']
)

# Tracking
loss_history = []
best_loss = float('inf')
no_improve_count = 0
checkpoint_dir = Path('checkpoints')
checkpoint_dir.mkdir(exist_ok=True)

# ============================================================================
# MAIN TRAINING LOOP
# ============================================================================

print("=" * 80)
print("TRAINING START")
print("=" * 80)
print(f"Epochs: {CONFIG['max_epochs']:,} | Device: {device}")
print(f"Stage 1: 0-{CONFIG['stage1_epochs']:,} (orthogonality focus)")
print(f"Stage 2: {CONFIG['stage1_epochs']:,}-{CONFIG['max_epochs']:,} (eigenvalue tuning)")
print("=" * 80 + "\n")

for epoch in range(1, CONFIG['max_epochs'] + 1):
    
    # ========================================================================
    # STAGE TRANSITION
    # ========================================================================
    
    if epoch == CONFIG['stage1_epochs'] + 1:
        print("\n" + "=" * 80)
        print("ENTERING STAGE 2: Fine-tuning eigenvalues")
        print("=" * 80 + "\n")
        
        # Reduce learning rate for stage 2
        for param_group in optimizer.param_groups:
            param_group['lr'] = CONFIG['lr_stage2']
        
        # Reset scheduler for stage 2
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=CONFIG['stage2_epochs'],
            eta_min=CONFIG['lr_min']
        )
    
    # ========================================================================
    # TRAINING STEP
    # ========================================================================
    
    model.train()
    
    # Forward pass
    U = model(X)
    
    # Periodic explicit orthogonalization (every 1000 steps in stage 1)
    if epoch <= CONFIG['stage1_epochs'] and epoch % 1000 == 0:
        with torch.no_grad():
            U = gram_schmidt_batch(U, M, num_steps=2)
        
        # Quick fine-tune to match orthogonalized output
        for _ in range(5):
            optimizer.zero_grad()
            U_pred = model(X)
            loss_proj = torch.mean((U_pred - U) ** 2)
            loss_proj.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['grad_clip'])
            optimizer.step()
    
    # Compute loss
    optimizer.zero_grad()
    U = model(X)
    loss, loss_components = compute_loss(U, K, M, eigvals_torch, epoch, CONFIG)
    
    # Backward pass
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['grad_clip'])
    
    # Optimizer step (with optional gradient accumulation)
    if epoch % CONFIG['accumulation_steps'] == 0:
        optimizer.step()
        scheduler.step()
    
    loss_history.append(loss.item())
    
    # ========================================================================
    # LOGGING
    # ========================================================================
    
    if epoch % CONFIG['print_every'] == 0 or epoch == 1:
        model.eval()
        with torch.no_grad():
            U = model(X)
            UKU = U.T @ (K @ U)
            UMU = U.T @ (M @ U)
            
            approx_eigs = torch.diag(UKU).cpu().numpy()
            approx_eigs.sort()
            
            # Compute errors
            abs_error = np.abs(approx_eigs - eigvals[:k])
            rel_error = abs_error / (np.abs(eigvals[:k]) + 1e-10)
            mean_rel_error = np.mean(rel_error)
            median_rel_error = np.median(rel_error)
            
            # Orthogonality check
            orth_residual = torch.norm(UMU - torch.eye(k, device=device, dtype=torch.float64), p='fro').item()
            
            current_lr = optimizer.param_groups[0]['lr']
            stage = "Stage1" if epoch <= CONFIG['stage1_epochs'] else "Stage2"
            
        print(f"[{stage}] Epoch {epoch:>6} | LR={current_lr:.6f}")
        print(f"  Loss={loss.item():.6f} | λ₁={loss_components['zero_eig']:.2e} | Trace={loss_components['trace']:.3f}")
        print(f"  Orth={orth_residual:.2e} | OffDiag={loss_components['offdiag']:.2e}")
        print(f"  MeanRelErr={mean_rel_error:.4%} | MedianRelErr={median_rel_error:.4%}")
        print(f"  λ∈[{approx_eigs[0]:.6f}, {approx_eigs[-1]:.6f}] | Spread={approx_eigs[-1]-approx_eigs[0]:.6f}")
        
        # Detailed comparison every 10k epochs
        if epoch % (CONFIG['print_every'] * 5) == 0:
            print(f"  First 5 - Pred: {approx_eigs[:5].round(6)}")
            print(f"  First 5 - True: {eigvals[:5].round(6)}")
            print(f"  Last  5 - Pred: {approx_eigs[-5:].round(6)}")
            print(f"  Last  5 - True: {eigvals[k-5:k].round(6)}")
        print()
    
    # ========================================================================
    # CHECKPOINTING & EARLY STOPPING
    # ========================================================================
    
    if epoch % CONFIG['checkpoint_every'] == 0:
        # Save checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item(),
            'loss_history': loss_history,
        }
        torch.save(checkpoint, checkpoint_dir / f'checkpoint_epoch_{epoch}.pt')
        
        # Check for improvement
        if loss.item() < best_loss * CONFIG['early_stopping_threshold']:
            best_loss = loss.item()
            no_improve_count = 0
            torch.save(checkpoint, checkpoint_dir / 'best_model.pt')
            print(f"  ✓ New best model saved (loss={best_loss:.6f})")
        else:
            no_improve_count += CONFIG['checkpoint_every']
        
        # Early stopping
        if no_improve_count >= CONFIG['early_stopping_patience']:
            print(f"\n⚠ Early stopping at epoch {epoch} (no improvement for {no_improve_count} epochs)")
            break

print("\n" + "=" * 80)
print("TRAINING COMPLETE")
print("=" * 80 + "\n")

# ============================================================================
# FINAL EVALUATION
# ============================================================================

print("Loading best model for evaluation...")
best_checkpoint = torch.load(checkpoint_dir / 'best_model.pt')
model.load_state_dict(best_checkpoint['model_state_dict'])

model.eval()
with torch.no_grad():
    U = model(X)
    
    # Final matrices
    final_UKU = U.T @ (K @ U)
    final_UMU = U.T @ (M @ U)
    
    # Eigenvalues
    final_eigs = torch.diag(final_UKU).cpu().numpy()
    final_eigs.sort()
    
    # Errors
    abs_error = np.abs(final_eigs - eigvals[:k])
    rel_error = abs_error / (np.abs(eigvals[:k]) + 1e-10)
    
    print("=" * 80)
    print("FINAL RESULTS")
    print("=" * 80)
    
    # Orthogonality
    identity_k = torch.eye(k, device=device, dtype=torch.float64)
    orth_residual = torch.norm(final_UMU - identity_k, p='fro').item()
    orth_diag = torch.diag(final_UMU).cpu().numpy()
    
    print(f"\n Orthogonality Quality:")
    print(f"  ||U^T M U - I||_F = {orth_residual:.2e}")
    print(f"  Diagonal range: [{orth_diag.min():.6f}, {orth_diag.max():.6f}] (target: 1.0)")
    
    # Rayleigh matrix
    rayleigh_diag = torch.diag(final_UKU).cpu().numpy()
    rayleigh_offdiag = (final_UKU - torch.diag(torch.diag(final_UKU))).cpu().numpy()
    
    print(f"\n Rayleigh Quotient Matrix:")
    print(f"  Diagonal norm: {np.linalg.norm(rayleigh_diag):.6f}")
    print(f"  Off-diagonal norm: {np.linalg.norm(rayleigh_offdiag, 'fro'):.2e} (should be ≈0)")
    
    # Eigenvalue comparison
    print(f"\n Eigenvalue Comparison (First 10):")
    print(f"{'Mode':<6} {'Predicted':<14} {'Reference':<14} {'Abs Error':<14} {'Rel Error':<12}")
    print("-" * 70)
    for i in range(min(10, k)):
        print(f"{i+1:<6} {final_eigs[i]:<14.8f} {eigvals[i]:<14.8f} "
              f"{abs_error[i]:<14.8f} {rel_error[i]:<12.4%}")
    
    print(f"\n Eigenvalue Comparison (Last 10):")
    print(f"{'Mode':<6} {'Predicted':<14} {'Reference':<14} {'Abs Error':<14} {'Rel Error':<12}")
    print("-" * 70)
    for i in range(max(0, k-10), k):
        print(f"{i+1:<6} {final_eigs[i]:<14.8f} {eigvals[i]:<14.8f} "
              f"{abs_error[i]:<14.8f} {rel_error[i]:<12.4%}")
    
    # Statistics
    print(f"\n Overall Statistics ({k} modes):")
    print(f"  Mean Absolute Error:     {np.mean(abs_error):.8f}")
    print(f"  Mean Relative Error:     {np.mean(rel_error):.4%}")
    print(f"  Median Relative Error:   {np.median(rel_error):.4%}")
    print(f"  Max Relative Error:      {np.max(rel_error):.4%}")
    print(f"  Modes with <1% error:    {np.sum(rel_error < 0.01)}/{k}")
    print(f"  Modes with <5% error:    {np.sum(rel_error < 0.05)}/{k}")
    print(f"  Modes with <10% error:   {np.sum(rel_error < 0.10)}/{k}")
    
    print("\n" + "=" * 80)

# Save final model
torch.save({
    'model_state_dict': model.state_dict(),
    'config': CONFIG,
    'final_eigenvalues': final_eigs,
    'reference_eigenvalues': eigvals[:k],
    'errors': {'absolute': abs_error, 'relative': rel_error},
}, 'final_model.pt')

print("\n✓ Model saved to 'final_model.pt'")
print("✓ Checkpoints saved to 'checkpoints/' directory")

Loading mesh...
Computing Laplacian...
Computing eigenvalues (reference)...

=== Matrix Conditioning ===
Regularization: ε=0.0001
Condition number: 1.47e+05
Scaling factor: 1.58e+02

=== Initializing Model ===
Parameters: 128,821
Fourier features: True
Architecture: [256, 256, 128, 128] -> 50

=== Unsupervised Mode: No pre-training ===
Model will discover eigenmodes from scratch using only K and M matrices

TRAINING START
Epochs: 150,000 | Device: cuda
Stage 1: 0-50,000 (orthogonality focus)
Stage 2: 50,000-150,000 (eigenvalue tuning)

[Stage1] Epoch      1 | LR=0.010000
  Loss=500.001621 | λ₁=1.55e-11 | Trace=0.000
  Orth=7.01e+00 | OffDiag=4.61e-09
  MeanRelErr=31077915.3729% | MedianRelErr=99.5882%
  λ∈[0.001559, 0.036197] | Spread=0.034638

[Stage1] Epoch   2000 | LR=0.009055
  Loss=377.213065 | λ₁=1.75e-06 | Trace=11.711
  Orth=5.63e+00 | OffDiag=8.37e-02
  MeanRelErr=26164066.2776% | MedianRelErr=99.3287%
  λ∈[0.001313, 0.949928] | Spread=0.948616

[Stage1] Epoch   4000 | LR=0.00