In [None]:
# === Setup ===
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from pathlib import Path
import sys
import time

sys.path.insert(0, str(Path.cwd().parent.parent.parent))

from modules._import_helper import safe_import_from

set_seed = safe_import_from('00_repo_standards.src.mlphys_core', 'set_seed')
(HarmonicOscillatorConfig, HarmonicOscillatorPINN,
 analytical_harmonic_oscillator) = safe_import_from(
    '07_physics_informed_ml.src.ode_pinn',
    'HarmonicOscillatorConfig', 'HarmonicOscillatorPINN',
    'analytical_harmonic_oscillator'
)
(HeatEquationConfig, HeatEquationPINN,
 solve_heat_equation_finite_difference) = safe_import_from(
    '07_physics_informed_ml.src.pde_pinn',
    'HeatEquationConfig', 'HeatEquationPINN',
    'solve_heat_equation_finite_difference'
)

plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 11
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3

reports_dir = Path.cwd().parent / 'reports'
reports_dir.mkdir(exist_ok=True)

SEED = 42
set_seed(SEED)
torch.manual_seed(SEED)

print("‚úì Setup complete")

---
## 1. Failure Mode #1: High-Frequency Solutions

**Problem**: NNs have "spectral bias" - they learn low frequencies first.

**Test**: Harmonic oscillator with increasing $\omega$

In [None]:
# === High frequency failure ===
omega_values = [1.0, 3.0, 5.0, 10.0]
t_test = np.linspace(0, 10, 200)
results_freq = []

print("FAILURE MODE: High Frequencies")
print("="*60)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

for idx, omega in enumerate(omega_values):
    set_seed(SEED)
    
    config = HarmonicOscillatorConfig(
        omega=omega, x0=1.0, v0=0.0, t_max=10.0,
        n_collocation=200,
        hidden_dims=[32, 32, 32],
        epochs=3000,
    )
    
    pinn = HarmonicOscillatorPINN(config)
    history = pinn.train(verbose=0)
    
    x_pred, _ = pinn.predict_with_velocity(t_test)
    x_true, _ = analytical_harmonic_oscillator(omega, 1.0, 0.0, t_test)
    
    rmse = np.sqrt(np.mean((x_pred - x_true)**2))
    n_osc = omega * 10 / (2 * np.pi)
    
    results_freq.append({'omega': omega, 'rmse': rmse, 'n_osc': n_osc})
    
    # Plot
    ax = axes[idx]
    ax.plot(t_test, x_true, 'k-', lw=2, label='True')
    ax.plot(t_test, x_pred, 'r--', lw=2, label='PINN')
    ax.set_title(f'œâ={omega} ({n_osc:.1f} oscillations)\nRMSE = {rmse:.4f}')
    ax.set_xlabel('t')
    ax.set_ylabel('x(t)')
    ax.legend()
    
    status = "‚úì" if rmse < 0.1 else "‚úó"
    print(f"œâ={omega:5.1f} | {n_osc:5.1f} osc | RMSE={rmse:.4f} {status}")

plt.tight_layout()
plt.savefig(reports_dir / '04_failure_high_frequency.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n‚ö†Ô∏è DIAGNOSIS: Standard PINN fails for œâ ‚â• 5")
print("   - Spectral bias: NNs are biased toward smooth solutions")
print("   - High-frequency modes require more capacity or specialized architectures")

### Fix Attempt #1: More Collocation Points

In [None]:
# === Fix: More collocation points ===
omega_hard = 10.0
n_col_values = [100, 200, 500, 1000]
results_ncol = []

print(f"\nFIX ATTEMPT: More collocation points (œâ={omega_hard})")
print("-"*50)

for n_col in n_col_values:
    set_seed(SEED)
    
    config = HarmonicOscillatorConfig(
        omega=omega_hard, x0=1.0, v0=0.0, t_max=10.0,
        n_collocation=n_col,
        hidden_dims=[32, 32, 32],
        epochs=3000,
    )
    
    pinn = HarmonicOscillatorPINN(config)
    _ = pinn.train(verbose=0)
    
    x_pred, _ = pinn.predict_with_velocity(t_test)
    x_true, _ = analytical_harmonic_oscillator(omega_hard, 1.0, 0.0, t_test)
    rmse = np.sqrt(np.mean((x_pred - x_true)**2))
    
    results_ncol.append({'n_col': n_col, 'rmse': rmse})
    print(f"N_col={n_col:4d} | RMSE={rmse:.4f}")

print("\nüí° Verdict: Helps somewhat, but doesn't fully solve the problem")

### Fix Attempt #2: Larger Network

In [None]:
# === Fix: Larger network ===
arch_values = [[32, 32], [64, 64, 64], [128, 128, 128], [64, 64, 64, 64, 64]]
results_arch = []

print(f"\nFIX ATTEMPT: Larger network (œâ={omega_hard})")
print("-"*50)

for hidden_dims in arch_values:
    set_seed(SEED)
    
    config = HarmonicOscillatorConfig(
        omega=omega_hard, x0=1.0, v0=0.0, t_max=10.0,
        n_collocation=500,
        hidden_dims=hidden_dims,
        epochs=3000,
    )
    
    pinn = HarmonicOscillatorPINN(config)
    _ = pinn.train(verbose=0)
    
    x_pred, _ = pinn.predict_with_velocity(t_test)
    x_true, _ = analytical_harmonic_oscillator(omega_hard, 1.0, 0.0, t_test)
    rmse = np.sqrt(np.mean((x_pred - x_true)**2))
    
    n_params = sum(p.numel() for p in pinn.model.parameters())
    results_arch.append({'arch': str(hidden_dims), 'rmse': rmse, 'n_params': n_params})
    print(f"{str(hidden_dims):25s} | {n_params:6d} params | RMSE={rmse:.4f}")

print("\nüí° Verdict: Larger networks help but still struggle with œâ=10")
print("   Better solution: Fourier features or periodic activations")

---
## 2. Failure Mode #2: Stiff ODEs

**Problem**: Systems with widely separated time scales.

**Example**: Damped oscillator with strong damping $\gamma \gg \omega$

In [None]:
# === Custom stiff ODE PINN ===
# We'll create a simple damped oscillator manually

class StiffODEPINN:
    """PINN for damped oscillator: x'' + 2*gamma*x' + omega^2*x = 0"""
    
    def __init__(self, omega, gamma, x0, v0, t_max, n_col=200, epochs=5000, lr=1e-3):
        self.omega = omega
        self.gamma = gamma
        self.x0 = x0
        self.v0 = v0
        
        # Network
        self.model = nn.Sequential(
            nn.Linear(1, 64), nn.Tanh(),
            nn.Linear(64, 64), nn.Tanh(),
            nn.Linear(64, 64), nn.Tanh(),
            nn.Linear(64, 1)
        )
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        
        # Collocation
        self.t_col = torch.linspace(0, t_max, n_col).view(-1, 1)
        self.t_col.requires_grad = True
        
        self.epochs = epochs
        
    def train(self, verbose=500):
        history = {'loss': []}
        
        for epoch in range(self.epochs):
            self.optimizer.zero_grad()
            
            # Predictions
            x = self.model(self.t_col)
            
            # Derivatives
            x_t = torch.autograd.grad(x, self.t_col, torch.ones_like(x), create_graph=True)[0]
            x_tt = torch.autograd.grad(x_t, self.t_col, torch.ones_like(x_t), create_graph=True)[0]
            
            # ODE residual: x'' + 2*gamma*x' + omega^2*x = 0
            residual = x_tt + 2*self.gamma*x_t + self.omega**2 * x
            loss_physics = torch.mean(residual**2)
            
            # IC
            t0 = torch.tensor([[0.0]], requires_grad=True)
            x0_pred = self.model(t0)
            v0_pred = torch.autograd.grad(x0_pred, t0, torch.ones_like(x0_pred), create_graph=True)[0]
            loss_ic = (x0_pred - self.x0)**2 + (v0_pred - self.v0)**2
            
            loss = loss_physics + 10.0 * loss_ic
            loss.backward()
            self.optimizer.step()
            
            history['loss'].append(loss.item())
            
            if verbose > 0 and (epoch + 1) % verbose == 0:
                print(f"Epoch {epoch+1}: loss={loss.item():.6f}")
        
        return history
    
    def predict(self, t):
        self.model.eval()
        with torch.no_grad():
            t_tensor = torch.tensor(t, dtype=torch.float32).view(-1, 1)
            return self.model(t_tensor).numpy().flatten()

def damped_analytical(omega, gamma, x0, v0, t):
    """Analytical solution for underdamped case (gamma < omega)"""
    omega_d = np.sqrt(omega**2 - gamma**2)  # Damped frequency
    A = x0
    B = (v0 + gamma * x0) / omega_d
    return np.exp(-gamma * t) * (A * np.cos(omega_d * t) + B * np.sin(omega_d * t))

In [None]:
# === Test stiffness ===
omega_base = 1.0
gamma_values = [0.1, 0.5, 0.9]  # Increasing damping (stiffness ratio = gamma/omega)

t_test = np.linspace(0, 10, 200)
results_stiff = []

print("FAILURE MODE: Stiff ODEs (Strong Damping)")
print("="*60)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for idx, gamma in enumerate(gamma_values):
    set_seed(SEED)
    
    pinn = StiffODEPINN(
        omega=omega_base, gamma=gamma,
        x0=1.0, v0=0.0, t_max=10.0,
        n_col=300, epochs=3000, lr=1e-3
    )
    _ = pinn.train(verbose=0)
    
    x_pred = pinn.predict(t_test)
    x_true = damped_analytical(omega_base, gamma, 1.0, 0.0, t_test)
    
    rmse = np.sqrt(np.mean((x_pred - x_true)**2))
    stiffness = gamma / omega_base
    
    results_stiff.append({'gamma': gamma, 'stiffness': stiffness, 'rmse': rmse})
    
    # Plot
    ax = axes[idx]
    ax.plot(t_test, x_true, 'k-', lw=2, label='True')
    ax.plot(t_test, x_pred, 'r--', lw=2, label='PINN')
    ax.set_title(f'Œ≥={gamma} (Œ≥/œâ={stiffness:.1f})\nRMSE={rmse:.4f}')
    ax.set_xlabel('t'); ax.set_ylabel('x(t)')
    ax.legend()
    
    status = "‚úì" if rmse < 0.05 else "‚úó"
    print(f"Œ≥={gamma:.1f} | Stiffness={stiffness:.1f} | RMSE={rmse:.4f} {status}")

plt.tight_layout()
plt.savefig(reports_dir / '04_failure_stiffness.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nüí° Higher damping ‚Üí faster decay ‚Üí multi-scale dynamics")
print("   Stiff systems need adaptive time-stepping or curriculum learning")

---
## 3. Failure Mode #3: Long Time Horizons

**Problem**: Error accumulates over time, especially for oscillatory solutions.

In [None]:
# === Long time horizon test ===
t_max_values = [5.0, 10.0, 20.0, 50.0]
results_long = []

print("FAILURE MODE: Long Time Horizons")
print("="*60)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

for idx, t_max in enumerate(t_max_values):
    set_seed(SEED)
    
    config = HarmonicOscillatorConfig(
        omega=2.0, x0=1.0, v0=0.0, t_max=t_max,
        n_collocation=int(50 * t_max / 5),  # Scale with time
        hidden_dims=[64, 64, 64],
        epochs=3000,
    )
    
    pinn = HarmonicOscillatorPINN(config)
    _ = pinn.train(verbose=0)
    
    t_test = np.linspace(0, t_max, 200)
    x_pred, _ = pinn.predict_with_velocity(t_test)
    x_true, _ = analytical_harmonic_oscillator(2.0, 1.0, 0.0, t_test)
    
    rmse = np.sqrt(np.mean((x_pred - x_true)**2))
    n_periods = 2.0 * t_max / (2 * np.pi)
    
    results_long.append({'t_max': t_max, 'n_periods': n_periods, 'rmse': rmse})
    
    # Plot
    ax = axes[idx]
    ax.plot(t_test, x_true, 'k-', lw=2, label='True')
    ax.plot(t_test, x_pred, 'r--', lw=2, label='PINN')
    ax.set_title(f't_max={t_max} ({n_periods:.1f} periods)\nRMSE={rmse:.4f}')
    ax.set_xlabel('t'); ax.set_ylabel('x(t)')
    ax.legend()
    
    status = "‚úì" if rmse < 0.1 else "‚úó"
    print(f"t_max={t_max:5.1f} | {n_periods:5.1f} periods | RMSE={rmse:.4f} {status}")

plt.tight_layout()
plt.savefig(reports_dir / '04_failure_long_time.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nüí° Longer time horizons ‚Üí more oscillations ‚Üí higher error")
print("   Solution: time-windowing, causal training, or sequence-to-sequence approach")

---
## 4. Debugging Checklist

A systematic approach to diagnosing PINN problems.

In [None]:
# === Debugging Checklist Implementation ===

def diagnose_pinn(pinn, config, t_test, x_true):
    """Run diagnostics on a trained PINN."""
    print("="*60)
    print("PINN DIAGNOSTIC REPORT")
    print("="*60)
    
    # 1. IC Satisfaction
    x_pred, v_pred = pinn.predict_with_velocity(np.array([0.0]))
    ic_x_err = abs(x_pred[0] - config.x0)
    ic_v_err = abs(v_pred[0] - config.v0)
    print(f"\n1. IC SATISFACTION:")
    print(f"   x(0) error: {ic_x_err:.6f} {'‚úì' if ic_x_err < 0.01 else '‚úó'}")
    print(f"   v(0) error: {ic_v_err:.6f} {'‚úì' if ic_v_err < 0.01 else '‚úó'}")
    
    # 2. Overall Error
    x_pred_full, _ = pinn.predict_with_velocity(t_test)
    rmse = np.sqrt(np.mean((x_pred_full - x_true)**2))
    max_err = np.max(np.abs(x_pred_full - x_true))
    print(f"\n2. PREDICTION ERROR:")
    print(f"   RMSE: {rmse:.6f}")
    print(f"   Max Error: {max_err:.6f}")
    
    # 3. Error Distribution Over Time
    error_vs_t = np.abs(x_pred_full - x_true)
    early_error = np.mean(error_vs_t[:len(t_test)//4])
    late_error = np.mean(error_vs_t[-len(t_test)//4:])
    print(f"\n3. ERROR DISTRIBUTION:")
    print(f"   Early time (t<25%): {early_error:.6f}")
    print(f"   Late time (t>75%): {late_error:.6f}")
    if late_error > 2 * early_error:
        print("   ‚ö†Ô∏è Error grows over time - consider curriculum learning")
    
    # 4. Model Info
    n_params = sum(p.numel() for p in pinn.model.parameters())
    print(f"\n4. MODEL INFO:")
    print(f"   Parameters: {n_params}")
    print(f"   Collocation points: {config.n_collocation}")
    print(f"   Œª_IC: {config.lambda_ic}")
    
    print("\n" + "="*60)
    
    return {
        'ic_x_err': ic_x_err,
        'ic_v_err': ic_v_err,
        'rmse': rmse,
        'max_err': max_err,
        'early_error': early_error,
        'late_error': late_error,
    }

# Test diagnostics on a problematic case
set_seed(SEED)
config_diag = HarmonicOscillatorConfig(
    omega=5.0, x0=1.0, v0=0.0, t_max=10.0,
    n_collocation=200, hidden_dims=[32, 32, 32], epochs=3000
)
pinn_diag = HarmonicOscillatorPINN(config_diag)
_ = pinn_diag.train(verbose=0)

t_diag = np.linspace(0, 10, 200)
x_true_diag, _ = analytical_harmonic_oscillator(5.0, 1.0, 0.0, t_diag)

diag_results = diagnose_pinn(pinn_diag, config_diag, t_diag, x_true_diag)

---
## 5. What Works / What Doesn't Summary

In [None]:
# === Generate summary report ===

report = f"""
# PINN Failure Modes & Debugging - Report

**Date**: {time.strftime('%Y-%m-%d %H:%M')}
**Seed**: {SEED}

## Executive Summary

This notebook provides an honest evaluation of where PINNs struggle and what can be done.

## Failure Mode #1: High Frequencies

| œâ | # Oscillations | RMSE | Status |
|---|----------------|------|--------|
""" + "\n".join([f"| {r['omega']:.1f} | {r['n_osc']:.1f} | {r['rmse']:.4f} | {'‚úì' if r['rmse'] < 0.1 else '‚úó'} |" for r in results_freq]) + """

**Root cause**: Neural networks have spectral bias (prefer low frequencies).

**Partial fixes**:
- More collocation points (diminishing returns)
- Larger networks (helps somewhat)
- Fourier features or periodic activations (recommended)

## Failure Mode #2: Stiff ODEs

| Œ≥ | Stiffness (Œ≥/œâ) | RMSE | Status |
|---|-----------------|------|--------|
""" + "\n".join([f"| {r['gamma']:.1f} | {r['stiffness']:.1f} | {r['rmse']:.4f} | {'‚úì' if r['rmse'] < 0.05 else '‚úó'} |" for r in results_stiff]) + """

**Root cause**: Multiple time scales require different learning rates.

**Partial fixes**:
- Input normalization
- Adaptive loss weighting
- Curriculum learning (start with short times)

## Failure Mode #3: Long Time Horizons

| t_max | # Periods | RMSE | Status |
|-------|-----------|------|--------|
""" + "\n".join([f"| {r['t_max']:.1f} | {r['n_periods']:.1f} | {r['rmse']:.4f} | {'‚úì' if r['rmse'] < 0.1 else '‚úó'} |" for r in results_long]) + """

**Root cause**: Error accumulates; phase drift over many oscillations.

**Partial fixes**:
- Time windowing (train on shorter segments)
- Causal training (weight early times more)
- Curriculum learning

## Debugging Checklist

1. **Check IC/BC satisfaction first** - If violated, increase Œª_IC/Œª_BC
2. **Monitor loss components** - Which one dominates?
3. **Plot error vs time** - Growing error ‚Üí temporal drift
4. **Visualize residual field** - Where is PDE residual largest?
5. **Try simpler problem first** - If that fails, fix fundamentals
6. **Input normalization** - Scale inputs to [0, 1] or [-1, 1]
7. **Learning rate schedule** - Reduce LR over training
8. **Gradient clipping** - If NaNs appear

## When to Use PINNs vs Classical Methods

| Scenario | Recommendation |
|----------|----------------|
| Simple ODE/PDE, well-posed | **Classical methods** (scipy, FD) |
| Inverse problem | **PINNs** |
| Irregular geometry | **PINNs** (mesh-free advantage) |
| Multi-physics coupling | **PINNs** (easier to combine) |
| High accuracy required | **Classical methods** |
| Real-time inference needed | **Classical methods** (faster) |
| Sparse/noisy data + physics | **PINNs** (physics regularization) |

## Key Takeaway

**PINNs are NOT universally better than classical methods.**

They excel when:
- Physics is known but data is scarce
- Inverse problems (parameter discovery)
- Geometry is complex (mesh-free)

They struggle when:
- High frequencies or multi-scale dynamics
- High accuracy is required
- Speed is critical

**Choose the right tool for the job!**
"""

with open(reports_dir / '04_failure_modes_report.md', 'w') as f:
    f.write(report)

print(report)
print(f"\n‚úì Report saved to {reports_dir / '04_failure_modes_report.md'}")

---
## 6. Mini Exercises

**Exercise 1**: Implement input normalization for the high-frequency case. Does it help?

**Exercise 2**: Try reducing the learning rate for the stiff ODE. What happens?

**Exercise 3**: Implement a simple curriculum: train on t‚àà[0,2], then [0,5], then [0,10].

**Exercise 4**: Add gradient clipping to the training loop. When does it help?

In [None]:
# === Exercise 1: Input normalization ===
# YOUR CODE HERE


In [None]:
# === Exercise 2: Learning rate for stiff ODE ===
# YOUR CODE HERE


In [None]:
# === Exercise 3: Curriculum learning ===
# YOUR CODE HERE


In [None]:
# === Exercise 4: Gradient clipping ===
# YOUR CODE HERE


---
## Key Takeaways

### ‚úÖ What We Learned

1. **PINNs have real limitations**: High frequencies, stiffness, long times
2. **Spectral bias is fundamental**: NNs prefer smooth solutions
3. **Diagnostics are essential**: Systematic debugging saves time
4. **Know when to use alternatives**: Classical methods often win on simple problems

### ‚ö†Ô∏è Honest Assessment

PINNs are a **tool**, not a silver bullet:
- Great for inverse problems and irregular geometries
- Poor for high-accuracy forward simulation
- Always compare against classical baselines

### üí° Practical Recommendations

1. **Start simple**: Test on easy problems first
2. **Use diagnostics**: Check IC/BC, monitor all losses
3. **Scale inputs**: Normalize to [0,1] or [-1,1]
4. **Tune Œª weights**: Problem-specific, start with Œª_IC > Œª_physics
5. **Consider curriculum**: For hard problems, start easy