In [None]:
import os
os.makedirs('outputs', exist_ok=True)

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# ==================== LOAD DATA ====================
data = np.load('data/synthetic_m.npz')
t = data['t']
V = data['V']
m_true = data['m_true']

# Tensors
X = torch.tensor(np.stack([t, V], axis=1), dtype=torch.float32)  # (1000, 2)
y = torch.tensor(m_true, dtype=torch.float32).unsqueeze(1)        # (1000, 1)

# ==================== HH RATE FUNCTIONS ====================
def alpha_m(V):
    return 0.1 * (V + 40.0) / (1.0 - torch.exp(-(V + 40.0) / 10.0))

def beta_m(V):
    return 4.0 * torch.exp(-(V + 65.0) / 18.0)

# ==================== PINN MODEL ====================
class PINN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.net(x)

    def physics_residual(self, x):
        x.requires_grad_(True)
        m_pred = self(x)
        # Autodiff: compute dm/dt
        dm_dt = torch.autograd.grad(m_pred.sum(), x, create_graph=True)[0][:, 0:1]

        V = x[:, 1:2]
        residual = dm_dt - (alpha_m(V) * (1 - m_pred) - beta_m(V) * m_pred)
        return residual

model = PINN()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
lambda_phys = 0.1   # physics loss weight (tune later if needed)

# ==================== TRAINING ====================
epochs = 3000
losses = []
data_losses = []
phys_losses = []

print("Training PINN (physics embedded in loss)...")
for epoch in range(epochs):
    optimizer.zero_grad()

    m_pred = model(X)
    data_loss = criterion(m_pred, y)

    phys_res = model.physics_residual(X)
    phys_loss = torch.mean(phys_res**2)

    loss = data_loss + lambda_phys * phys_loss

    loss.backward()
    optimizer.step()

    losses.append(loss.item())
    data_losses.append(data_loss.item())
    phys_losses.append(phys_loss.item())

    if epoch % 500 == 0:
        print(f"   Epoch {epoch:4d} | Total Loss: {loss.item():.6f} | Data: {data_loss.item():.6f} | Phys: {phys_loss.item():.6f}")

# ==================== PLOTS ====================
m_pred = model(X).detach().numpy().flatten()

fig, axs = plt.subplots(3, 1, figsize=(10, 10))

axs[0].plot(t, V, 'b-', label='Voltage')
axs[0].set_ylabel('Voltage (mV)')
axs[0].legend(); axs[0].grid(True)

axs[1].plot(t, m_true, 'r-', label='True m(t)', lw=2)
axs[1].plot(t, m_pred, 'g--', label='PINN prediction', lw=2)
axs[1].set_ylabel('Gating variable m(t)')
axs[1].legend(); axs[1].grid(True)

axs[2].plot(losses, label='Total Loss')
axs[2].plot(data_losses, label='Data Loss')
axs[2].plot(phys_losses, label='Physics Loss')
axs[2].set_xlabel('Epoch')
axs[2].set_ylabel('Loss')
axs[2].legend(); axs[2].grid(True)

plt.suptitle('Physics-Informed Neural Network (PINN) â€” embeds the ODE directly')
plt.tight_layout()
plt.savefig('outputs/03_pinn_results.png', dpi=200)
plt.show()

print("PINN training complete!")
print(f"   Final Total Loss: {losses[-1]:.6f}")
print(f"   Final Data Loss:  {data_losses[-1]:.6f}")
print("Plot saved to outputs/03_pinn_results.png")