In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import linalg
from Mesh import Mesh
import matplotlib.pyplot as plt
from pathlib import Path

# ============================================================================
# CONFIGURATION
# ============================================================================

torch.set_default_dtype(torch.double)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CONFIG = {
    'k': 50,
    'max_epochs': 150_000,
    'stage1_epochs': 50_000,
    'lr_stage1': 0.01,
    'lr_stage2': 0.001,
    'lr_min': 0.0001,
    'print_every': 2000,
    'checkpoint_every': 5000,
    'grad_clip': 1.0,
    'early_stopping_patience': 30_000,
    'early_stopping_threshold': 0.999,
}

# ============================================================================
# DATA LOADING AND PREPROCESSING
# ============================================================================

print("Loading mesh...")
m = Mesh('data/coil_1.2_MM.obj')

# Normalize vertices
centroid = m.verts.mean(0)
std_max = m.verts.std(0).max()
verts_new = (m.verts - centroid) / std_max
m = Mesh(verts=verts_new, connectivity=m.connectivity)

print('Computing Laplacian...')
K, M = m.computeLaplacian()

print('Computing reference eigenvalues...')
eigvals, eigvecs = linalg.eigh(K, M)

# Convert to torch
K = torch.from_numpy(K).to(device)
M = torch.from_numpy(M).to(device)
X = torch.from_numpy(m.verts).to(device)

k = CONFIG['k']
N = X.shape[0]

# ============================================================================
# MATRIX CONDITIONING (Your careful approach)
# ============================================================================

print("\n=== Matrix Diagnostics ===")
print(f"N = {N}, k = {k}, ratio = {N/k:.1f}")
print(f"Condition number of K: {torch.linalg.cond(K).item():.2e}")
print(f"Condition number of M: {torch.linalg.cond(M).item():.2e}")
print(f"Target eigenvalue range: [{eigvals[0]:.6f}, {eigvals[k-1]:.6f}]")

if eigvals[k-1] / (eigvals[1] + 1e-10) > 100:
    print("⚠ Large eigenvalue spread detected")

# Regularization
epsilon = 1e-4
K_reg = K + epsilon * torch.eye(N, device=device)
print(f"\nRegularization: ε={epsilon}")
print(f"Condition number after reg: {torch.linalg.cond(K_reg).item():.2e}")

# Separate normalization (your approach - more careful)
K_scale = torch.norm(K_reg, p='fro')
M_scale = torch.norm(M, p='fro')

K = K_reg / K_scale
M = M / M_scale

print(f"\nNormalization:")
print(f"  K_scale = {K_scale.item():.2e}")
print(f"  M_scale = {M_scale.item():.2e}")
print(f"  ||K||_F = {torch.norm(K, p='fro').item():.4f}")
print(f"  ||M||_F = {torch.norm(M, p='fro').item():.4f}")

# Sanity checks
print(f"\nSanity checks:")
print(f"  K symmetric: {torch.allclose(K, K.T, atol=1e-6)}")
print(f"  M positive definite: {torch.all(torch.linalg.eigvalsh(M) > 0)}")
print(f"  No NaN/Inf: {not (torch.isnan(K).any() or torch.isinf(K).any())}")

print("\n" + "="*80)

# ============================================================================
# MODEL DEFINITION (My improved architecture)
# ============================================================================

class HybridMLP(nn.Module):
    """
    Combines:
    - Fourier features for high-frequency modes
    - LayerNorm for training stability
    - Residual connections for gradient flow
    - Wider architecture for capacity
    """
    def __init__(self, in_dim=3, out_dim=50, hidden=[256, 256, 128, 128],
                 use_fourier=True, n_fourier_features=3):
        super().__init__()
        
        self.use_fourier = use_fourier
        self.n_fourier = n_fourier_features
        
        # Learnable Fourier frequencies
        if use_fourier:
            self.freq_scale = nn.Parameter(torch.ones(in_dim) * 10.0)
            input_dim = in_dim * (1 + 2 * n_fourier_features)
        else:
            input_dim = in_dim
        
        # Network layers
        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        
        last_dim = input_dim
        for h in hidden:
            self.layers.append(nn.Linear(last_dim, h))
            self.norms.append(nn.LayerNorm(h))
            last_dim = h
        
        self.output = nn.Linear(last_dim, out_dim)
        
    def forward(self, x):
        # Fourier feature mapping
        if self.use_fourier:
            features = [x]
            for i in range(self.n_fourier):
                freq = self.freq_scale * (2 ** i)
                features.append(torch.sin(freq * x))
                features.append(torch.cos(freq * x))
            h = torch.cat(features, dim=-1)
        else:
            h = x
        
        # Forward with residuals and normalization
        for i, (layer, norm) in enumerate(zip(self.layers, self.norms)):
            h_new = layer(h)
            h_new = norm(h_new)
            h_new = torch.nn.functional.silu(h_new)
            
            # Residual connection (scaled)
            if i > 0 and h.shape[-1] == h_new.shape[-1]:
                h = h_new + 0.1 * h
            else:
                h = h_new
        
        return self.output(h)

# ============================================================================
# MODEL INITIALIZATION
# ============================================================================

print("=== Initializing Model ===")
model = HybridMLP(in_dim=3, out_dim=k, use_fourier=True).double().to(device)

# Careful initialization
for name, p in model.named_parameters():
    if 'weight' in name:
        if 'output' in name:
            nn.init.normal_(p.data, std=1e-4)
        elif 'freq_scale' not in name and p.ndim >= 2:
            nn.init.xavier_uniform_(p.data)
    elif 'bias' in name:
        nn.init.zeros_(p.data)

print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Architecture: {[l.out_features for l in model.layers]} -> {k}")
print(f"Fourier features: Enabled ({model.n_fourier} frequencies)")
print()

# ============================================================================
# LOSS COMPUTATION (Hybrid approach)
# ============================================================================

def compute_loss(U, U_orth, K, M, rayleigh_matrix, B, S, eigvals_ref, epoch, config):
    """
    Hybrid loss combining:
    - Your SVD stability monitoring
    - My zero eigenvalue constraint
    - Adaptive weighting for two-stage training
    """
    k = U.shape[1]
    identity_k = torch.eye(k, device=device, dtype=torch.float64)
    
    # Extract eigenvalues
    eigenvalues_approx = torch.diag(rayleigh_matrix)
    sorted_eigs, sort_idx = torch.sort(eigenvalues_approx)
    
    # ========================================================================
    # 1. ZERO EIGENVALUE CONSTRAINT (Critical for first mode)
    # ========================================================================
    zero_eig_loss = sorted_eigs[0] ** 2
    
    # ========================================================================
    # 2. TRACE LOSS (Minimize sum, excluding first eigenvalue)
    # ========================================================================
    trace_loss = torch.sum(sorted_eigs[1:]) / (k - 1)
    
    # ========================================================================
    # 3. DIVERSITY LOSS (Adaptive gap)
    # ========================================================================
    gaps = sorted_eigs[1:] - sorted_eigs[:-1]
    current_spread = sorted_eigs[-1] - sorted_eigs[1] + 1e-8
    min_gap = current_spread * 0.005 / k  # 0.5% of spread
    diversity_loss = torch.sum(torch.relu(min_gap - gaps)) / (k - 1)
    
    # ========================================================================
    # 4. OFF-DIAGONAL PENALTY (Enforce diagonalization)
    # ========================================================================
    off_diag_mask = 1 - identity_k
    off_diag_loss = torch.sum((rayleigh_matrix * off_diag_mask) ** 2) / (k * (k-1))
    
    # ========================================================================
    # 5. ORTHOGONALITY RESIDUAL (Should be small due to SVD)
    # ========================================================================
    B_orth = U_orth.T @ (M @ U_orth)
    orth_loss = torch.norm(B_orth - identity_k, p='fro') ** 2
    
    # ========================================================================
    # 6. ORDERING LOSS (Maintain sorted eigenvalues)
    # ========================================================================
    ordering_loss = torch.sum(torch.relu(sorted_eigs[:-1] - sorted_eigs[1:])) / k
    
    # ========================================================================
    # 7. STABILITY LOSS (Your SVD monitoring)
    # ========================================================================
    S_ratio = S.max() / (S.min() + 1e-10)
    stability_loss = torch.relu(S_ratio - 1e3) / 1e3
    
    # ========================================================================
    # 8. SMOOTHNESS REGULARIZATION (Stage 2 only)
    # ========================================================================
    if epoch > config['stage1_epochs']:
        U_sorted = U_orth[:, sort_idx]
        smoothness_loss = torch.mean(torch.sum((U_sorted[:, 1:] - U_sorted[:, :-1]) ** 2, dim=0))
    else:
        smoothness_loss = torch.tensor(0.0, device=device)
    
    # ========================================================================
    # ADAPTIVE WEIGHTING (Two-stage)
    # ========================================================================
    if epoch <= config['stage1_epochs']:
        # Stage 1: Establish orthogonality, find low modes, force λ₁=0
        w_zero = 100.0
        w_trace = 5.0
        w_div = 2.0
        w_offdiag = 10.0
        w_orth = 5.0
        w_order = 0.5
        w_stability = 0.1
        w_smooth = 0.0
    else:
        # Stage 2: Fine-tune all eigenvalues, maintain constraints
        w_zero = 50.0
        w_trace = 2.0
        w_div = 1.0
        w_offdiag = 15.0
        w_orth = 2.0
        w_order = 0.2
        w_stability = 0.1
        w_smooth = 0.1
    
    # Total loss
    total_loss = (w_zero * zero_eig_loss +
                  w_trace * trace_loss +
                  w_div * diversity_loss +
                  w_offdiag * off_diag_loss +
                  w_orth * orth_loss +
                  w_order * ordering_loss +
                  w_stability * stability_loss +
                  w_smooth * smoothness_loss)
    
    # Components for logging
    components = {
        'zero_eig': zero_eig_loss.item(),
        'trace': trace_loss.item(),
        'diversity': diversity_loss.item(),
        'offdiag': off_diag_loss.item(),
        'orth': orth_loss.item(),
        'order': ordering_loss.item(),
        'stability': stability_loss.item(),
        'smooth': smoothness_loss.item(),
        'svd_cond': S_ratio.item(),
    }
    
    return total_loss, components

# ============================================================================
# TRAINING SETUP
# ============================================================================

optimizer = optim.AdamW(model.parameters(), lr=CONFIG['lr_stage1'], weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=10000, T_mult=2, eta_min=CONFIG['lr_min']
)

loss_history = []
best_loss = float('inf')
no_improve_count = 0
checkpoint_dir = Path('checkpoints')
checkpoint_dir.mkdir(exist_ok=True)

# ============================================================================
# TRAINING LOOP (Your SVD orthogonalization + My improvements)
# ============================================================================

print("=" * 80)
print("HYBRID TRAINING START")
print("=" * 80)
print(f"Strategy: SVD orthogonalization + Fourier features + Two-stage training")
print(f"Stage 1: 0-{CONFIG['stage1_epochs']:,} | Stage 2: {CONFIG['stage1_epochs']:,}-{CONFIG['max_epochs']:,}")
print("=" * 80 + "\n")

for epoch in range(1, CONFIG['max_epochs'] + 1):
    
    # ========================================================================
    # STAGE TRANSITION
    # ========================================================================
    if epoch == CONFIG['stage1_epochs'] + 1:
        print("\n" + "=" * 80)
        print("STAGE 2: Fine-tuning all eigenvalues")
        print("=" * 80 + "\n")
        
        for param_group in optimizer.param_groups:
            param_group['lr'] = CONFIG['lr_stage2']
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=CONFIG['max_epochs'] - CONFIG['stage1_epochs'],
            eta_min=CONFIG['lr_min']
        )
    
    # ========================================================================
    # FORWARD PASS WITH SVD ORTHOGONALIZATION (Your approach)
    # ========================================================================
    model.train()
    
    U = model(X)  # (N, k)
    
    # SVD-based M-orthogonalization
    B = U.T @ (M @ U)  # Gram matrix
    V, S, _ = torch.linalg.svd(B)
    
    # Compute B^(-1/2)
    S_inv_sqrt = torch.diag_embed(1.0 / torch.sqrt(torch.clamp(S, min=1e-7)))
    B_inv_sqrt = V @ S_inv_sqrt @ V.T
    U_orth = U @ B_inv_sqrt  # M-orthonormal basis
    
    # Rayleigh quotient
    rayleigh_matrix = U_orth.T @ (K @ U_orth)
    
    # ========================================================================
    # LOSS COMPUTATION
    # ========================================================================
    optimizer.zero_grad()
    loss, loss_components = compute_loss(
        U, U_orth, K, M, rayleigh_matrix, B, S, eigvals, epoch, CONFIG
    )
    
    # ========================================================================
    # BACKWARD PASS
    # ========================================================================
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['grad_clip'])
    optimizer.step()
    scheduler.step()
    
    loss_history.append(loss.item())
    
    # ========================================================================
    # LOGGING
    # ========================================================================
    if epoch % CONFIG['print_every'] == 0 or epoch == 1:
        model.eval()
        with torch.no_grad():
            # Get eigenvalues in original scale
            approx_eigs = torch.diag(rayleigh_matrix).cpu().numpy()
            approx_eigs.sort()
            approx_eigs_original = approx_eigs * (K_scale / M_scale).cpu().numpy()
            
            # Compute errors
            abs_error = np.abs(approx_eigs_original[:k] - eigvals[:k])
            rel_error = abs_error / (np.abs(eigvals[:k]) + 1e-10)
            mean_rel_error = np.mean(rel_error)
            median_rel_error = np.median(rel_error)
            
            # Orthogonality check
            B_orth = U_orth.T @ (M @ U_orth)
            orth_residual = torch.norm(B_orth - torch.eye(k, device=device, dtype=torch.float64), p='fro').item()
            
            current_lr = optimizer.param_groups[0]['lr']
            stage = "Stage1" if epoch <= CONFIG['stage1_epochs'] else "Stage2"
        
        print(f"[{stage}] Epoch {epoch:>6} | LR={current_lr:.6f}")
        print(f"  Loss={loss.item():.6f} | λ₁²={loss_components['zero_eig']:.2e} | Trace={loss_components['trace']:.3f}")
        print(f"  Orth={orth_residual:.2e} | OffDiag={loss_components['offdiag']:.2e} | SVD_cond={loss_components['svd_cond']:.2e}")
        print(f"  MeanRelErr={mean_rel_error:.4%} | MedianRelErr={median_rel_error:.4%}")
        print(f"  λ∈[{approx_eigs_original[0]:.6f}, {approx_eigs_original[-1]:.6f}]")
        
        if epoch % (CONFIG['print_every'] * 5) == 0:
            print(f"  First 5 - Pred: {approx_eigs_original[:5]}")
            print(f"  First 5 - True: {eigvals[:5]}")
            print(f"  Last  5 - Pred: {approx_eigs_original[-5:]}")
            print(f"  Last  5 - True: {eigvals[k-5:k]}")
        print()
    
    # ========================================================================
    # CHECKPOINTING & EARLY STOPPING
    # ========================================================================
    if epoch % CONFIG['checkpoint_every'] == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item(),
            'loss_history': loss_history,
        }
        torch.save(checkpoint, checkpoint_dir / f'checkpoint_epoch_{epoch}.pt')
        
        if loss.item() < best_loss * CONFIG['early_stopping_threshold']:
            best_loss = loss.item()
            no_improve_count = 0
            torch.save(checkpoint, checkpoint_dir / 'best_model.pt')
            print(f"  ✓ New best model (loss={best_loss:.6f})")
        else:
            no_improve_count += CONFIG['checkpoint_every']
        
        if no_improve_count >= CONFIG['early_stopping_patience']:
            print(f"\n⚠ Early stopping at epoch {epoch}")
            break

print("\n" + "=" * 80)
print("TRAINING COMPLETE")
print("=" * 80 + "\n")

# ============================================================================
# FINAL EVALUATION
# ============================================================================

print("Loading best model...")
best_checkpoint = torch.load(checkpoint_dir / 'best_model.pt')
model.load_state_dict(best_checkpoint['model_state_dict'])

model.eval()
with torch.no_grad():
    U_final = model(X)
    
    # Final SVD orthogonalization
    B_final = U_final.T @ (M @ U_final)
    V_final, S_final, _ = torch.linalg.svd(B_final)
    S_inv_sqrt_final = torch.diag_embed(1.0 / torch.sqrt(torch.clamp(S_final, min=1e-7)))
    B_inv_sqrt_final = V_final @ S_inv_sqrt_final @ V_final.T
    U_orth_final = U_final @ B_inv_sqrt_final
    
    # Final matrices
    final_rayleigh = U_orth_final.T @ (K @ U_orth_final)
    final_ortho = U_orth_final.T @ (M @ U_orth_final)
    
    # Eigenvalues
    final_eigs = torch.diag(final_rayleigh).cpu().numpy()
    final_eigs.sort()
    final_eigs_original = final_eigs * (K_scale / M_scale).cpu().numpy()
    
    # Errors
    abs_error = np.abs(final_eigs_original - eigvals[:k])
    rel_error = abs_error / (np.abs(eigvals[:k]) + 1e-10)
    
    print("=" * 80)
    print("FINAL RESULTS")
    print("=" * 80)
    
    # Orthogonality
    identity_k = torch.eye(k, device=device, dtype=torch.float64)
    orth_residual = torch.norm(final_ortho - identity_k, p='fro').item()
    orth_diag = torch.diag(final_ortho).cpu().numpy()
    
    print(f"\n Orthogonality Quality:")
    print(f"  ||U^T M U - I||_F = {orth_residual:.2e}")
    print(f"  Diagonal range: [{orth_diag.min():.6f}, {orth_diag.max():.6f}]")
    
    # Rayleigh matrix
    rayleigh_diag = torch.diag(final_rayleigh).cpu().numpy()
    rayleigh_offdiag = (final_rayleigh - torch.diag(torch.diag(final_rayleigh))).cpu().numpy()
    
    print(f"\n Rayleigh Matrix:")
    print(f"  Diagonal norm: {np.linalg.norm(rayleigh_diag):.6f}")
    print(f"  Off-diagonal norm: {np.linalg.norm(rayleigh_offdiag, 'fro'):.2e}")
    
    # SVD condition
    print(f"\n SVD Stability:")
    print(f"  Condition number: {S_final.max().item() / S_final.min().item():.2e}")
    
    # Eigenvalue comparison
    print(f"\n Eigenvalue Comparison (First 10):")
    print(f"{'Mode':<6} {'Predicted':<14} {'Reference':<14} {'Abs Error':<14} {'Rel Error':<12}")
    print("-" * 70)
    for i in range(min(10, k)):
        print(f"{i+1:<6} {final_eigs_original[i]:<14.8f} {eigvals[i]:<14.8f} "
              f"{abs_error[i]:<14.8f} {rel_error[i]:<12.4%}")
    
    print(f"\n Eigenvalue Comparison (Last 10):")
    print(f"{'Mode':<6} {'Predicted':<14} {'Reference':<14} {'Abs Error':<14} {'Rel Error':<12}")
    print("-" * 70)
    for i in range(max(0, k-10), k):
        print(f"{i+1:<6} {final_eigs_original[i]:<14.8f} {eigvals[i]:<14.8f} "
              f"{abs_error[i]:<14.8f} {rel_error[i]:<12.4%}")
    
    # Statistics
    print(f"\n Overall Statistics ({k} modes):")
    print(f"  Mean Absolute Error:     {np.mean(abs_error):.8f}")
    print(f"  Mean Relative Error:     {np.mean(rel_error):.4%}")
    print(f"  Median Relative Error:   {np.median(rel_error):.4%}")
    print(f"  Max Relative Error:      {np.max(rel_error):.4%}")
    print(f"  Modes with <1% error:    {np.sum(rel_error < 0.01)}/{k}")
    print(f"  Modes with <5% error:    {np.sum(rel_error < 0.05)}/{k}")
    print(f"  Modes with <10% error:   {np.sum(rel_error < 0.10)}/{k}")
    
    print("\n" + "=" * 80)

# Save final model
torch.save({
    'model_state_dict': model.state_dict(),
    'config': CONFIG,
    'final_eigenvalues': final_eigs_original,
    'reference_eigenvalues': eigvals[:k],
    'errors': {'absolute': abs_error, 'relative': rel_error},
    'normalization': {'K_scale': K_scale.item(), 'M_scale': M_scale.item()},
}, 'hybrid_model_final.pt')

print("\n✓ Model saved to 'hybrid_model_final.pt'")
print("✓ Checkpoints in 'checkpoints/' directory")

Loading mesh...
Computing Laplacian...
Computing reference eigenvalues...

=== Matrix Diagnostics ===
N = 1546, k = 50, ratio = 30.9
Condition number of K: 2.56e+16
Condition number of M: 3.70e+02
Target eigenvalue range: [0.000000, 7.834566]
⚠ Large eigenvalue spread detected

Regularization: ε=0.0001
Condition number after reg: 1.47e+05

Normalization:
  K_scale = 1.58e+02
  M_scale = 1.22e+00
  ||K||_F = 1.0000
  ||M||_F = 1.0000

Sanity checks:
  K symmetric: True
  M positive definite: True
  No NaN/Inf: True

=== Initializing Model ===
Parameters: 128,821
Architecture: [256, 256, 128, 128] -> 50
Fourier features: Enabled (3 frequencies)

HYBRID TRAINING START
Strategy: SVD orthogonalization + Fourier features + Two-stage training
Stage 1: 0-50,000 | Stage 2: 50,000-150,000

[Stage1] Epoch      1 | LR=0.010000
  Loss=54.752092 | λ₁²=5.06e-01 | Trace=0.820
  Orth=1.71e-13 | OffDiag=2.02e-03 | SVD_cond=1.80e+03
  MeanRelErr=1840041252546.8276% | MedianRelErr=2295.1240%
  λ∈[92.32435