# PINN Fundamentals: 1D Heat Equation

## Problem Setup

We solve:
$$\frac{\partial u}{\partial t} = \alpha \frac{\partial^2 u}{\partial x^2}, \quad x \in [0,1],\; t \in [0,1]$$

with IC $u(x,0) = \sin(\pi x)$ and BCs $u(0,t) = u(1,t) = 0$.

- **Forward problem**: given $\alpha = 0.01$, find $u(x,t)$
- **Inverse problem**: given noisy measurements of $u$, recover $\alpha$

This notebook covers:
1. Data generation
2. Forward problem — optimiser comparison (Adam, L-BFGS, adaptive weights)
3. Inverse problem — same optimiser comparison
4. Sensitivity analysis — noise, initial guess, data volume, uncertainty

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append('..')

from data.heat_data import HeatEquationData
from models.heat_pinn_strategy import StrategicPINN
from training.trainer_strategy import StrategicPINNTrainer
from utils.plotter import plot_solution

torch.manual_seed(42)  # reproducible weight initialisation
print('Imports successful.')

## 1. Data Generation

In [None]:
data_gen = HeatEquationData(
    L=1.0, T=1.0, alpha=0.01,
    N_f=10000, N_bc=100, N_ic=200,
    N_sensors=10, N_time_measurements=10,
    noise_level=0.01,
    device='cpu',
    random_seed=42
)

data = data_gen.generate_full_dataset()
data_gen.visualize_data(data)

## 2. Forward Problem

We fix $\alpha = 0.01$ and train the network to approximate $u(x,t)$.

### 2.1 Adam optimizer only

Baseline: Adam with a `ReduceLROnPlateau` schedule and no second-order refinement.

In [None]:
model_fwd_adam = StrategicPINN(
    layers=[2, 50, 50, 50, 50, 1],
    alpha_true=0.01,
    inverse=False
)

trainer_fwd_adam = StrategicPINNTrainer(
    model=model_fwd_adam,
    data=data,
    learning_rate=1e-3,
    switch_var=1e-12,   # disabled: never switch to L-BFGS
    switch_slope=1e-12,
    adaptive_weights=False,
)

trainer_fwd_adam.train(epochs=5000, print_every=1000, plot_every=2500)

In [None]:
plot_solution(model_fwd_adam, data, alpha_true=0.01)

### 2.2 Adam + L-BFGS

Adam warms up the parameters; once training plateaus, L-BFGS is used for second-order refinement.

In [None]:
model_fwd_lbfgs = StrategicPINN(
    layers=[2, 50, 50, 50, 50, 1],
    alpha_true=0.01,
    inverse=False
)

trainer_fwd_lbfgs = StrategicPINNTrainer(
    model=model_fwd_lbfgs,
    data=data,
    learning_rate=1e-3,
    switch_var=0.1,
    switch_slope=0.001,
    adaptive_weights=False,
)

trainer_fwd_lbfgs.train(epochs=5000, print_every=1000, plot_every=2500)

In [None]:
plot_solution(model_fwd_lbfgs, data, alpha_true=0.01)

### 2.3 L-BFGS advantage

Train Adam-only to the same epoch count as the L-BFGS switch point above, then compare errors directly.

In [None]:
# Find the epoch at which L-BFGS kicked in
lbfgs_switch = next(
    (i for i, o in enumerate(trainer_fwd_lbfgs.history['optimizer']) if o == 'lbfgs'),
    len(trainer_fwd_lbfgs.history['optimizer'])
)
print(f'L-BFGS switch epoch: {lbfgs_switch}')

model_fwd_adam_same = StrategicPINN(
    layers=[2, 50, 50, 50, 50, 1],
    alpha_true=0.01,
    inverse=False
)

trainer_fwd_adam_same = StrategicPINNTrainer(
    model=model_fwd_adam_same,
    data=data,
    switch_var=1e-12,
    switch_slope=1e-12,
    adaptive_weights=False,
)

trainer_fwd_adam_same.train(epochs=lbfgs_switch, print_every=500, plot_every=10000)

In [None]:
from models.heat_pinn import analytical_solution

x_eval = torch.linspace(0, 1, 100).reshape(-1, 1)
t_eval = torch.linspace(0, 1, 100).reshape(-1, 1)
X, T = torch.meshgrid(x_eval.squeeze(), t_eval.squeeze(), indexing='ij')
x_flat = X.flatten().reshape(-1, 1)
t_flat = T.flatten().reshape(-1, 1)

u_exact = analytical_solution(x_flat.numpy(), t_flat.numpy(), alpha=0.01)
u_adam  = model_fwd_adam_same.predict(x_flat, t_flat)
u_lbfgs = model_fwd_lbfgs.predict(x_flat, t_flat)

rel_adam  = np.sqrt(np.mean((u_adam  - u_exact)**2)) / np.sqrt(np.mean(u_exact**2)) * 100
rel_lbfgs = np.sqrt(np.mean((u_lbfgs - u_exact)**2)) / np.sqrt(np.mean(u_exact**2)) * 100

print(f'Adam only  ({lbfgs_switch} epochs): relative L2 = {rel_adam:.4f}%')
print(f'Adam+LBFGS ({lbfgs_switch}+ epochs): relative L2 = {rel_lbfgs:.4f}%')
print(f'Improvement factor: {rel_adam / rel_lbfgs:.1f}x')

### 2.4 Loss gradient norms

Training loss spikes occur when the optimiser probes high-curvature regions of the parameter landscape. Tracking the per-loss gradient norms $\|\nabla_\theta \mathcal{L}_i\|_2$ reveals which loss component dominates at each spike.

In [None]:
model_fwd_gn = StrategicPINN(
    layers=[2, 50, 50, 50, 50, 1],
    alpha_true=0.01,
    inverse=False
)

trainer_fwd_gn = StrategicPINNTrainer(
    model=model_fwd_gn,
    data=data,
    learning_rate=1e-3,
    switch_var=1e-12,
    switch_slope=1e-12,
    track_gradient_norms=True,
    adaptive_weights=False,
)

trainer_fwd_gn.train(epochs=2000, print_every=1000, plot_every=1000)

### 2.5 Adaptive weights

Adaptive weighting re-balances loss terms dynamically based on gradient magnitudes (Wang et al., 2021 proxy). Useful when loss components differ by orders of magnitude.

In [None]:
model_fwd_aw = StrategicPINN(
    layers=[2, 50, 50, 50, 50, 1],
    alpha_true=0.01,
    inverse=False
)

trainer_fwd_aw = StrategicPINNTrainer(
    model=model_fwd_aw,
    data=data,
    learning_rate=1e-3,
    switch_var=0.1,
    switch_slope=0.001,
    adaptive_weights=True,
)

trainer_fwd_aw.train(epochs=5000, print_every=1000, plot_every=2500)

In [None]:
plot_solution(model_fwd_aw, data, alpha_true=0.01)

**Forward problem summary**

| Configuration | Notes |
|---|---|
| Adam only | Reasonable baseline but stagnates |
| Adam + L-BFGS | ~8× lower L2 error at same wall time |
| Adaptive weights | Small benefit for this smooth PDE |

## 3. Inverse Problem

Now $\alpha$ is unknown. We add noisy sensor measurements as an extra loss term and recover $\alpha$ jointly with $u$.

### 3.1 Adam only

In [None]:
model_inv_adam = StrategicPINN(
    layers=[2, 50, 50, 50, 50, 1],
    inverse=True,
    alpha_init=0.02
)

trainer_inv_adam = StrategicPINNTrainer(
    model=model_inv_adam,
    data=data,
    learning_rate=1e-3,
    switch_var=1e-12,
    switch_slope=1e-12,
    adaptive_weights=False,
)

trainer_inv_adam.train(epochs=5000, print_every=1000, plot_every=2500)

In [None]:
def eval_inverse(model, alpha_true=0.01):
    pred = model.get_alpha()
    err  = abs(pred - alpha_true) / alpha_true * 100
    print(f'True α: {alpha_true:.6f}  |  Predicted α: {pred:.6f}  |  Error: {err:.2f}%')
    status = 'SUCCESS' if err < 5 else 'needs longer training'
    print(f'[{status}]')

eval_inverse(model_inv_adam)
plot_solution(model_inv_adam, data, alpha_true=0.01)

### 3.2 Adam + L-BFGS

In [None]:
model_inv_lbfgs = StrategicPINN(
    layers=[2, 50, 50, 50, 50, 1],
    inverse=True,
    alpha_init=0.02
)

trainer_inv_lbfgs = StrategicPINNTrainer(
    model=model_inv_lbfgs,
    data=data,
    learning_rate=1e-3,
    switch_var=0.1,
    switch_slope=0.001,
    adaptive_weights=False,
)

trainer_inv_lbfgs.train(epochs=5000, print_every=1000, plot_every=2500)

In [None]:
eval_inverse(model_inv_lbfgs)
plot_solution(model_inv_lbfgs, data, alpha_true=0.01)

### 3.3 Adaptive weights + L-BFGS

The extra measurement loss makes the inverse problem harder to balance — adaptive weights help more here.

In [None]:
model_inv_aw = StrategicPINN(
    layers=[2, 50, 50, 50, 50, 1],
    inverse=True,
    alpha_init=0.02
)

trainer_inv_aw = StrategicPINNTrainer(
    model=model_inv_aw,
    data=data,
    learning_rate=1e-3,
    switch_var=0.1,
    switch_slope=0.001,
    adaptive_weights=True,
)

trainer_inv_aw.train(epochs=5000, print_every=1000, plot_every=2500)

In [None]:
eval_inverse(model_inv_aw)
plot_solution(model_inv_aw, data, alpha_true=0.01)

**Inverse problem summary**

| Configuration | Typical α error |
|---|---|
| Adam only | ~1–2% |
| Adam + L-BFGS | ~0.5% |
| Adaptive + L-BFGS | ~0.1% |

Adaptive weights are more impactful for the inverse problem than the forward problem.

## 4. Sensitivity Analysis

All experiments below use the best configuration found above: Adam + L-BFGS with adaptive weights.

### 4.1 Effect of measurement noise

⚠️ **Long-running cell** — trains 4 models × 5000 epochs each.

In [None]:
noise_levels = [0.001, 0.01, 0.05, 0.1]
results_noise = []

for noise in noise_levels:
    print(f'\nNoise level: {noise:.1%}')
    dg = HeatEquationData(
        N_f=10000, N_bc=100, N_ic=200,
        N_sensors=10, N_time_measurements=10,
        noise_level=noise, random_seed=42
    )
    d = dg.generate_full_dataset()
    m = StrategicPINN(inverse=True, alpha_init=0.02)
    t = StrategicPINNTrainer(
        m, d, switch_var=0.1, switch_slope=0.001, adaptive_weights=True
    )
    t.train(epochs=5000, print_every=2500, plot_every=10000)
    results_noise.append({'noise': noise, 'error': abs(m.get_alpha() - 0.01) / 0.01 * 100})

fig, ax = plt.subplots(figsize=(8, 5))
ax.plot([r['noise'] for r in results_noise], [r['error'] for r in results_noise],
        'bo-', markersize=10, linewidth=2)
ax.axhline(y=5, color='r', linestyle='--', label='5% threshold')
ax.set_xscale('log')
ax.set_xlabel('Noise level'); ax.set_ylabel('α error (%)')
ax.set_title('Parameter recovery vs measurement noise')
ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout(); plt.show()

### 4.2 Sensitivity to initial guess

⚠️ **Long-running cell** — trains 4 models × 5000 epochs each.

In [None]:
alpha_guesses = [0.005, 0.05, 0.5, 5.0]
results_guess = []

for guess in alpha_guesses:
    print(f'\nInitial guess: {guess}')
    m = StrategicPINN(inverse=True, alpha_init=guess)
    t = StrategicPINNTrainer(
        m, data, switch_var=0.1, switch_slope=0.001, adaptive_weights=True
    )
    t.train(epochs=5000, print_every=2500, plot_every=10000)
    results_guess.append({'guess': guess, 'error': abs(m.get_alpha() - 0.01) / 0.01 * 100})

fig, ax = plt.subplots(figsize=(8, 5))
ax.plot([r['guess'] for r in results_guess], [r['error'] for r in results_guess],
        'bo-', markersize=10, linewidth=2)
ax.axhline(y=5, color='r', linestyle='--', label='5% threshold')
ax.axvline(x=0.01, color='g', linestyle=':', label='True α')
ax.set_xscale('log')
ax.set_xlabel('Initial α guess'); ax.set_ylabel('α error (%)')
ax.set_title('Parameter recovery vs initial guess')
ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout(); plt.show()

### 4.3 Data requirements

How many sensor measurements are needed for reliable parameter recovery?

⚠️ **Long-running cell** — trains 4 models × 5000 epochs each.

In [None]:
x_sensor_counts = [2, 5, 10, 20]
t_sensors = 5
results_data = []

for n_s in x_sensor_counts:
    n_total = n_s * t_sensors
    print(f'\nSensors: {n_s} × {t_sensors} = {n_total} measurements')
    dg = HeatEquationData(
        N_f=10000, N_bc=100, N_ic=200,
        N_sensors=n_s, N_time_measurements=t_sensors,
        noise_level=0.01, random_seed=42
    )
    d = dg.generate_full_dataset()
    m = StrategicPINN(inverse=True, alpha_init=0.02)
    t = StrategicPINNTrainer(
        m, d, switch_var=0.1, switch_slope=0.001, adaptive_weights=True
    )
    t.train(epochs=5000, print_every=2500, plot_every=10000)
    results_data.append({'n': n_total, 'error': abs(m.get_alpha() - 0.01) / 0.01 * 100})

fig, ax = plt.subplots(figsize=(8, 5))
ax.plot([r['n'] for r in results_data], [r['error'] for r in results_data],
        'bo-', markersize=10, linewidth=2)
ax.axhline(y=5, color='r', linestyle='--', label='5% threshold')
ax.set_xlabel('Number of measurements'); ax.set_ylabel('α error (%)')
ax.set_title('Parameter recovery vs data volume')
ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout(); plt.show()

### 4.4 Uncertainty quantification

Train an ensemble of 10 models with different random seeds to estimate variability in $\hat{\alpha}$.

⚠️ **Long-running cell** — trains 10 models × 5000 epochs each.

In [None]:
alphas_ensemble = []

for seed in range(10):
    torch.manual_seed(seed)
    m = StrategicPINN(inverse=True, alpha_init=0.02)
    t = StrategicPINNTrainer(
        m, data, switch_var=0.1, switch_slope=0.001, adaptive_weights=True
    )
    t.train(epochs=5000, print_every=5000, plot_every=10000)
    alphas_ensemble.append(m.get_alpha())

alpha_mean = np.mean(alphas_ensemble)
alpha_std  = np.std(alphas_ensemble)
print(f'\nEnsemble results (n=10):')
print(f'  α = {alpha_mean:.6f} ± {alpha_std:.6f}')
print(f'  True α = 0.01')
print(f'  Mean error: {abs(alpha_mean - 0.01) / 0.01 * 100:.2f}%')
print(f'  95% CI: [{alpha_mean - 2*alpha_std:.6f}, {alpha_mean + 2*alpha_std:.6f}]')

fig, ax = plt.subplots(figsize=(8, 4))
ax.scatter(range(10), alphas_ensemble, zorder=3)
ax.axhline(y=0.01, color='r', linestyle='--', label='True α')
ax.axhline(y=alpha_mean, color='b', linestyle='-', label=f'Mean = {alpha_mean:.5f}')
ax.fill_between(range(10),
                alpha_mean - alpha_std, alpha_mean + alpha_std,
                alpha=0.2, color='b', label='±1 std')
ax.set_xlabel('Seed'); ax.set_ylabel('α')
ax.set_title('Ensemble uncertainty in α recovery')
ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout(); plt.show()