In [17]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import pairwise_distances
import matplotlib.pyplot as plt
from tqdm import tqdm

# Set up plotting style
plt.rcParams.update({
    "font.family": "serif",
    "font.size": 14,
    "figure.dpi": 300,
    "text.usetex": False
})

# ------------------ SIMULATOR ------------------
def sir_simulate(beta, gamma, S0=0.99, I0=0.01, R0=0.0, T=50, dt=1.0):
    S, I, R = [S0], [I0], [R0]
    for _ in range(T-1):
        s, i, r = S[-1], I[-1], R[-1]
        dS = -beta * s * i
        dI = beta * s * i - gamma * i
        dR = gamma * i
        S.append(s + dS * dt)
        I.append(i + dI * dt)
        R.append(r + dR * dt)
    return np.stack([S, I, R], axis=0)

def sample_params(n):
    betas = np.random.uniform(0.1, 0.5, size=n)
    gammas = np.random.uniform(0.05, 0.2, size=n)
    return betas, gammas

# ------------------ DATASET ------------------
class SIRDataset(Dataset):
    def __init__(self, n_samples):
        self.X_in, self.X_out, self.theta = [], [], []
        betas, gammas = sample_params(n_samples)
        for b, g in zip(betas, gammas):
            traj = sir_simulate(b, g)
            self.X_in.append(torch.tensor(traj[:, :40], dtype=torch.float32))
            self.X_out.append(torch.tensor(traj[:, 40:], dtype=torch.float32))
            self.theta.append(torch.tensor([b, g], dtype=torch.float32))

    def __len__(self): return len(self.X_in)
    def __getitem__(self, i):
        return {'x': self.X_in[i], 'y': self.X_out[i], 'theta': self.theta[i]}

# ------------------ MODEL ------------------
class EncoderForecastNet(nn.Module):
    def __init__(self, emb_dim=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(3, 32, 3, padding=1), nn.ReLU(),
            nn.Conv1d(32, 64, 3, padding=1), nn.ReLU(),
            nn.Conv1d(64, 128, 3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool1d(1), nn.Flatten(),
            nn.Linear(128, emb_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(emb_dim, 128), nn.ReLU(),
            nn.Linear(128, 30)  # 3 variables × 10 steps
        )

    def forward(self, x):
        z = self.encoder(x)
        out = self.decoder(z).view(-1, 3, 10)
        return out, z

# ------------------ TRAINING + ATTRIBUTION ------------------
def run_experiment(epochs=3, bandwidths=[0.1, 0.25, 0.5, 1.0]):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_ds = SIRDataset(25000)
    test_ds = SIRDataset(500)
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    model = EncoderForecastNet(emb_dim=128).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=3e-4)
    loss_fn = nn.MSELoss()

    l2_errors_by_step = []

    steps_per_epoch = 5
    total_batches = len(train_loader)
    batches_per_step = total_batches // steps_per_epoch

    for ep in range(epochs):
        print(f"Epoch {ep+1}/{epochs}")
        model.train()
        batch_iter = iter(train_loader)

        for step in range(steps_per_epoch):
            for _ in range(batches_per_step):
                try:
                    batch = next(batch_iter)
                except StopIteration:
                    break
                x = batch['x'].to(device)
                y = batch['y'].to(device)
                opt.zero_grad()
                y_hat, _ = model(x)
                loss = loss_fn(y_hat, y)
                loss.backward()
                opt.step()

            # Run attribution at each intra-epoch step
            model.eval()
            lib = build_library(model, train_ds, size=2000, device=device)
            errors = attribution(test_ds, model, lib, n_eval=200, bandwidth=0.1)
            mean_error = np.mean(errors)
            print(f"  Step {step+1}/{steps_per_epoch} | L2 Error = {mean_error:.4f}")
            l2_errors_by_step.append(mean_error)

    # Final bandwidth sweep
    final_lib = build_library(model, train_ds, size=2000, device=device)
    mean_errors = []
    for bw in bandwidths:
        errors = attribution(test_ds, model, final_lib, n_eval=200, bandwidth=bw)
        mean_err = np.mean(errors)
        print(f"[Bandwidth {bw}] Mean L2 error = {mean_err:.4f}")
        mean_errors.append(mean_err)

    # ----------- Plotting ----------- #
    fig, axs = plt.subplots(1, 2, figsize=(12, 5))

    x_vals = np.linspace(1, epochs, len(l2_errors_by_step))
    axs[0].plot(x_vals, l2_errors_by_step, marker='o', linewidth=2)
    axs[0].set_title('Attribution Error During Training', fontsize=14)
    axs[0].set_xlabel('Training Epoch', fontsize=13)
    axs[0].set_ylabel('Mean L2 Error to True Parameters', fontsize=13)
    axs[0].tick_params(axis='both', which='major', labelsize=12)

    axs[1].plot(bandwidths, mean_errors, marker='o', linewidth=2)
    axs[1].set_title('Bandwidth Sensitivity of Attribution', fontsize=14)
    axs[1].set_xlabel('RBF Kernel Bandwidth $h$', fontsize=13)
    axs[1].set_ylabel('Mean L2 Error to True Parameters', fontsize=13)
    axs[1].tick_params(axis='both', which='major', labelsize=12)

    plt.tight_layout()
    plt.savefig("back2sim.pdf", format="pdf", bbox_inches='tight')
    plt.close()

def build_library(model, dataset, size=2000, device='cpu'):
    lib = []
    model.eval()
    with torch.no_grad():
        for i in range(size):
            sample = dataset[i]
            x = sample['x'].unsqueeze(0).to(device)
            emb = model.encoder(x).cpu().numpy().flatten()
            lib.append({'z': emb, 'theta': sample['theta'].numpy()})
    return lib

def attribution(test_set, model, lib, n_eval=200, bandwidth=0.5):
    errors = []
    model.eval()
    Z_lib = np.stack([l['z'] for l in lib])
    theta_lib = np.stack([l['theta'] for l in lib])

    for i in range(n_eval):
        sample = test_set[i]
        x = sample['x'].unsqueeze(0).to(next(model.parameters()).device)
        theta_true = sample['theta'].numpy()
        with torch.no_grad():
            z = model.encoder(x).cpu().numpy()
        dists = pairwise_distances(z, Z_lib)[0]
        sims = np.exp(-dists**2 / bandwidth**2)
        sims /= sims.sum()
        theta_hat = np.sum(sims[:, None] * theta_lib, axis=0)
        error = np.linalg.norm(theta_hat - theta_true)
        errors.append(error)
    return errors

run_experiment()

Epoch 1/3
  Step 1/5 | L2 Error = 0.0549
  Step 2/5 | L2 Error = 0.0315
  Step 3/5 | L2 Error = 0.0197
  Step 4/5 | L2 Error = 0.0130
  Step 5/5 | L2 Error = 0.0111
Epoch 2/3
  Step 1/5 | L2 Error = 0.0097
  Step 2/5 | L2 Error = 0.0080
  Step 3/5 | L2 Error = 0.0070
  Step 4/5 | L2 Error = 0.0064
  Step 5/5 | L2 Error = 0.0058
Epoch 3/3
  Step 1/5 | L2 Error = 0.0056
  Step 2/5 | L2 Error = 0.0055
  Step 3/5 | L2 Error = 0.0054
  Step 4/5 | L2 Error = 0.0054
  Step 5/5 | L2 Error = 0.0054
[Bandwidth 0.1] Mean L2 error = 0.0054
[Bandwidth 0.25] Mean L2 error = 0.0116
[Bandwidth 0.5] Mean L2 error = 0.0206
[Bandwidth 1.0] Mean L2 error = 0.0352
