# Physics-Informed Neural Networks (PINN) for Parameter Estimation

## Learning Objectives

This tutorial covers:
1. **What is a PINN?** - Combining data with physics
2. **Training a PINN** - Creating synthetic datasets
3. **Parameter estimation** - Solving the inverse problem
4. **Uncertainty quantification** - Bootstrap confidence intervals
5. **Model validation** - Testing accuracy
6. **Comparison with Bayesian methods**

## What is a Physics-Informed Neural Network?

Traditional neural networks learn from data alone:
```
NN(data) → output
```

**PINNs** also enforce physical laws (ODEs, PDEs):
```
Total Loss = Data Loss + Physics Loss
           = ||NN(data) - measurements||² + ||dV/dt - f(V, params)||²
```

**Advantages:**
- More accurate with limited data
- Physically consistent predictions
- Better generalization
- Can solve inverse problems (data → hidden parameters)

In [None]:
# Imports
import numpy as np
import matplotlib.pyplot as plt
import sys

sys.path.insert(0, '..')

from ens_gi_core import ENSGIDigitalTwin

try:
    from ens_gi_pinn import PINNEstimator, PINNConfig
    import tensorflow as tf
    PINN_AVAILABLE = True
    print(f"TensorFlow version: {tf.__version__}")
except ImportError as e:
    PINN_AVAILABLE = False
    print(f"PINN not available: {e}")
    print("Install with: pip install tensorflow")

if PINN_AVAILABLE:
    print("PINN framework ready!")
else:
    print("This tutorial requires TensorFlow. Please install and restart.")

## Part 1: The Inverse Problem

**Forward Problem** (easy): Given parameters → predict observations
```python
twin.set_parameters(g_Na=120, g_K=36)
result = twin.run()
# We get voltage traces
```

**Inverse Problem** (hard): Given observations → estimate parameters
```python
# We have voltage traces from a patient
# What are g_Na and g_K?
pinn.estimate_parameters(voltage_data) → {g_Na: 85 ± 5, g_K: 42 ± 3}
```

This is what PINNs solve!

In [None]:
if not PINN_AVAILABLE:
    raise ImportError("TensorFlow required for this tutorial")

# Create ground truth patient with KNOWN parameters
print("Creating synthetic patient...")
true_g_Na = 100.0  # This is what we want to recover
true_g_K = 40.0
true_omega = 0.01

patient = ENSGIDigitalTwin(n_segments=10)

# Set known parameters
for neuron in patient.network.neurons:
    neuron.params.g_Na = true_g_Na
    neuron.params.g_K = true_g_K
patient.network.icc.omega = true_omega

# Generate "clinical" data
print("Simulating clinical measurements...")
result = patient.run(1500, dt=0.1, I_stim={4: 10.0}, verbose=False)

print(f"\nGenerated data:")
print(f"  Voltage traces: {result['voltages'].shape}")
print(f"  Force traces: {result['forces'].shape}")
print(f"  Calcium traces: {result['calcium'].shape}")

print(f"\nTrue parameters (hidden from PINN):")
print(f"  g_Na = {true_g_Na}")
print(f"  g_K = {true_g_K}")
print(f"  omega = {true_omega}")

## Part 2: Configure and Create PINN

In [None]:
# Configure PINN architecture
config = PINNConfig(
    architecture='resnet',  # ResNet better for deep networks
    hidden_dims=[128, 64, 32],  # 3 hidden layers
    activation='tanh',  # Smooth activation
    learning_rate=0.001,  # Adam learning rate
    lambda_data=1.0,  # Data loss weight
    lambda_physics=0.1,  # Physics loss weight (10%)
    batch_size=32,
    validation_split=0.2
)

# Create PINN estimator
twin_for_pinn = ENSGIDigitalTwin(n_segments=10)
pinn = PINNEstimator(
    digital_twin=twin_for_pinn,
    config=config,
    parameter_names=['g_Na', 'g_K', 'omega']
)

print("PINN Model Summary:")
print(f"  Architecture: {config.architecture}")
print(f"  Parameters to estimate: {pinn.parameter_names}")
print(f"  Total model parameters: {pinn.model.count_params():,}")

## Part 3: Generate Training Dataset

PINN needs training data with known parameter values.

In [None]:
# Generate synthetic training dataset
print("Generating training dataset (this may take 1-2 minutes)...")
dataset = pinn.generate_synthetic_dataset(n_samples=500)

print(f"\nDataset created:")
print(f"  Features shape: {dataset['features'].shape}")
print(f"  Parameters shape: {dataset['parameters'].shape}")
print(f"  Number of samples: {len(dataset['parameters'])}")

# Show parameter distribution in training set
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

param_names = ['g_Na', 'g_K', 'omega']
for idx, (ax, name) in enumerate(zip(axes, param_names)):
    ax.hist(dataset['parameters'][:, idx], bins=30, alpha=0.7, edgecolor='black')
    ax.set_xlabel(name, fontsize=12)
    ax.set_ylabel('Count', fontsize=12)
    ax.set_title(f'{name} Distribution in Training Set', fontweight='bold')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Part 4: Train the PINN

Training minimizes: `Total Loss = Data Loss + Physics Loss`

In [None]:
# Train PINN
print("Training PINN (this will take 2-3 minutes)...\n")
history = pinn.train(
    features=dataset['features'],
    parameters=dataset['parameters'],
    epochs=1000,
    verbose=True
)

# Plot training history
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Total loss
axes[0].plot(history['loss'], label='Training', linewidth=2)
axes[0].plot(history['val_loss'], label='Validation', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Total Loss')
axes[0].set_title('Training Progress', fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[0].set_yscale('log')

# Data loss
axes[1].plot(history['data_loss'], linewidth=2, color='green')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Data Loss')
axes[1].set_title('Data Fitting', fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].set_yscale('log')

# Physics loss
axes[2].plot(history['physics_loss'], linewidth=2, color='red')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Physics Loss')
axes[2].set_title('Physics Constraint', fontweight='bold')
axes[2].grid(True, alpha=0.3)
axes[2].set_yscale('log')

plt.tight_layout()
plt.show()

print(f"\nFinal training loss: {history['loss'][-1]:.6f}")
print(f"Final validation loss: {history['val_loss'][-1]:.6f}")

## Part 5: Estimate Patient Parameters

Now use the trained PINN to estimate parameters from "clinical" data.

In [None]:
# Estimate parameters with uncertainty quantification
print("Estimating parameters with bootstrap uncertainty...")
estimates = pinn.estimate_parameters(
    voltages=result['voltages'],
    forces=result['forces'],
    calcium=result['calcium'],
    n_bootstrap=50  # 50 bootstrap samples
)

# Display results
print("\n" + "="*70)
print("PINN PARAMETER ESTIMATION RESULTS")
print("="*70)
print(f"\n{'Parameter':<12} {'True Value':<15} {'Estimated':<20} {'Error':<10}")
print("-"*70)

true_values = {'g_Na': true_g_Na, 'g_K': true_g_K, 'omega': true_omega}

for param in ['g_Na', 'g_K', 'omega']:
    if param in estimates:
        est = estimates[param]
        true_val = true_values[param]
        error = abs(est['mean'] - true_val) / true_val * 100
        
        print(f"{param:<12} {true_val:<15.4f} "
              f"{est['mean']:.4f} ± {est['std']:.4f}  "
              f"{error:>6.1f}%")

print("="*70)

# Success criteria
avg_error = np.mean([abs(estimates[p]['mean'] - true_values[p]) / true_values[p] * 100 
                     for p in ['g_Na', 'g_K', 'omega'] if p in estimates])

print(f"\nAverage Error: {avg_error:.2f}%")
if avg_error < 10:
    print("SUCCESS: Error < 10% (target achieved!)")
elif avg_error < 20:
    print("GOOD: Error < 20% (acceptable for clinical use)")
else:
    print("NEEDS IMPROVEMENT: Error > 20%")

## Part 6: Visualize Uncertainty

In [None]:
# Plot parameter estimates with confidence intervals
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, (param, ax) in enumerate(zip(['g_Na', 'g_K', 'omega'], axes)):
    if param in estimates:
        est = estimates[param]
        true_val = true_values[param]
        
        # Plot estimate with error bars
        ax.errorbar([0], [est['mean']], yerr=[2*est['std']], 
                   fmt='o', markersize=12, capsize=10, capthick=2,
                   label='PINN Estimate', color='blue')
        
        # Plot true value
        ax.axhline(y=true_val, color='red', linestyle='--', linewidth=2,
                  label='True Value')
        
        # Shaded confidence region
        ax.axhspan(est['mean'] - 2*est['std'], 
                  est['mean'] + 2*est['std'],
                  alpha=0.2, color='blue', label='95% CI')
        
        ax.set_ylabel(param, fontsize=14, fontweight='bold')
        ax.set_title(f"{param} Estimation", fontsize=14, fontweight='bold')
        ax.set_xlim(-0.5, 0.5)
        ax.set_xticks([])
        ax.legend(loc='upper right')
        ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Part 7: Test on IBS Profiles

Can PINN distinguish between healthy and IBS patients?

In [None]:
# Test on different IBS profiles
profiles_to_test = ['healthy', 'ibs_c', 'ibs_d']
profile_results = {}

print("Testing PINN on different IBS profiles...\n")

for profile in profiles_to_test:
    print(f"Testing {profile.upper()}...")
    
    # Create patient with profile
    test_twin = ENSGIDigitalTwin(n_segments=10)
    test_twin.apply_profile(profile)
    test_result = test_twin.run(1500, dt=0.1, verbose=False)
    
    # Estimate parameters
    test_estimates = pinn.estimate_parameters(
        voltages=test_result['voltages'],
        forces=test_result['forces'],
        calcium=test_result['calcium'],
        n_bootstrap=20
    )
    
    profile_results[profile] = test_estimates
    
    if 'g_Na' in test_estimates:
        print(f"  g_Na: {test_estimates['g_Na']['mean']:.2f} ± {test_estimates['g_Na']['std']:.2f}")

# Compare profiles
print("\nProfile Comparison:")
print(f"  Healthy g_Na: {profile_results['healthy']['g_Na']['mean']:.2f}")
print(f"  IBS-C g_Na: {profile_results['ibs_c']['g_Na']['mean']:.2f} (should be lower)")
print(f"  IBS-D g_Na: {profile_results['ibs_d']['g_Na']['mean']:.2f} (should be higher)")

## Part 8: Save and Load Models

In [None]:
# Save trained PINN
pinn.save('trained_pinn_model')
print("Model saved to: trained_pinn_model/")

# Load model
loaded_pinn = PINNEstimator.load('trained_pinn_model')
print("Model loaded successfully!")

# Verify it works
test_estimates = loaded_pinn.estimate_parameters(
    voltages=result['voltages'],
    forces=result['forces'],
    calcium=result['calcium']
)

print(f"\nLoaded model g_Na estimate: {test_estimates['g_Na']['mean']:.2f}")
print(f"Original model g_Na estimate: {estimates['g_Na']['mean']:.2f}")
print("Models match!" if abs(test_estimates['g_Na']['mean'] - estimates['g_Na']['mean']) < 0.1 else "Models differ!")

## Summary

This tutorial demonstrated:

1. ✅ **PINN concept** - Combining data with physics
2. ✅ **Training** - Generated 500 synthetic samples
3. ✅ **Parameter estimation** - Solved inverse problem
4. ✅ **Uncertainty** - Bootstrap confidence intervals
5. ✅ **Validation** - Tested on multiple IBS profiles
6. ✅ **Persistence** - Save/load trained models

### Key Advantages of PINN:

- **Fast**: ~1-2 minutes for parameter estimation
- **Accurate**: Typically <10% error with good data
- **Uncertainty**: Bootstrap provides confidence intervals
- **Scalable**: Can handle high-dimensional parameter spaces
- **Physically consistent**: Enforces ODE constraints

### When to Use PINN vs Bayesian:

**Use PINN when:**
- You need fast results (<5 minutes)
- You have moderate amounts of data
- Point estimates + bootstrap CI are sufficient
- Multiple parameters to estimate (>5)

**Use Bayesian when:**
- You need full posterior distributions
- Clinical decisions require rigorous uncertainty
- Small dataset but strong priors
- Regulatory approval requires UQ

**Best approach:** Use PINN first, then refine with Bayesian!

### Next Steps:

- Try the Bayesian tutorial (`bayesian_tutorial.ipynb`)
- Explore clinical workflow (`clinical_workflow.ipynb`)
- Test on real patient data
- Publish results!