In [None]:
# === Setup ===
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from pathlib import Path
import sys
import time

sys.path.insert(0, str(Path.cwd().parent.parent.parent))

from modules._import_helper import safe_import_from

set_seed = safe_import_from('00_repo_standards.src.mlphys_core', 'set_seed')
(ConservationConfig, ConservationConstrainedNN,
 generate_pendulum_data, compute_energy_violation) = safe_import_from(
    '07_physics_informed_ml.src.constrained_learning',
    'ConservationConfig', 'ConservationConstrainedNN',
    'generate_pendulum_data', 'compute_energy_violation'
)

plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 11
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3

reports_dir = Path.cwd().parent / 'reports'
reports_dir.mkdir(exist_ok=True)

SEED = 42
set_seed(SEED)
torch.manual_seed(SEED)

print("‚úì Setup complete")

---
## 1. The Physics Problem

### Simple Harmonic Oscillator (Conservative System)

$$\ddot{x} + \omega^2 x = 0$$

**State**: $(x, v)$ where $v = \dot{x}$

**Discrete dynamics** (Euler integration):
$$x_{n+1} = x_n + v_n \Delta t$$
$$v_{n+1} = v_n - \omega^2 x_n \Delta t$$

### Energy Conservation

$$E = \frac{1}{2}m v^2 + \frac{1}{2}k x^2 = \text{constant}$$

With $m=1$, $k=\omega^2$:
$$E = \frac{1}{2}(v^2 + \omega^2 x^2)$$

### Learning Task

**Input**: Current state $(x_n, v_n)$  
**Output**: Next state $(x_{n+1}, v_{n+1})$  
**Constraint**: $E(x_{n+1}, v_{n+1}) = E(x_n, v_n)$

### Why This Matters

Without energy conservation, a learned model will:
- Accumulate errors over time
- Eventually predict unphysical states (spiral in/out in phase space)
- Fail catastrophically for long rollouts

In [None]:
# === Generate synthetic data ===
omega = 1.0
dt = 0.1
noise_std = 0.02  # Observation noise

# Training data
X_train, y_train = generate_pendulum_data(
    n=2000, omega=omega, dt=dt, noise_std=noise_std, seed=SEED
)

# Test data (different seed for new initial conditions)
X_test, y_test = generate_pendulum_data(
    n=500, omega=omega, dt=dt, noise_std=0.0, seed=SEED+1  # No noise for clean eval
)

print(f"Training samples: {len(X_train)}")
print(f"Test samples: {len(X_test)}")
print(f"Noise std: {noise_std}")
print(f"Time step: {dt}")

# Visualize data distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

ax = axes[0]
ax.scatter(X_train[:, 0], X_train[:, 1], c='b', alpha=0.3, s=10)
ax.set_xlabel('Position x')
ax.set_ylabel('Velocity v')
ax.set_title('Training Data (Current State)')
ax.axis('equal')

ax = axes[1]
# Compute energy
E_train_in = 0.5 * (X_train[:, 1]**2 + omega**2 * X_train[:, 0]**2)
E_train_out = 0.5 * (y_train[:, 1]**2 + omega**2 * y_train[:, 0]**2)
ax.hist(E_train_out - E_train_in, bins=50, alpha=0.7, edgecolor='black')
ax.axvline(0, color='r', linestyle='--', lw=2)
ax.set_xlabel('ŒîE (output - input)')
ax.set_ylabel('Count')
ax.set_title(f'Energy Change Due to Noise (std={np.std(E_train_out - E_train_in):.4f})')

plt.tight_layout()
plt.savefig(reports_dir / '03_data_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

---
## 2. Conservation-Constrained Learning

### Loss Function

$$\mathcal{L} = \mathcal{L}_{\text{MSE}} + \lambda \mathcal{L}_{\text{conservation}}$$

where:
$$\mathcal{L}_{\text{MSE}} = \frac{1}{N} \sum_i \| \hat{y}_i - y_i \|^2$$

$$\mathcal{L}_{\text{conservation}} = \frac{1}{N} \sum_i (E(\hat{y}_i) - E(x_i))^2$$

### The $\lambda$ Trade-off

- $\lambda = 0$: Pure data fitting (unconstrained)
- $\lambda \to \infty$: Pure energy conservation (ignores data)
- Optimal $\lambda$: Balances both objectives

In [None]:
# === Train unconstrained model ===
set_seed(SEED)

config_unconstrained = ConservationConfig(
    input_dim=2,
    output_dim=2,
    hidden_dims=[64, 64],
    epochs=2000,
    lr=1e-3,
    batch_size=64,
    lambda_conservation=0.0,  # No constraint!
    conservation_type="energy",
)

model_unconstrained = ConservationConstrainedNN(config_unconstrained)
print("Training UNCONSTRAINED model (Œª=0)...")
history_unconstrained = model_unconstrained.train(
    X_train, y_train, X_test, y_test, verbose=500
)
print("‚úì Done")

In [None]:
# === Train constrained model ===
set_seed(SEED)

config_constrained = ConservationConfig(
    input_dim=2,
    output_dim=2,
    hidden_dims=[64, 64],
    epochs=2000,
    lr=1e-3,
    batch_size=64,
    lambda_conservation=10.0,  # Strong energy constraint
    conservation_type="energy",
)

model_constrained = ConservationConstrainedNN(config_constrained)
print("Training CONSTRAINED model (Œª=10)...")
history_constrained = model_constrained.train(
    X_train, y_train, X_test, y_test, verbose=500
)
print("‚úì Done")

In [None]:
# === Compare on test set ===
y_pred_unconstrained = model_unconstrained.predict(X_test)
y_pred_constrained = model_constrained.predict(X_test)

# MSE
mse_unconstrained = np.mean((y_pred_unconstrained - y_test)**2)
mse_constrained = np.mean((y_pred_constrained - y_test)**2)

# Energy violation
E_violation_unconstrained = compute_energy_violation(X_test, y_pred_unconstrained, spring_k=omega**2)
E_violation_constrained = compute_energy_violation(X_test, y_pred_constrained, spring_k=omega**2)

print("="*60)
print("TEST SET COMPARISON")
print("="*60)
print(f"{'Metric':<30} {'Unconstrained':>12} {'Constrained':>12}")
print("-"*60)
print(f"{'MSE':<30} {mse_unconstrained:>12.6f} {mse_constrained:>12.6f}")
print(f"{'Mean |ŒîE|':<30} {np.mean(E_violation_unconstrained):>12.6f} {np.mean(E_violation_constrained):>12.6f}")
print(f"{'Max |ŒîE|':<30} {np.max(E_violation_unconstrained):>12.6f} {np.max(E_violation_constrained):>12.6f}")
print("="*60)

In [None]:
# === Visualization ===
fig, axes = plt.subplots(2, 3, figsize=(15, 9))

# Training curves
ax = axes[0, 0]
ax.semilogy(history_unconstrained['loss'], 'b-', lw=2, label='Unconstrained')
ax.semilogy(history_constrained['loss'], 'r-', lw=2, label='Constrained')
ax.set_xlabel('Epoch')
ax.set_ylabel('Total Loss')
ax.set_title('Training Loss')
ax.legend()

# MSE only
ax = axes[0, 1]
ax.semilogy(history_unconstrained['loss_mse'], 'b-', lw=2, label='Unconstrained')
ax.semilogy(history_constrained['loss_mse'], 'r-', lw=2, label='Constrained')
ax.set_xlabel('Epoch')
ax.set_ylabel('MSE Loss')
ax.set_title('Data Fitting Loss')
ax.legend()

# Conservation loss
ax = axes[0, 2]
ax.semilogy(history_constrained['loss_conservation'], 'r-', lw=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Conservation Loss')
ax.set_title('Energy Conservation Loss (Constrained)')

# Energy violation histogram
ax = axes[1, 0]
ax.hist(E_violation_unconstrained, bins=30, alpha=0.7, label='Unconstrained', color='b')
ax.hist(E_violation_constrained, bins=30, alpha=0.7, label='Constrained', color='r')
ax.set_xlabel('|ŒîE|')
ax.set_ylabel('Count')
ax.set_title('Energy Violation Distribution (Test)')
ax.legend()

# Phase space predictions
ax = axes[1, 1]
idx_sample = np.random.choice(len(X_test), 100, replace=False)
ax.scatter(y_test[idx_sample, 0], y_test[idx_sample, 1], 
           c='k', s=30, alpha=0.5, label='True')
ax.scatter(y_pred_unconstrained[idx_sample, 0], y_pred_unconstrained[idx_sample, 1],
           c='b', s=30, alpha=0.5, marker='x', label='Unconstrained')
ax.scatter(y_pred_constrained[idx_sample, 0], y_pred_constrained[idx_sample, 1],
           c='r', s=30, alpha=0.5, marker='+', label='Constrained')
ax.set_xlabel('x_next')
ax.set_ylabel('v_next')
ax.set_title('Prediction Samples in Phase Space')
ax.legend()
ax.axis('equal')

# Prediction error vs energy violation
ax = axes[1, 2]
pred_error_unconstrained = np.linalg.norm(y_pred_unconstrained - y_test, axis=1)
pred_error_constrained = np.linalg.norm(y_pred_constrained - y_test, axis=1)
ax.scatter(E_violation_unconstrained, pred_error_unconstrained, 
           c='b', alpha=0.3, s=10, label='Unconstrained')
ax.scatter(E_violation_constrained, pred_error_constrained,
           c='r', alpha=0.3, s=10, label='Constrained')
ax.set_xlabel('|ŒîE| (Energy Violation)')
ax.set_ylabel('Prediction Error')
ax.set_title('Trade-off: Error vs Energy Violation')
ax.legend()

plt.tight_layout()
plt.savefig(reports_dir / '03_constrained_vs_unconstrained.png', dpi=150, bbox_inches='tight')
plt.show()

---
## 3. Experiment: The $\lambda$ Sweep

**Question**: What is the optimal trade-off between data fit and physics constraint?

In [None]:
# === Lambda sweep ===
lambda_values = [0.0, 0.1, 1.0, 5.0, 10.0, 50.0, 100.0]
results_lambda = []

print("Experiment: Effect of Œª_conservation")
print("="*70)

for lam in lambda_values:
    set_seed(SEED)
    
    config_test = ConservationConfig(
        input_dim=2, output_dim=2,
        hidden_dims=[64, 64],
        epochs=2000,
        lr=1e-3,
        batch_size=64,
        lambda_conservation=lam,
    )
    
    model_test = ConservationConstrainedNN(config_test)
    _ = model_test.train(X_train, y_train, verbose=0)
    
    y_pred = model_test.predict(X_test)
    mse = np.mean((y_pred - y_test)**2)
    E_viol = compute_energy_violation(X_test, y_pred, spring_k=omega**2)
    
    results_lambda.append({
        'lambda': lam,
        'mse': mse,
        'mean_E_viol': np.mean(E_viol),
        'max_E_viol': np.max(E_viol),
    })
    
    print(f"Œª={lam:6.1f} | MSE={mse:.6f} | Mean|ŒîE|={np.mean(E_viol):.6f} | Max|ŒîE|={np.max(E_viol):.6f}")

# Plot Pareto front
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

mses = [r['mse'] for r in results_lambda]
E_viols = [r['mean_E_viol'] for r in results_lambda]
lambdas = [r['lambda'] for r in results_lambda]

# Pareto front
ax = axes[0]
scatter = ax.scatter(mses, E_viols, c=lambdas, cmap='viridis', s=100, edgecolor='black')
ax.plot(mses, E_viols, 'k--', alpha=0.5)
for i, lam in enumerate(lambdas):
    ax.annotate(f'Œª={lam}', (mses[i], E_viols[i]), fontsize=9, 
                xytext=(5, 5), textcoords='offset points')
plt.colorbar(scatter, ax=ax, label='Œª')
ax.set_xlabel('MSE (Data Fit)')
ax.set_ylabel('Mean |ŒîE| (Energy Violation)')
ax.set_title('Trade-off Curve (Pareto Front)')

# Lambda vs metrics
ax = axes[1]
ax.semilogx([max(l, 0.01) for l in lambdas], mses, 'bo-', lw=2, ms=8, label='MSE')
ax2 = ax.twinx()
ax2.semilogx([max(l, 0.01) for l in lambdas], E_viols, 'rs-', lw=2, ms=8, label='Mean|ŒîE|')
ax.set_xlabel('Œª_conservation')
ax.set_ylabel('MSE', color='b')
ax2.set_ylabel('Mean |ŒîE|', color='r')
ax.set_title('MSE and Energy Violation vs Œª')
ax.legend(loc='upper left')
ax2.legend(loc='upper right')

plt.tight_layout()
plt.savefig(reports_dir / '03_lambda_sweep.png', dpi=150, bbox_inches='tight')
plt.show()

---
## 4. Long-Horizon Rollout Test

**Key test**: Does the constraint help with multi-step predictions?

In [None]:
# === Long rollout comparison ===
def rollout(model, x0, v0, n_steps, omega=1.0, dt=0.1):
    """Multi-step prediction by iteratively applying the model."""
    trajectory = [(x0, v0)]
    x, v = x0, v0
    
    for _ in range(n_steps):
        state = np.array([[x, v]])
        next_state = model.predict(state)[0]
        x, v = next_state[0], next_state[1]
        trajectory.append((x, v))
    
    return np.array(trajectory)

def true_rollout(x0, v0, n_steps, omega=1.0, dt=0.1):
    """Ground truth using analytical solution."""
    t = np.arange(n_steps + 1) * dt
    x = x0 * np.cos(omega * t) + (v0 / omega) * np.sin(omega * t)
    v = -x0 * omega * np.sin(omega * t) + v0 * np.cos(omega * t)
    return np.stack([x, v], axis=1)

# Initial condition
x0, v0 = 1.0, 0.0
n_steps = 200  # 20 seconds with dt=0.1

# Rollouts
traj_true = true_rollout(x0, v0, n_steps, omega, dt)
traj_unconstrained = rollout(model_unconstrained, x0, v0, n_steps, omega, dt)
traj_constrained = rollout(model_constrained, x0, v0, n_steps, omega, dt)

# Compute energies
E_true = 0.5 * (traj_true[:, 1]**2 + omega**2 * traj_true[:, 0]**2)
E_unconstrained = 0.5 * (traj_unconstrained[:, 1]**2 + omega**2 * traj_unconstrained[:, 0]**2)
E_constrained = 0.5 * (traj_constrained[:, 1]**2 + omega**2 * traj_constrained[:, 0]**2)

t_rollout = np.arange(n_steps + 1) * dt

# Visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Position over time
ax = axes[0, 0]
ax.plot(t_rollout, traj_true[:, 0], 'k-', lw=2, label='True')
ax.plot(t_rollout, traj_unconstrained[:, 0], 'b--', lw=2, label='Unconstrained')
ax.plot(t_rollout, traj_constrained[:, 0], 'r:', lw=2.5, label='Constrained')
ax.set_xlabel('Time')
ax.set_ylabel('Position x')
ax.set_title('Long Horizon Rollout (Position)')
ax.legend()

# Phase space
ax = axes[0, 1]
ax.plot(traj_true[:, 0], traj_true[:, 1], 'k-', lw=2, label='True')
ax.plot(traj_unconstrained[:, 0], traj_unconstrained[:, 1], 'b--', lw=2, label='Unconstrained')
ax.plot(traj_constrained[:, 0], traj_constrained[:, 1], 'r:', lw=2.5, label='Constrained')
ax.plot(x0, v0, 'go', ms=10)
ax.set_xlabel('Position x')
ax.set_ylabel('Velocity v')
ax.set_title('Phase Space')
ax.legend()
ax.axis('equal')

# Energy
ax = axes[1, 0]
ax.plot(t_rollout, E_true, 'k-', lw=2, label='True')
ax.plot(t_rollout, E_unconstrained, 'b--', lw=2, label='Unconstrained')
ax.plot(t_rollout, E_constrained, 'r:', lw=2.5, label='Constrained')
ax.axhline(E_true[0], color='gray', linestyle=':', alpha=0.5)
ax.set_xlabel('Time')
ax.set_ylabel('Energy E')
ax.set_title('Energy Over Time')
ax.legend()

# Cumulative error
ax = axes[1, 1]
error_unconstrained = np.linalg.norm(traj_unconstrained - traj_true, axis=1)
error_constrained = np.linalg.norm(traj_constrained - traj_true, axis=1)
ax.semilogy(t_rollout, error_unconstrained, 'b-', lw=2, label='Unconstrained')
ax.semilogy(t_rollout, error_constrained, 'r-', lw=2, label='Constrained')
ax.set_xlabel('Time')
ax.set_ylabel('Trajectory Error')
ax.set_title('Error Accumulation Over Rollout')
ax.legend()

plt.tight_layout()
plt.savefig(reports_dir / '03_long_horizon_rollout.png', dpi=150, bbox_inches='tight')
plt.show()

# Summary
print("="*60)
print("LONG HORIZON ROLLOUT SUMMARY")
print("="*60)
print(f"{'Metric':<30} {'Unconstrained':>12} {'Constrained':>12}")
print("-"*60)
print(f"{'Final trajectory error':<30} {error_unconstrained[-1]:>12.4f} {error_constrained[-1]:>12.4f}")
print(f"{'Final energy deviation':<30} {np.abs(E_unconstrained[-1] - E_true[0]):>12.4f} {np.abs(E_constrained[-1] - E_true[0]):>12.4f}")
print(f"{'Mean energy deviation':<30} {np.mean(np.abs(E_unconstrained - E_true[0])):>12.4f} {np.mean(np.abs(E_constrained - E_true[0])):>12.4f}")
print("="*60)

---
## 5. Summary Results

In [None]:
# === Save summary ===
summary = f"""
# Constrained Learning: Conservation Penalties - Results Summary

**Date**: {time.strftime('%Y-%m-%d %H:%M')}
**Seed**: {SEED}

## Problem Setup

- System: Harmonic oscillator (œâ={omega})
- Time step: dt={dt}
- Training noise: std={noise_std}
- Training samples: {len(X_train)}

## Single-Step Comparison

| Metric | Unconstrained | Constrained (Œª=10) |
|--------|---------------|--------------------|
| MSE | {mse_unconstrained:.6f} | {mse_constrained:.6f} |
| Mean |ŒîE| | {np.mean(E_violation_unconstrained):.6f} | {np.mean(E_violation_constrained):.6f} |
| Max |ŒîE| | {np.max(E_violation_unconstrained):.6f} | {np.max(E_violation_constrained):.6f} |

## Lambda Sweep Results

| Œª | MSE | Mean |ŒîE| |
|---|-----|----------|
""" + "\n".join([f"| {r['lambda']:.1f} | {r['mse']:.6f} | {r['mean_E_viol']:.6f} |" for r in results_lambda]) + f"""

## Long Horizon Rollout ({n_steps} steps)

| Metric | Unconstrained | Constrained |
|--------|---------------|-------------|
| Final trajectory error | {error_unconstrained[-1]:.4f} | {error_constrained[-1]:.4f} |
| Final energy deviation | {np.abs(E_unconstrained[-1] - E_true[0]):.4f} | {np.abs(E_constrained[-1] - E_true[0]):.4f} |

## Key Findings

1. **Constraint helps long-horizon predictions**: Unconstrained model drifts, constrained stays on orbit
2. **Trade-off exists**: Higher Œª ‚Üí worse single-step MSE but better energy conservation
3. **Optimal Œª depends on use case**: Short predictions ‚Üí low Œª; long rollouts ‚Üí high Œª
4. **Phase space behavior**: Unconstrained spirals in/out; constrained stays on ellipse

## Failure Modes

- **Œª too high**: Underfits data, predictions may be smooth but wrong
- **Œª too low**: Good single-step fit but catastrophic long-term drift
- **Noisy data**: Harder to satisfy exact conservation; consider relaxed constraints
"""

with open(reports_dir / '03_constrained_learning_summary.md', 'w') as f:
    f.write(summary)

print(summary)
print(f"\n‚úì Summary saved to {reports_dir / '03_constrained_learning_summary.md'}")

---
## 6. Failure Modes & Debugging

### When Conservation Constraints Help

‚úÖ **Good for**:
- Long-horizon rollouts (iterative predictions)
- Noisy data where physics provides regularization
- Out-of-distribution generalization (physics is always valid)
- Physical plausibility requirements

### When Conservation Constraints Hurt

‚ùå **Problematic when**:
- System is NOT conservative (dissipation, external forcing)
- Œª too high ‚Üí underfits data
- Constraint is approximate (numerical errors, discrete time)
- Single-step accuracy is paramount

### Debugging Checklist

1. **Check constraint formulation**: Is the physics formula correct?
2. **Monitor both losses**: MSE and conservation should both decrease
3. **Verify constraint satisfaction**: Compute |ŒîE| on validation set
4. **Test on clean data first**: Remove noise to isolate constraint effects
5. **Tune Œª on validation set**: Not too high, not too low

---
## 7. Mini Exercises

**Exercise 1**: Train with much higher noise (noise_std=0.1). Does the constraint help more?

**Exercise 2**: Implement momentum conservation for a 2-body system.

**Exercise 3**: Try a 3-layer network. Does it change the optimal Œª?

**Exercise 4**: Add damping to the system (not conservative). What happens to the constraint?

**Exercise 5**: Implement a "soft" constraint that allows energy to change by at most Œµ.

In [None]:
# === Exercise 1: High noise regime ===
# YOUR CODE HERE


In [None]:
# === Exercise 2: Momentum conservation ===
# YOUR CODE HERE


In [None]:
# === Exercise 3: Deeper network ===
# YOUR CODE HERE


In [None]:
# === Exercise 4: Non-conservative system ===
# YOUR CODE HERE


In [None]:
# === Exercise 5: Soft constraint ===
# YOUR CODE HERE


---
## Solutions

In [None]:
# === Solution 1: High noise ===
# Uncomment to see solution:

# X_noisy, y_noisy = generate_pendulum_data(n=2000, omega=1.0, dt=0.1, 
#                                           noise_std=0.1, seed=42)  # Higher noise!
# 
# for lam in [0.0, 10.0]:
#     set_seed(42)
#     config_ex = ConservationConfig(hidden_dims=[64,64], epochs=2000, lambda_conservation=lam)
#     model_ex = ConservationConstrainedNN(config_ex)
#     _ = model_ex.train(X_noisy, y_noisy, verbose=0)
#     
#     traj = rollout(model_ex, 1.0, 0.0, 100)
#     E_traj = 0.5 * (traj[:, 1]**2 + traj[:, 0]**2)
#     print(f"Œª={lam}: Final energy deviation = {abs(E_traj[-1] - 0.5):.4f}")

---
## Key Takeaways

### ‚úÖ What We Learned

1. **Conservation constraints as regularizers** improve physical plausibility
2. **Trade-off exists**: Data fit vs constraint satisfaction
3. **Long-horizon critical**: Constraints prevent error accumulation
4. **Œª tuning is essential**: Too high ‚Üí underfit, too low ‚Üí violate physics

### ‚ö†Ô∏è Limitations

1. **Only works for conservative systems** (or known dissipation patterns)
2. **Soft constraints** may still violate physics slightly
3. **Discrete time** introduces small conservation errors even in truth
4. **Multiple conservation laws** need careful balancing

### üí° When to Use Conservation Constraints?

- **Always for long rollouts**: Prevents catastrophic drift
- **When physics is well-known**: Energy, momentum, mass, etc.
- **As regularization**: When data is noisy or limited
- **NOT when**: System is dissipative, or constraint is approximate