# 03 — Physics-Informed Neural Network (PINN)

Trains a PINN that embeds the Hodgkin-Huxley ODE residual directly into the loss function.  
Uses **Tanh** activations (smooth, nonzero higher-order derivatives — critical for autograd-based physics loss).  
Same 80/20 train/test split as the baseline for fair comparison.

In [None]:
import os
os.makedirs('outputs', exist_ok=True)

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# ==================== LOAD DATA ====================
data = np.load('data/synthetic_m.npz')
t = data['t']
V = data['V']
m_true = data['m_true']

# ==================== SAME TRAIN / TEST SPLIT ====================
split_data = np.load('data/split_indices.npz')
train_idx = split_data['train_idx']
test_idx  = split_data['test_idx']

X_all = np.stack([t, V], axis=1)

X_train = torch.tensor(X_all[train_idx], dtype=torch.float32)
y_train = torch.tensor(m_true[train_idx], dtype=torch.float32).unsqueeze(1)
X_test  = torch.tensor(X_all[test_idx],  dtype=torch.float32)
y_test  = torch.tensor(m_true[test_idx],  dtype=torch.float32).unsqueeze(1)
X_full  = torch.tensor(X_all, dtype=torch.float32)

print(f'Training samples: {len(train_idx)}')
print(f'Test samples:     {len(test_idx)}')

In [None]:
# ==================== HH RATE FUNCTIONS (PyTorch) ====================
def alpha_m(V):
    return 0.1 * (V + 40.0) / (1.0 - torch.exp(-(V + 40.0) / 10.0))

def beta_m(V):
    return 4.0 * torch.exp(-(V + 65.0) / 18.0)

# ==================== PINN MODEL (Tanh activations) ====================
class PINN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.net(x)

    def physics_residual(self, x):
        """Compute ODE residual: dm/dt - [alpha(1-m) - beta*m]"""
        x = x.clone().requires_grad_(True)
        m_pred = self(x)
        dm_dt = torch.autograd.grad(
            m_pred, x, grad_outputs=torch.ones_like(m_pred),
            create_graph=True
        )[0][:, 0:1]  # derivative w.r.t. t (first input column)

        V = x[:, 1:2]
        residual = dm_dt - (alpha_m(V) * (1 - m_pred) - beta_m(V) * m_pred)
        return residual

In [None]:
# ==================== TRAINING ====================
torch.manual_seed(0)
model = PINN()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
lambda_phys = 0.1  # physics loss weight

epochs = 3000
train_losses = []
test_losses  = []
data_losses  = []
phys_losses  = []

# Collocation points: evaluate physics on the FULL time domain
# (physics doesn't need labels, just the ODE constraint)
X_colloc = X_full.clone()

print('Training PINN (physics embedded in loss)...')
for epoch in range(epochs):
    optimizer.zero_grad()

    # Data loss — only on training points
    m_pred_train = model(X_train)
    data_loss = criterion(m_pred_train, y_train)

    # Physics loss — evaluated on all collocation points
    phys_res = model.physics_residual(X_colloc)
    phys_loss = torch.mean(phys_res ** 2)

    loss = data_loss + lambda_phys * phys_loss
    loss.backward()
    optimizer.step()

    train_losses.append(loss.item())
    data_losses.append(data_loss.item())
    phys_losses.append(phys_loss.item())

    with torch.no_grad():
        test_losses.append(criterion(model(X_test), y_test).item())

    if epoch % 500 == 0:
        print(f'   Epoch {epoch:4d} | Total: {loss.item():.6f} | Data: {data_loss.item():.6f} | Phys: {phys_loss.item():.6f} | Test: {test_losses[-1]:.6f}')

print(f'\nFinal Total Loss: {train_losses[-1]:.6f}')
print(f'Final Data Loss:  {data_losses[-1]:.6f}')
print(f'Final Phys Loss:  {phys_losses[-1]:.6f}')
print(f'Final Test MSE:   {test_losses[-1]:.6f}')

In [None]:
# ==================== PLOTS ====================
m_pred = model(X_full).detach().numpy().flatten()

fig, axs = plt.subplots(3, 1, figsize=(10, 10))

axs[0].plot(t, V, 'b-', label='Voltage')
axs[0].set_ylabel('Voltage (mV)')
axs[0].legend(); axs[0].grid(True)

axs[1].plot(t, m_true, 'r-', label='True m(t)', lw=2)
axs[1].plot(t, m_pred, 'g--', label='PINN prediction', lw=2)
axs[1].scatter(t[test_idx], m_true[test_idx], c='orange', s=8, zorder=5, label='Test points')
axs[1].set_ylabel('Gating variable m(t)')
axs[1].legend(); axs[1].grid(True)

axs[2].plot(train_losses, label='Total Loss')
axs[2].plot(data_losses, label='Data Loss')
axs[2].plot(phys_losses, label='Physics Loss')
axs[2].plot(test_losses, label='Test Loss', alpha=0.7)
axs[2].set_xlabel('Epoch'); axs[2].set_ylabel('Loss')
axs[2].legend(); axs[2].grid(True)

plt.suptitle('Physics-Informed Neural Network — Tanh activations, ODE constraint')
plt.tight_layout()
plt.savefig('outputs/03_pinn_results.png', dpi=200)
plt.show()
print('Plot saved to outputs/03_pinn_results.png')

In [None]:
# ==================== SAVE MODEL + METRICS ====================
torch.save(model.state_dict(), 'outputs/03_pinn_weights.pt')

np.savez('outputs/03_metrics.npz',
         train_losses=np.array(train_losses),
         test_losses=np.array(test_losses),
         data_losses=np.array(data_losses),
         phys_losses=np.array(phys_losses),
         m_pred_full=m_pred,
         train_idx=train_idx,
         test_idx=test_idx)
print('Model weights and metrics saved.')