# 05 — Side-by-Side Comparison

Loads saved metrics from notebooks 02–04 and produces:
1. Overlay plot of both models' predictions vs. ground truth  
2. Quantitative metrics table (MSE, MAE, max error) on train and test sets  
3. Training dynamics comparison  
4. Sparse-data ablation summary

**Run notebooks 02, 03, and 04 first.**

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os

# ==================== LOAD EVERYTHING ====================
data = np.load('data/synthetic_m.npz')
t = data['t']
V = data['V']
m_true = data['m_true']

dd = np.load('outputs/02_metrics.npz')
pinn = np.load('outputs/03_metrics.npz')

dd_pred   = dd['m_pred_full']
pinn_pred = pinn['m_pred_full']
train_idx = dd['train_idx']
test_idx  = dd['test_idx']

print('All data loaded successfully.')

In [None]:
# ==================== QUANTITATIVE METRICS ====================
def compute_metrics(y_true, y_pred, idx):
    err = y_true[idx] - y_pred[idx]
    return {
        'MSE': np.mean(err ** 2),
        'MAE': np.mean(np.abs(err)),
        'Max Error': np.max(np.abs(err)),
        'RMSE': np.sqrt(np.mean(err ** 2))
    }

dd_train   = compute_metrics(m_true, dd_pred, train_idx)
dd_test    = compute_metrics(m_true, dd_pred, test_idx)
pinn_train = compute_metrics(m_true, pinn_pred, train_idx)
pinn_test  = compute_metrics(m_true, pinn_pred, test_idx)

print(f'{"":>14s}  {"DD Train":>12s}  {"DD Test":>12s}  {"PINN Train":>12s}  {"PINN Test":>12s}')
print('-' * 70)
for metric in ['MSE', 'MAE', 'RMSE', 'Max Error']:
    print(f'{metric:>14s}  {dd_train[metric]:>12.6f}  {dd_test[metric]:>12.6f}  {pinn_train[metric]:>12.6f}  {pinn_test[metric]:>12.6f}')

# Improvement
if dd_test['MSE'] > 0:
    pct = (dd_test['MSE'] - pinn_test['MSE']) / dd_test['MSE'] * 100
    print(f'\nPINN test MSE improvement over data-driven: {pct:+.1f}%')

In [None]:
# ==================== OVERLAY PREDICTION PLOT ====================
fig, axs = plt.subplots(3, 1, figsize=(12, 10))

# Voltage protocol
axs[0].plot(t, V, 'b-', lw=1.5)
axs[0].set_ylabel('Voltage (mV)', fontsize=12)
axs[0].set_title('Voltage Step Protocol', fontsize=13)
axs[0].grid(True, alpha=0.3)

# Both predictions overlaid
axs[1].plot(t, m_true, 'k-', label='Ground Truth', lw=2.5, alpha=0.8)
axs[1].plot(t, dd_pred, '--', color='#e74c3c', label='Data-Driven (ReLU)', lw=2)
axs[1].plot(t, pinn_pred, '--', color='#2ecc71', label='PINN (Tanh + ODE)', lw=2)
axs[1].scatter(t[test_idx], m_true[test_idx], c='orange', s=6, zorder=5, alpha=0.5, label='Test points')
axs[1].set_ylabel('m(t)', fontsize=12)
axs[1].set_title('Prediction Comparison', fontsize=13)
axs[1].legend(fontsize=11); axs[1].grid(True, alpha=0.3)

# Pointwise absolute error
axs[2].plot(t, np.abs(m_true - dd_pred), color='#e74c3c', label='Data-Driven |error|', lw=1.5, alpha=0.8)
axs[2].plot(t, np.abs(m_true - pinn_pred), color='#2ecc71', label='PINN |error|', lw=1.5, alpha=0.8)
axs[2].set_xlabel('Time (ms)', fontsize=12)
axs[2].set_ylabel('Absolute Error', fontsize=12)
axs[2].set_title('Pointwise Error', fontsize=13)
axs[2].legend(fontsize=11); axs[2].grid(True, alpha=0.3)

plt.suptitle('PINN vs. Data-Driven: Full Comparison on HH m-gate', fontsize=15, fontweight='bold')
plt.tight_layout()
plt.savefig('outputs/05_comparison_overlay.png', dpi=200)
plt.show()
print('Saved to outputs/05_comparison_overlay.png')

In [None]:
# ==================== TRAINING DYNAMICS ====================
fig, axs = plt.subplots(1, 2, figsize=(14, 5))

axs[0].plot(dd['train_losses'], label='DD Train', color='#e74c3c')
axs[0].plot(dd['test_losses'], label='DD Test', color='#e74c3c', linestyle='--', alpha=0.7)
axs[0].plot(pinn['train_losses'], label='PINN Total', color='#2ecc71')
axs[0].plot(pinn['test_losses'], label='PINN Test', color='#2ecc71', linestyle='--', alpha=0.7)
axs[0].set_xlabel('Epoch'); axs[0].set_ylabel('Loss')
axs[0].set_title('Training Curves')
axs[0].set_yscale('log')
axs[0].legend(); axs[0].grid(True, alpha=0.3)

# PINN loss decomposition
axs[1].plot(pinn['data_losses'], label='Data Loss', color='#3498db')
axs[1].plot(pinn['phys_losses'], label='Physics Loss', color='#9b59b6')
axs[1].plot(pinn['train_losses'], label='Total Loss', color='#2ecc71', lw=2)
axs[1].set_xlabel('Epoch'); axs[1].set_ylabel('Loss')
axs[1].set_title('PINN Loss Decomposition')
axs[1].set_yscale('log')
axs[1].legend(); axs[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('outputs/05_training_dynamics.png', dpi=200)
plt.show()
print('Saved to outputs/05_training_dynamics.png')

In [None]:
# ==================== ABLATION SUMMARY (if notebook 04 was run) ====================
ablation_path = 'outputs/04_ablation_results.npz'
if os.path.exists(ablation_path):
    ab = np.load(ablation_path)
    sizes = ab['train_sizes']
    
    fig, ax = plt.subplots(figsize=(9, 6))
    ax.errorbar(sizes, ab['dd_means'], yerr=ab['dd_stds'],
                marker='o', capsize=5, lw=2, ms=8,
                label='Data-Driven (ReLU)', color='#e74c3c')
    ax.errorbar(sizes, ab['pinn_means'], yerr=ab['pinn_stds'],
                marker='s', capsize=5, lw=2, ms=8,
                label='PINN (Tanh + ODE)', color='#2ecc71')

    ax.set_xlabel('Training Set Size', fontsize=13)
    ax.set_ylabel('Test MSE', fontsize=13)
    ax.set_title('Sparse-Data Ablation: Generalization vs. Data Quantity', fontsize=14)
    ax.legend(fontsize=12)
    ax.set_xscale('log'); ax.set_yscale('log')
    ax.grid(True, which='both', alpha=0.3)
    ax.set_xticks(sizes)
    ax.set_xticklabels(sizes)

    plt.tight_layout()
    plt.savefig('outputs/05_ablation_summary.png', dpi=200)
    plt.show()
    
    # Print improvement at each size
    print(f'{"N_train":>8}  {"DD MSE":>10}  {"PINN MSE":>10}  {"PINN advantage":>14}')
    print('-' * 48)
    for i, n in enumerate(sizes):
        improv = (ab['dd_means'][i] - ab['pinn_means'][i]) / ab['dd_means'][i] * 100
        print(f'{n:>8d}  {ab["dd_means"][i]:>10.6f}  {ab["pinn_means"][i]:>10.6f}  {improv:>+12.1f}%')
else:
    print('Notebook 04 results not found — run 04_sparse_ablation.ipynb first.')

## Summary

Key findings to reference in the paper:

1. **Full-data regime:** Both models fit the HH m-gate well with 800 training points. The PINN's advantage is modest here.  
2. **Sparse-data regime:** As training data decreases, the data-driven model degrades significantly while the PINN remains accurate — the physics constraint acts as a powerful regularizer.  
3. **Activation functions matter:** Tanh (smooth, infinitely differentiable) is critical for PINNs. ReLU's piecewise-linear nature kills gradient signal through the physics loss.  
4. **Collocation is free:** The PINN evaluates physics at all time points regardless of label availability — it gets extra information from the ODE without needing more labeled data.