# SINDy-SHRED: Synthetic Data Example (Low-Level API)

This notebook demonstrates the detailed workflow of SINDy-SHRED on synthetic data.
For a simpler high-level interface, see `synthetic_data_sindy_shred_refactor.ipynb`.

**Synthetic Data:** Uses the FitzHugh-Nagumo model with spatially delayed copies to create spatio-temporal data.

## 1. Setup and Imports

In [None]:
import copy
import os
import random
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pysindy as ps
import seaborn as sns
import torch
from scipy.integrate import solve_ivp
from sklearn.preprocessing import MinMaxScaler

# Local modules
import sindy
import sindy_shred_net
import plotting
from utils import get_device, TimeSeriesDataset

warnings.filterwarnings("ignore")

# Create results directory
RESULTS_DIR = "results/synthetic_data"
os.makedirs(RESULTS_DIR, exist_ok=True)
print(f"Results will be saved to: {RESULTS_DIR}")

In [None]:
# Plotting configuration
sns.set_context("paper")
sns.set_style("whitegrid")

pcolor_kwargs = {
    "vmin": -3,
    "vmax": 3,
    "cmap": "RdBu_r",
    "rasterized": True,
}

### Device and Seed Configuration

In [None]:
# Device selection
device = get_device()
print(f"Using device: {device}")

# Set seeds for reproducibility
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device.type == "cuda":
    torch.cuda.manual_seed(SEED)

## 2. Data Generation

Generate synthetic spatio-temporal data from the FitzHugh-Nagumo model with spatially delayed copies.

In [None]:
def rhs_FNM(t, x, tau, a, b, Iext):
    """FitzHugh-Nagumo Model."""
    v, w = x
    vdot = v - (v**3) / 3 - w + Iext
    wdot = (1 / tau) * (v + a - b * w)
    return np.array([vdot, wdot])

In [None]:
# Time integration parameters
T = 64
dt_solve = 0.0001 * 8
t_solution = np.arange(0, T, dt_solve)

# FitzHugh-Nagumo parameters
x0 = np.array([-1.110, -0.125])
tau1 = 2
a = 0.7
b = 0.8
Iext = 0.65

# Solve the ODE
solution_fn = solve_ivp(
    rhs_FNM, [0, T], x0, t_eval=t_solution, args=(tau1, a, b, Iext)
)

print(f"FitzHugh-Nagumo solution shape: {solution_fn.y.shape}")

In [None]:
# Create spatially delayed copies
num_space_dims = 10
delays = np.linspace(0, 2, num_space_dims)  # delays in time units
uv_spatial = np.zeros((len(t_solution), 2 * num_space_dims))

for i in range(num_space_dims):
    delay_steps = int(delays[i] / dt_solve)
    if delay_steps == 0:
        uv_spatial[:, 2 * i : 2 * i + 2] = solution_fn.y.T
    else:
        # Pad with initial condition and shift
        uv_spatial[:, 2 * i : 2 * i + 2] = np.vstack(
            [np.tile(x0, (delay_steps, 1)), solution_fn.y.T[:-delay_steps, :]]
        )

# Subsample for computational efficiency
substep = 50
uv_spatial = uv_spatial[0::substep, :]
t_solution = t_solution[0::substep]
time = t_solution
dt_data = time[1] - time[0]

# Get dimensions
n_space_dims = uv_spatial.shape[1]
n_time = uv_spatial.shape[0]

# Final data matrix (space x time for visualization, time x space for processing)
data_original = uv_spatial.T

print(f"Data shape (space x time): {data_original.shape}")
print(f"Time step: {dt_data:.6f}")
print(f"Number of time samples: {n_time}")
print(f"Spatial dimension: {n_space_dims}")

In [None]:
# Visualize generated data
space_dim = np.arange(n_space_dims)

fig, ax = plt.subplots(1, 1, figsize=(8, 3))

ax.pcolormesh(time, space_dim, data_original, **pcolor_kwargs)
ax.set_title(r"Spatio-temporal data $\mathbf{x}$", loc="left")
ax.set_ylabel("Space")
ax.set_xlabel("Time")

fig.tight_layout()
fig.savefig(f"{RESULTS_DIR}/data_original.pdf", bbox_inches="tight", dpi=300)
plt.show()

## 3. Configuration and Data Preprocessing

In [None]:
# Sensor configuration (fixed for reproducibility)
sensor_locations = np.array([5, 14, 7])
num_sensors = len(sensor_locations)

# Model hyperparameters
latent_dim = 2
poly_order = 3
include_sine = False

# Calculate library dimension
library_dim = sindy.library_size(latent_dim, poly_order, include_sine, include_constant=True)

# Data split configuration
lags = 120
train_length = 750 // 4
validate_length = 0

# Prepare data (transpose to time x space, subsample)
load_X = copy.deepcopy(data_original)
load_X = load_X.T[::4]  # Subsample by 4
dt = dt_data * 4
lags = lags // 4

n = load_X.shape[0]
m = load_X.shape[1]

# SINDy threshold
sindy_threshold = 0.10

print(f"Data shape after preprocessing: {load_X.shape}")
print(f"Number of sensors: {num_sensors}")
print(f"Latent dimension: {latent_dim}")
print(f"Library dimension: {library_dim}")
print(f"Trajectory length (lags): {lags}")
print(f"Training length: {train_length}")
print(f"Time step: {dt:.6f}")

In [None]:
# Visualize sensor time series
t_plot = time[::4]

fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(t_plot, load_X[:, sensor_locations], color='b', alpha=0.7)
ax.axvline(t_plot[train_length], color='k', linestyle='--', label='Train/Test split')
ax.axvline(t_plot[lags], color='r', linestyle=':', label='Lag window')
ax.set_xlabel('Time')
ax.set_ylabel('Sensor value')
ax.set_title('Sensor Time Series')
ax.legend()
plt.show()

In [None]:
# Create train/test indices
train_indices = np.arange(0, train_length)

mask = np.ones(n - lags)
mask[train_indices] = 0
test_indices = np.arange(0, n - lags)[np.where(mask != 0)[0]]

# For this example, validation = empty
valid_indices = test_indices[:validate_length] if validate_length > 0 else train_indices[:1]

print(f"Train samples: {len(train_indices)}")
print(f"Test samples: {len(test_indices)}")

In [None]:
# Scale data using MinMaxScaler
sc = MinMaxScaler()
sc = sc.fit(load_X[train_indices])
transformed_X = sc.transform(load_X)

# Generate input sequences (sensor trajectories)
all_data_in = np.zeros((n - lags, lags, num_sensors))
for i in range(len(all_data_in)):
    all_data_in[i] = transformed_X[i:i+lags, sensor_locations]

# Create input/output tensors
train_data_in = torch.tensor(all_data_in[train_indices], dtype=torch.float32).to(device)
valid_data_in = torch.tensor(all_data_in[valid_indices], dtype=torch.float32).to(device)
test_data_in = torch.tensor(all_data_in[test_indices], dtype=torch.float32).to(device)

train_data_out = torch.tensor(transformed_X[train_indices + lags - 1], dtype=torch.float32).to(device)
valid_data_out = torch.tensor(transformed_X[valid_indices + lags - 1], dtype=torch.float32).to(device)
test_data_out = torch.tensor(transformed_X[test_indices + lags - 1], dtype=torch.float32).to(device)

# Create datasets
train_dataset = TimeSeriesDataset(train_data_in, train_data_out)
valid_dataset = TimeSeriesDataset(valid_data_in, valid_data_out)
test_dataset = TimeSeriesDataset(test_data_in, test_data_out)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

## 4. Model Creation and Training

In [None]:
# Create SINDy-SHRED network
shred = sindy_shred_net.SINDy_SHRED_net(
    input_size=num_sensors,
    output_size=m,
    hidden_size=latent_dim,
    hidden_layers=2,
    l1=350,
    l2=400,
    dropout=0.1,
    library_dim=library_dim,
    poly_order=poly_order,
    include_sine=include_sine,
    dt=dt,
).to(device)

print("SINDy-SHRED network created")
print(f"Total parameters: {sum(p.numel() for p in shred.parameters())}")

In [None]:
# Train the model
validation_errors = sindy_shred_net.fit(
    shred,
    train_dataset,
    valid_dataset,
    batch_size=64,
    num_epochs=600,
    lr=1e-3,
    verbose=True,
    threshold=0.05,
    patience=5,
    sindy_regularization=10.0,
    optimizer="AdamW",
    thres_epoch=100,
)

## 5. Evaluation

In [None]:
# Compute test reconstruction error
test_recons = sc.inverse_transform(shred(test_dataset.X).detach().cpu().numpy())
test_ground_truth = sc.inverse_transform(test_dataset.Y.detach().cpu().numpy())

relative_error = np.linalg.norm(test_recons - test_ground_truth) / np.linalg.norm(test_ground_truth)
print(f"Test set relative reconstruction error: {relative_error:.4f}")

In [None]:
# Visualize reconstruction
fig, axes = plt.subplots(2, 1, figsize=(10, 5), sharex=True)

ax = axes[0]
ax.pcolormesh(test_ground_truth.T, **pcolor_kwargs)
ax.set_title("Ground Truth")
ax.set_ylabel("Space")

ax = axes[1]
ax.pcolormesh(test_recons.T, **pcolor_kwargs)
ax.set_title("SINDy-SHRED Reconstruction")
ax.set_ylabel("Space")
ax.set_xlabel("Time step")

fig.tight_layout()
fig.savefig(f"{RESULTS_DIR}/reconstruction_comparison.pdf", bbox_inches="tight", dpi=300)
plt.show()

## 6. Post-hoc SINDy Discovery

Extract latent trajectories and discover sparse governing equations.

In [None]:
# Extract latent trajectories from training data
gru_outs_train, _ = shred.gru_outputs(train_dataset.X, sindy=True)
gru_outs_train = gru_outs_train[:, 0, :]

# Save min/max for each latent dimension (needed for denormalization later)
latent_min = torch.min(gru_outs_train, dim=0).values
latent_max = torch.max(gru_outs_train, dim=0).values

# Normalize latent trajectories to [-1, 1]
gru_outs_normalized = gru_outs_train.clone()
for i in range(latent_dim):
    gru_outs_normalized[:, i] = (gru_outs_train[:, i] - latent_min[i]) / (latent_max[i] - latent_min[i])
gru_outs_normalized = 2 * gru_outs_normalized - 1

x_train = gru_outs_normalized.detach().cpu().numpy()
gru_outs_train_np = gru_outs_train.detach().cpu().numpy()
print(f"Latent trajectories shape: {x_train.shape}")
print(f"Latent min: {latent_min.detach().cpu().numpy()}")
print(f"Latent max: {latent_max.detach().cpu().numpy()}")

In [None]:
# SINDy discovery
differentiation_method = ps.differentiation.FiniteDifference()

model = ps.SINDy(
    optimizer=ps.STLSQ(threshold=sindy_threshold, alpha=0.05),
    differentiation_method=differentiation_method,
    feature_library=ps.PolynomialLibrary(degree=poly_order),
)

model.fit(x_train, t=dt)
print("\nDiscovered SINDy equations:")
model.print()

### True Governing Equations

**FitzHugh-Nagumo Model:**
$$\dot{v} = v - \frac{1}{3}v^3 - w + 0.65$$
$$\dot{w} = \frac{1}{\tau}(v + 0.7 - 0.8w)$$
with $\tau = 2$.

In [None]:
# Simulate discovered model
t_sim = np.arange(0, len(x_train) * dt, dt)
init_cond = x_train[0, :]
x_sim = model.simulate(init_cond, t_sim)

# Plot comparison
fig, axes = plt.subplots(latent_dim, 1, figsize=(10, 2 * latent_dim), sharex=True)
for i in range(latent_dim):
    axes[i].plot(t_sim, x_train[:len(t_sim), i], label="SINDy-SHRED")
    axes[i].plot(t_sim, x_sim[:, i], "k--", label="Identified model")
    axes[i].set_ylabel(rf"$z_{{{i}}}$")
    if i == latent_dim - 1:
        axes[i].set_xlabel("Time")
        axes[i].legend()

fig.suptitle("Latent Space: SINDy-SHRED vs Identified Model")
fig.tight_layout()
fig.savefig(f"{RESULTS_DIR}/latent_comparison.pdf", bbox_inches="tight", dpi=300)
plt.show()

## 7. Sensor-Level Predictions

In [None]:
# Get test latent from GRU
gru_outs_test, _ = shred.gru_outputs(test_dataset.X, sindy=True)
gru_outs_test = gru_outs_test[:, 0, :]

# Normalize test latent using TRAINING min/max
test_normalized = gru_outs_test.clone()
for i in range(latent_dim):
    test_normalized[:, i] = (gru_outs_test[:, i] - latent_min[i]) / (latent_max[i] - latent_min[i])
test_normalized = 2 * test_normalized - 1
test_normalized_np = test_normalized.detach().cpu().numpy()

# Use first test point as initial condition and simulate SINDy forward
t_test = np.arange(0, len(test_normalized_np) * dt, dt)
init_cond_test = test_normalized_np[0, :]
x_sindy_test = model.simulate(init_cond_test, t_test)

# Denormalize SINDy output back to original latent scale
# Reverse of: normalized = 2 * (x - min) / (max - min) - 1
# So: x = (normalized + 1) / 2 * (max - min) + min
x_sindy_denorm = np.zeros_like(x_sindy_test)
for i in range(latent_dim):
    x_sindy_denorm[:, i] = (x_sindy_test[:, i] + 1) / 2 * (latent_max[i].item() - latent_min[i].item()) + latent_min[i].item()

# Decode through SDN (decoder only) to get physical space prediction
x_sindy_tensor = torch.tensor(x_sindy_denorm, dtype=torch.float32).to(device)
sindy_physical_scaled = shred.decode(x_sindy_tensor).detach().cpu().numpy()

# Inverse transform to original data scale
sindy_physical = sc.inverse_transform(sindy_physical_scaled)

print(f"SINDy latent prediction shape: {x_sindy_test.shape}")
print(f"Decoded physical prediction shape: {sindy_physical.shape}")

# Plot sensor-level comparisons: Ground Truth vs SINDy Prediction
fig, axes = plotting.plot_sensor_predictions(
    test_ground_truth,
    sindy_physical[:len(test_ground_truth)],
    sensor_locations=np.arange(n_space_dims),  # All spatial dims
    rows=2,
    cols=5,
    save_path=f"{RESULTS_DIR}/sensor_predictions_grid.pdf"
)
fig.suptitle("Sensor-Level: Ground Truth vs SINDy Prediction")
fig.tight_layout()
print(f"Saved sensor predictions plot to {RESULTS_DIR}/sensor_predictions_grid.pdf")
plt.show()

## 8. Save Results

In [None]:
# Save trained model
torch.save(shred.state_dict(), f"{RESULTS_DIR}/shred_model.pt")
print(f"Saved SHRED model to {RESULTS_DIR}/shred_model.pt")

# Save latent trajectories
np.save(f"{RESULTS_DIR}/latent_train.npy", x_train)
print(f"Saved latent trajectories")

# Save SINDy coefficients
sindy_coefficients = model.coefficients()
np.save(f"{RESULTS_DIR}/sindy_coefficients.npy", sindy_coefficients)
print(f"Saved SINDy coefficients: shape {sindy_coefficients.shape}")

# Save feature names
feature_names = model.get_feature_names()
with open(f"{RESULTS_DIR}/sindy_feature_names.txt", "w") as f:
    for name in feature_names:
        f.write(name + "\n")

# Save original data
np.save(f"{RESULTS_DIR}/data_original.npy", data_original)

# Save config
config = {
    "latent_dim": latent_dim,
    "poly_order": poly_order,
    "num_sensors": num_sensors,
    "lags": lags,
    "train_length": train_length,
    "dt": dt,
    "sindy_threshold": sindy_threshold,
    "relative_error": relative_error,
}
np.save(f"{RESULTS_DIR}/config.npy", config)

print("\nAll results saved!")