# 04 — Sparse-Data Ablation Study

**Core experiment:** How do PINNs and data-driven NNs compare as training data becomes scarce?

We train both models on subsets of 25, 50, 100, 200, 500, and 800 points,  
evaluate on a fixed 200-point test set, and plot generalization error vs. training set size.

Hypothesis: the PINN's physics constraint acts as a strong regularizer,  
so it should degrade much more gracefully than the data-driven model under sparse data.

In [None]:
import os
os.makedirs('outputs', exist_ok=True)

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from scipy.integrate import odeint

# ==================== LOAD DATA ====================
data = np.load('data/synthetic_m.npz')
t = data['t']
V = data['V']
m_true = data['m_true']

X_all = np.stack([t, V], axis=1)

# Fixed test set: 200 randomly chosen points, held constant across all experiments
np.random.seed(42)
all_idx = np.random.permutation(len(t))
test_idx = all_idx[:200]
pool_idx = all_idx[200:]  # remaining 800 points to subsample from

X_test = torch.tensor(X_all[test_idx], dtype=torch.float32)
y_test = torch.tensor(m_true[test_idx], dtype=torch.float32).unsqueeze(1)
X_full = torch.tensor(X_all, dtype=torch.float32)

print(f'Fixed test set: {len(test_idx)} points')
print(f'Training pool:  {len(pool_idx)} points')

In [None]:
# ==================== HH RATE FUNCTIONS (PyTorch) ====================
def alpha_m(V):
    return 0.1 * (V + 40.0) / (1.0 - torch.exp(-(V + 40.0) / 10.0))

def beta_m(V):
    return 4.0 * torch.exp(-(V + 65.0) / 18.0)

# ==================== MODEL DEFINITIONS ====================
class DataDrivenNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 64), nn.ReLU(),
            nn.Linear(64, 64), nn.ReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, x):
        return self.net(x)

class PINN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 64), nn.Tanh(),
            nn.Linear(64, 64), nn.Tanh(),
            nn.Linear(64, 1)
        )
    def forward(self, x):
        return self.net(x)

    def physics_residual(self, x):
        x = x.clone().requires_grad_(True)
        m_pred = self(x)
        dm_dt = torch.autograd.grad(
            m_pred, x, grad_outputs=torch.ones_like(m_pred),
            create_graph=True
        )[0][:, 0:1]
        V = x[:, 1:2]
        return dm_dt - (alpha_m(V) * (1 - m_pred) - beta_m(V) * m_pred)

In [None]:
# ==================== TRAINING FUNCTIONS ====================
def train_data_driven(X_train, y_train, epochs=2000, seed=0):
    torch.manual_seed(seed)
    model = DataDrivenNN()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        optimizer.zero_grad()
        loss = criterion(model(X_train), y_train)
        loss.backward()
        optimizer.step()

    with torch.no_grad():
        test_mse = criterion(model(X_test), y_test).item()
    return test_mse, model


def train_pinn(X_train, y_train, X_colloc, epochs=3000, lambda_phys=0.1, seed=0):
    torch.manual_seed(seed)
    model = PINN()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        optimizer.zero_grad()
        data_loss = criterion(model(X_train), y_train)
        phys_res = model.physics_residual(X_colloc)
        phys_loss = torch.mean(phys_res ** 2)
        loss = data_loss + lambda_phys * phys_loss
        loss.backward()
        optimizer.step()

    with torch.no_grad():
        test_mse = criterion(model(X_test), y_test).item()
    return test_mse, model

In [None]:
# ==================== ABLATION SWEEP ====================
train_sizes = [25, 50, 100, 200, 500, 800]
n_trials = 3  # average over random subsets for robustness

dd_results = {n: [] for n in train_sizes}
pinn_results = {n: [] for n in train_sizes}

for n_train in train_sizes:
    print(f'\n===== Training size: {n_train} =====')
    for trial in range(n_trials):
        # Subsample from pool
        rng = np.random.RandomState(trial * 100 + n_train)
        sub_idx = rng.choice(pool_idx, size=n_train, replace=False)

        X_train = torch.tensor(X_all[sub_idx], dtype=torch.float32)
        y_train = torch.tensor(m_true[sub_idx], dtype=torch.float32).unsqueeze(1)

        # Data-driven
        dd_mse, _ = train_data_driven(X_train, y_train, seed=trial)
        dd_results[n_train].append(dd_mse)

        # PINN (collocation on full domain — physics doesn't need labels)
        pinn_mse, _ = train_pinn(X_train, y_train, X_full, seed=trial)
        pinn_results[n_train].append(pinn_mse)

        print(f'   Trial {trial+1}: DD={dd_mse:.6f}  PINN={pinn_mse:.6f}')

print('\nAblation sweep complete!')

In [None]:
# ==================== RESULTS TABLE ====================
dd_means  = [np.mean(dd_results[n]) for n in train_sizes]
dd_stds   = [np.std(dd_results[n]) for n in train_sizes]
pinn_means = [np.mean(pinn_results[n]) for n in train_sizes]
pinn_stds  = [np.std(pinn_results[n]) for n in train_sizes]

print(f'{"N_train":>8s}  {"DD MSE":>12s}  {"PINN MSE":>12s}  {"Improvement":>12s}')
print('-' * 52)
for i, n in enumerate(train_sizes):
    improv = (dd_means[i] - pinn_means[i]) / dd_means[i] * 100
    print(f'{n:>8d}  {dd_means[i]:>10.6f}±{dd_stds[i]:.4f}  {pinn_means[i]:>10.6f}±{pinn_stds[i]:.4f}  {improv:>+10.1f}%')

In [None]:
# ==================== MAIN PLOT: MSE vs TRAINING SIZE ====================
fig, ax = plt.subplots(figsize=(9, 6))

ax.errorbar(train_sizes, dd_means, yerr=dd_stds,
            marker='o', capsize=5, linewidth=2, markersize=8,
            label='Data-Driven NN (ReLU)', color='#e74c3c')
ax.errorbar(train_sizes, pinn_means, yerr=pinn_stds,
            marker='s', capsize=5, linewidth=2, markersize=8,
            label='PINN (Tanh + ODE constraint)', color='#2ecc71')

ax.set_xlabel('Number of Training Points', fontsize=13)
ax.set_ylabel('Test MSE', fontsize=13)
ax.set_title('Generalization Error vs. Training Set Size\n(Hodgkin-Huxley m-gate)', fontsize=14)
ax.legend(fontsize=12)
ax.set_xscale('log')
ax.set_yscale('log')
ax.grid(True, which='both', alpha=0.3)
ax.set_xticks(train_sizes)
ax.set_xticklabels(train_sizes)

plt.tight_layout()
plt.savefig('outputs/04_sparse_ablation.png', dpi=200)
plt.show()
print('Plot saved to outputs/04_sparse_ablation.png')

In [None]:
# ==================== BONUS: PREDICTION CURVES AT KEY SIZES ====================
highlight_sizes = [25, 100, 800]
fig, axs = plt.subplots(len(highlight_sizes), 2, figsize=(14, 4 * len(highlight_sizes)))

for row, n_train in enumerate(highlight_sizes):
    rng = np.random.RandomState(n_train)  # reproducible subset
    sub_idx = rng.choice(pool_idx, size=n_train, replace=False)
    X_tr = torch.tensor(X_all[sub_idx], dtype=torch.float32)
    y_tr = torch.tensor(m_true[sub_idx], dtype=torch.float32).unsqueeze(1)

    _, dd_model = train_data_driven(X_tr, y_tr, seed=99)
    _, pinn_model = train_pinn(X_tr, y_tr, X_full, seed=99)

    dd_pred   = dd_model(X_full).detach().numpy().flatten()
    pinn_pred = pinn_model(X_full).detach().numpy().flatten()

    for col, (pred, name) in enumerate([(dd_pred, 'Data-Driven'), (pinn_pred, 'PINN')]):
        axs[row, col].plot(t, m_true, 'r-', label='True m(t)', lw=2)
        axs[row, col].plot(t, pred, 'g--', label=f'{name} pred', lw=2)
        axs[row, col].scatter(t[sub_idx], m_true[sub_idx], c='blue', s=10, zorder=5, alpha=0.5, label='Train pts')
        axs[row, col].set_title(f'{name} — {n_train} training points')
        axs[row, col].legend(fontsize=8); axs[row, col].grid(True)
        axs[row, col].set_ylim(-0.1, 1.1)

plt.suptitle('Prediction Quality at Different Training Sizes', fontsize=14)
plt.tight_layout()
plt.savefig('outputs/04_sparse_ablation_curves.png', dpi=200)
plt.show()
print('Curve comparison saved to outputs/04_sparse_ablation_curves.png')

In [None]:
# ==================== SAVE ABLATION RESULTS ====================
np.savez('outputs/04_ablation_results.npz',
         train_sizes=np.array(train_sizes),
         dd_means=np.array(dd_means),
         dd_stds=np.array(dd_stds),
         pinn_means=np.array(pinn_means),
         pinn_stds=np.array(pinn_stds))
print('Ablation results saved.')