# 02 — Data-Driven Baseline Neural Network

Trains a standard feedforward NN on `(t, V) → m(t)` with **no physics prior**.  
Uses an 80/20 train/test split to measure generalization.

In [None]:
import os
os.makedirs('outputs', exist_ok=True)

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# ==================== LOAD SYNTHETIC DATA ====================
data = np.load('data/synthetic_m.npz')
t = data['t']
V = data['V']
m_true = data['m_true']

# ==================== TRAIN / TEST SPLIT (80/20) ====================
np.random.seed(42)
n = len(t)
idx = np.random.permutation(n)
split = int(0.8 * n)
train_idx, test_idx = idx[:split], idx[split:]

X_all = np.stack([t, V], axis=1)

X_train = torch.tensor(X_all[train_idx], dtype=torch.float32)
y_train = torch.tensor(m_true[train_idx], dtype=torch.float32).unsqueeze(1)
X_test  = torch.tensor(X_all[test_idx],  dtype=torch.float32)
y_test  = torch.tensor(m_true[test_idx],  dtype=torch.float32).unsqueeze(1)

print(f'Training samples: {len(train_idx)}')
print(f'Test samples:     {len(test_idx)}')

# Save split indices for consistency across notebooks
np.savez('data/split_indices.npz', train_idx=train_idx, test_idx=test_idx)

In [None]:
# ==================== DATA-DRIVEN NN ====================
class DataDrivenNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.net(x)

torch.manual_seed(0)
model = DataDrivenNN()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# ==================== TRAINING LOOP ====================
epochs = 2000
train_losses = []
test_losses  = []

print('Training data-driven baseline (no physics)...')
for epoch in range(epochs):
    optimizer.zero_grad()
    pred = model(X_train)
    loss = criterion(pred, y_train)
    loss.backward()
    optimizer.step()

    train_losses.append(loss.item())
    with torch.no_grad():
        test_losses.append(criterion(model(X_test), y_test).item())

    if epoch % 400 == 0:
        print(f'   Epoch {epoch:4d} | Train MSE: {train_losses[-1]:.6f} | Test MSE: {test_losses[-1]:.6f}')

print(f'\nFinal Train MSE: {train_losses[-1]:.6f}')
print(f'Final Test MSE:  {test_losses[-1]:.6f}')

In [None]:
# ==================== PLOTS ====================
X_full = torch.tensor(X_all, dtype=torch.float32)
m_pred = model(X_full).detach().numpy().flatten()

fig, axs = plt.subplots(3, 1, figsize=(10, 10))

axs[0].plot(t, V, 'b-', label='Voltage protocol')
axs[0].set_ylabel('Voltage (mV)')
axs[0].legend(); axs[0].grid(True)

axs[1].plot(t, m_true, 'r-', label='True m(t)', linewidth=2)
axs[1].plot(t, m_pred, 'g--', label='Data-driven prediction', linewidth=2)
axs[1].scatter(t[test_idx], m_true[test_idx], c='orange', s=8, zorder=5, label='Test points')
axs[1].set_xlabel('Time (ms)')
axs[1].set_ylabel('Gating variable m(t)')
axs[1].legend(); axs[1].grid(True)

axs[2].plot(train_losses, label='Train Loss')
axs[2].plot(test_losses, label='Test Loss', alpha=0.7)
axs[2].set_xlabel('Epoch'); axs[2].set_ylabel('MSE')
axs[2].legend(); axs[2].grid(True)

plt.suptitle('Data-Driven Neural Network Baseline (no physics prior)')
plt.tight_layout()
plt.savefig('outputs/02_data_driven_baseline.png', dpi=200)
plt.show()
print('Plot saved to outputs/02_data_driven_baseline.png')

In [None]:
# ==================== SAVE MODEL + METRICS ====================
torch.save(model.state_dict(), 'outputs/02_data_driven_weights.pt')

np.savez('outputs/02_metrics.npz',
         train_losses=np.array(train_losses),
         test_losses=np.array(test_losses),
         m_pred_full=m_pred,
         train_idx=train_idx,
         test_idx=test_idx)
print('Model weights and metrics saved.')