In [None]:
import os
os.makedirs('outputs', exist_ok=True)

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# ==================== LOAD SYNTHETIC DATA ====================
data = np.load('data/synthetic_m.npz')
t = data['t']
V = data['V']
m_true = data['m_true']

print("Data loaded!")
print(f"   Time points: {len(t)}")
print(f"   Training on (t, V) â†’ m(t)")

# Prepare tensors
X = torch.tensor(np.stack([t, V], axis=1), dtype=torch.float32)  # shape: (1000, 2)
y = torch.tensor(m_true, dtype=torch.float32).unsqueeze(1)        # shape: (1000, 1)

# ==================== SIMPLE DATA-DRIVEN NN ====================
class DataDrivenNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.net(x)

model = DataDrivenNN()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# ==================== TRAINING LOOP ====================
epochs = 2000
losses = []

print("Training data-driven baseline (no physics)...")
for epoch in range(epochs):
    optimizer.zero_grad()
    pred = model(X)
    loss = criterion(pred, y)
    loss.backward()
    optimizer.step()

    losses.append(loss.item())
    if epoch % 400 == 0:
        print(f"   Epoch {epoch:4d} | Loss: {loss.item():.6f}")

# ==================== PLOTS ====================
m_pred = model(X).detach().numpy().flatten()

fig, axs = plt.subplots(2, 1, figsize=(10, 8))

axs[0].plot(t, V, 'b-', label='Voltage protocol')
axs[0].set_ylabel('Voltage (mV)')
axs[0].legend()
axs[0].grid(True)

axs[1].plot(t, m_true, 'r-', label='True m(t)', linewidth=2)
axs[1].plot(t, m_pred, 'g--', label='Data-driven NN prediction', linewidth=2)
axs[1].set_xlabel('Time (ms)')
axs[1].set_ylabel('Gating variable m(t)')
axs[1].legend()
axs[1].grid(True)

plt.suptitle('Data-Driven Neural Network Baseline (no physics prior)')
plt.tight_layout()
plt.savefig('outputs/02_data_driven_baseline.png', dpi=200)
plt.show()

print("Training complete! Plot saved to outputs/02_data_driven_baseline.png")
print(f"   Final MSE: {losses[-1]:.6f}")