gi# SINDy-SHRED: Toy Data Example

This notebook demonstrates SINDy-SHRED on a synthetic toy dataset using the `SINDySHRED` class. The class handles data preprocessing, model training, and post-hoc SINDy discovery automatically.

## Overview

**SHRED** (SHallow REcurrent Decoder) models combine a recurrent layer (GRU) with a shallow decoder network to reconstruct high-dimensional spatio-temporal fields from sensor measurements.

**SINDy-SHRED** extends this by integrating Sparse Identification of Nonlinear Dynamics (SINDy) to learn interpretable governing equations:

$$\dot{z} = \Theta(z) \xi$$

## Toy Data

The synthetic data combines two dynamical systems:

1. **FitzHugh-Nagumo Model** (slow dynamics):
   $$\dot{v} = v - \frac{1}{3}v^3 - w + I_{ext}$$
   $$\dot{w} = \frac{1}{\tau_1}(v + a - bw)$$

2. **Unforced Duffing Oscillator** (fast dynamics):
   $$\dot{p} = q$$
   $$\dot{q} = -\frac{1}{\tau_2}(p + \epsilon p^3)$$

These are combined via orthogonal mixing to create multi-scale spatio-temporal data.

## Notebook Structure

1. Setup and Imports
2. Data Generation
3. Model Configuration and Training
4. SINDy Discovery
5. Evaluation
6. Save Results

## 1. Setup and Imports

In [None]:
import copy
import os
import random
import warnings

import matplotlib.pyplot as plt
import numpy as np
import scipy
import seaborn as sns
import torch
from scipy.integrate import solve_ivp

# Local modules
from sindy_shred import SINDySHRED
import plotting

warnings.filterwarnings("ignore")

# Create results directory
RESULTS_DIR = "results/toy_data"
os.makedirs(RESULTS_DIR, exist_ok=True)
print(f"Results will be saved to: {RESULTS_DIR}")

### Device and Seed Configuration

In [None]:
# Device selection
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Using device: {device}")

# Set seeds for reproducibility
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device == "cuda":
    torch.cuda.manual_seed(SEED)

### Plotting Configuration

In [None]:
sns.set_context("paper")
sns.set_style("whitegrid")

pcolor_kwargs = {
    "vmin": -3,
    "vmax": 3,
    "cmap": "RdBu_r",
    "rasterized": True,
}

## 2. Data Generation

Generate synthetic spatio-temporal data from the FitzHugh-Nagumo and Duffing oscillator systems.

### Define Dynamical Systems

In [None]:
def rhs_FNM(t, x, tau, a, b, Iext):
    """FitzHugh-Nagumo Model (slow dynamics).
    
    Parameters
    ----------
    t : float
        Time (unused, for ODE solver interface).
    x : array-like
        State vector [v, w].
    tau : float
        Time constant.
    a, b : float
        Model parameters.
    Iext : float
        External input current.
        
    Returns
    -------
    dx : array-like
        Time derivatives [dv/dt, dw/dt].
    """
    v, w = x
    vdot = v - (v**3) / 3 - w + Iext
    wdot = (1 / tau) * (v + a - b * w)
    return np.array([vdot, wdot])


def rhs_UFD(t, y, eta, epsilon, tau):
    """Unforced Duffing Oscillator (fast dynamics).
    
    Parameters
    ----------
    t : float
        Time (unused, for ODE solver interface).
    y : array-like
        State vector [p, q].
    eta : float
        Damping coefficient.
    epsilon : float
        Nonlinearity strength.
    tau : float
        Time constant.
        
    Returns
    -------
    dy : array-like
        Time derivatives [dp/dt, dq/dt].
    """
    p, q = y
    pdot = q
    qdot = (1 / tau) * (-2 * eta * q - p - epsilon * p**3)
    return np.array([pdot, qdot])

### Generate Data

In [None]:
# Time integration parameters
T = 64
dt = 0.0001 * 8
t_solution = np.arange(0, T, dt)

# FitzHugh-Nagumo parameters (slow mode, tau1=2)
x0 = np.array([-1.110, -0.125])
tau1 = 2
a = 0.7
b = 0.8
Iext = 0.65

# Duffing oscillator parameters (fast mode, tau2=0.2)
y0 = np.array([0, 1])
eta = 0
epsilon = 1
tau2 = 0.2

# Solve the ODEs
solution_fn = solve_ivp(
    rhs_FNM, [0, T], x0, t_eval=t_solution, args=(tau1, a, b, Iext)
)
solution_ufd = solve_ivp(
    rhs_UFD, [0, T], y0, t_eval=t_solution, args=(eta, epsilon, tau2)
)

print(f"FitzHugh-Nagumo solution shape: {solution_fn.y.shape}")
print(f"Duffing oscillator solution shape: {solution_ufd.y.shape}")

In [None]:
# Create mixed spatio-temporal data
seed = 1
num_space_dims = 10

# Tile the solutions to create spatial replicates
uv_tiled = np.hstack([
    np.tile(solution_fn.y.T, num_space_dims),
    np.tile(solution_ufd.y.T, num_space_dims),
])

# Subsample for computational efficiency
substep = 50
uv_tiled = uv_tiled[0::substep, :]
t_solution = t_solution[0::substep]
time = t_solution
dt_data = time[1] - time[0]

# Get dimensions
n_space_dims = uv_tiled.shape[1]
n_time = uv_tiled.shape[0]

# Apply orthogonal mixing
Q = scipy.stats.ortho_group.rvs(n_space_dims, random_state=seed)
Q = Q[0:n_space_dims, :]
x = uv_tiled @ Q

# Final data matrix (time x space)
data_original = x.T

# Extract slow and fast mode components for comparison
slow_modes = uv_tiled[:, 0:n_space_dims // 2] @ Q[0:n_space_dims // 2, :]
fast_modes = uv_tiled[:, n_space_dims // 2:] @ Q[n_space_dims // 2:, :]

print(f"Data shape (space x time): {data_original.shape}")
print(f"Time step: {dt_data:.6f}")
print(f"Number of time samples: {n_time}")
print(f"Spatial dimension: {n_space_dims}")

### Visualize Generated Data

In [None]:
space_dim = np.arange(n_space_dims)

fig, axes = plt.subplots(3, 1, figsize=(8, 6), sharex=True, sharey=True)

ax = axes[0]
ax.pcolormesh(time, space_dim, data_original, **pcolor_kwargs)
ax.set_title(r"a) Total signal $\mathbf{x}_{total}$", loc="left")
ax.set_ylabel("Space")

ax = axes[1]
ax.pcolormesh(time, space_dim, slow_modes.T, **pcolor_kwargs)
ax.set_title(r"b) Slow component $\mathbf{x}_{slow}$", loc="left")
ax.set_ylabel("Space")

ax = axes[2]
ax.pcolormesh(time, space_dim, fast_modes.T, **pcolor_kwargs)
ax.set_title(r"c) Fast component $\mathbf{x}_{fast}$", loc="left")
ax.set_ylabel("Space")
ax.set_xlabel("Time")

fig.tight_layout()

# Save the data visualization plot
fig.savefig(f"{RESULTS_DIR}/data_components.pdf", bbox_inches="tight", dpi=300)
fig.savefig(f"{RESULTS_DIR}/data_components.png", bbox_inches="tight", dpi=300)
print(f"Saved data components plot to {RESULTS_DIR}/data_components.pdf")

plt.show()

## 3. Model Configuration and Training

Configure the SINDy-SHRED model using the `SINDyShred` class.

### Data Configuration

In [None]:
# Sensor configuration (fixed for reproducibility)
sensor_locations = np.array([10, 28, 14, 11, 23, 27])
num_sensors = len(sensor_locations)

# Model hyperparameters
latent_dim = 4
poly_order = 1
include_sine = False
include_constant = True

# Data split configuration
lags = 120
train_length = 750 // 4
validate_length = 0

# Prepare data (transpose to time x space, subsample)
load_X = copy.deepcopy(data_original)
load_X = load_X.T[::4]  # Subsample
dt = dt_data * 4
lags = lags // 4

# SINDy threshold
threshold = 0.0

print(f"Data shape after preprocessing: {load_X.shape}")
print(f"Number of sensors: {num_sensors}")
print(f"Latent dimension: {latent_dim}")
print(f"Trajectory length (lags): {lags}")
print(f"Training length: {train_length}")
print(f"Time step: {dt:.6f}")

In [None]:
# Visualize sensor time series
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(load_X[:, sensor_locations], color='b', alpha=0.7)
ax.axvline(train_length, color='k', linestyle='--', label='Train/Test split')
ax.axvline(lags, color='r', linestyle=':', label='Lag window')
ax.set_xlabel('Time step')
ax.set_ylabel('Sensor value')
ax.set_title('Sensor Time Series')
ax.legend()
plt.show()

### Initialize and Train Model

The `SINDySHRED` class handles data preprocessing and model training.

In [None]:
# Initialize the model
model = SINDySHRED(
    latent_dim=latent_dim, 
    poly_order=poly_order,
    ode_order=1,  # 1st order ODE: z' = f(z)
    verbose=True
)

# Fit the model
model.fit(
    num_sensors,
    dt,
    load_X,
    lags,
    train_length,
    validate_length,
    sensor_locations
)

## 4. SINDy Discovery

Discover sparse governing equations from the learned latent space.

In [None]:
# Perform SINDy identification
model.sindy_identify(threshold=threshold, plot_result=True)
fig = plt.gcf()
fig.suptitle("Latent Space: SINDy-SHRED vs Identified Model")
fig.tight_layout()

# Save the latent comparison plot
fig.savefig(f"{RESULTS_DIR}/latent_comparison.pdf", bbox_inches="tight", dpi=300)
fig.savefig(f"{RESULTS_DIR}/latent_comparison.png", bbox_inches="tight", dpi=300)
print(f"Saved latent comparison plot to {RESULTS_DIR}/latent_comparison.pdf")

plt.show()

### Auto-Tune Threshold (Adaptive/Nonparametric)

Alternatively, use `auto_tune_threshold()` to automatically determine the best threshold.
By default it uses a nonparametric approach:
1. First computes the least-squares solution (threshold=0)
2. Uses `scale_factor * max(|coefficients|)` as the max threshold
3. Tests `n_thresholds` evenly spaced values and picks the best stable model

In [None]:
# Auto-tune threshold using nonparametric approach
# This computes least-squares solution first, then determines threshold range adaptively
best_threshold, tune_results = model.auto_tune_threshold(
    adaptive=True,           # Use nonparametric approach (default)
    scale_factor=0.3,        # Max threshold = 0.3 * max(|coefficients|)
    n_thresholds=10,         # Test 10 evenly spaced thresholds
    metric="bic", # Pick sparsest stable model
    verbose=True,
)

print(f"\nBest threshold: {best_threshold:.4f}")
print(f"Tested thresholds: {tune_results['thresholds']}")
print(f"Sparsity at each: {tune_results['sparsity']}")
print(f"Stability at each: {tune_results['stable']}")

### True Governing Equations

For reference, the true governing equations are:

**Slow modes (FitzHugh-Nagumo):**
$$\dot{v} = v - \frac{1}{3}v^3 - w + 0.65$$
$$\dot{w} = \frac{1}{\tau_1}(v + 0.7 - 0.8w)$$

with time constant $\tau_1 = 2$.

**Fast modes (Duffing):**
$$\dot{p} = q$$
$$\dot{q} = -\frac{1}{\tau_2}(p + p^3)$$

with time constant $\tau_2 = 0.2$.

## 5. Evaluation

Evaluate reconstruction performance on the test set.

In [None]:
# Compute test reconstruction error using new API
test_recons = model.sensor_recon(data_type="test", return_scaled=False)
test_ground_truth = model._scaler.inverse_transform(
    model._test_data.Y.detach().cpu().numpy()
)

relative_error = model.relative_error(test_recons, test_ground_truth)
print(f"Test set relative reconstruction error: {relative_error:.4f}")

In [None]:
# Visualize reconstruction
fig, axes = plt.subplots(2, 1, figsize=(10, 5), sharex=True)

ax = axes[0]
ax.pcolormesh(test_ground_truth.T, **pcolor_kwargs)
ax.set_title("Ground Truth")
ax.set_ylabel("Space")

ax = axes[1]
ax.pcolormesh(test_recons.T, **pcolor_kwargs)
ax.set_title("SINDy-SHRED Reconstruction")
ax.set_ylabel("Space")
ax.set_xlabel("Time step")

fig.tight_layout()

# Save the reconstruction comparison plot
fig.savefig(f"{RESULTS_DIR}/reconstruction_comparison.pdf", bbox_inches="tight", dpi=300)
fig.savefig(f"{RESULTS_DIR}/reconstruction_comparison.png", bbox_inches="tight", dpi=300)
print(f"Saved reconstruction comparison plot to {RESULTS_DIR}/reconstruction_comparison.pdf")

plt.show()

### Sensor-Level Predictions

Compare real vs predicted at individual spatial locations (sensors).

In [None]:
# Plot sensor-level comparisons (all 10 spatial dimensions)
fig, axes = plotting.plot_sensor_predictions(
    test_ground_truth,
    test_recons,
    sensor_locations=np.arange(n_space_dims),  # All spatial dims
    rows=2,
    cols=5,
    save_path=f"{RESULTS_DIR}/sensor_predictions_grid.pdf"
)
fig.suptitle("Sensor-Level Predictions: Real vs Reconstructed")
fig.tight_layout()
print(f"Saved sensor predictions plot to {RESULTS_DIR}/sensor_predictions_grid.pdf")
plt.show()

## Summary

This notebook demonstrated SINDy-SHRED on synthetic multi-scale data:

1. Generated toy data from FitzHugh-Nagumo (slow) and Duffing (fast) oscillators
2. Used the `SINDySHRED` class for streamlined model training
3. Discovered sparse governing equations that approximate the true dynamics
4. Achieved accurate reconstruction on held-out test data
5. **Saved all results** to the `results/toy_data/` folder

### Saved Files

| File | Description |
|------|-------------|
| `shred_model.pt` | Trained SHRED neural network weights |
| `latent_train.npy` | Latent trajectories from training data |
| `latent_test.npy` | Latent trajectories from test data |
| `latent_sindy_predict.npy` | SINDy-predicted latent trajectories |
| `sindy_coefficients.npy` | Learned SINDy coefficient matrix |
| `sindy_feature_names.txt` | Names of SINDy library terms |
| `config.npy` | Model configuration and hyperparameters |
| `data_original.npy` | Original mixed spatio-temporal data |
| `slow_modes.npy` | FitzHugh-Nagumo slow component |
| `fast_modes.npy` | Duffing oscillator fast component |
| `*.pdf/*.png` | Visualization plots |

The `SINDySHRED` class simplifies the workflow compared to manual data preprocessing and model setup.

## 6. Save Results

Save the trained model, latent space values, and learned SINDy model to the results folder.

In [None]:
# Get latent space trajectories
gru_outs_train = model.gru_normalize(data_type="train")
gru_outs_train_np = gru_outs_train.detach().cpu().numpy()

gru_outs_test = model.gru_normalize(data_type="test")
gru_outs_test_np = gru_outs_test.detach().cpu().numpy()

# Get SINDy predictions
x_predict = model.predict_latent()

# Save the trained SHRED model
torch.save(model._shred.state_dict(), f"{RESULTS_DIR}/shred_model.pt")
print(f"Saved SHRED model to {RESULTS_DIR}/shred_model.pt")

# Save latent space trajectories
np.save(f"{RESULTS_DIR}/latent_train.npy", gru_outs_train_np)
np.save(f"{RESULTS_DIR}/latent_test.npy", gru_outs_test_np)
np.save(f"{RESULTS_DIR}/latent_sindy_predict.npy", x_predict)
print(f"Saved latent trajectories to {RESULTS_DIR}/latent_*.npy")

# Save SINDy model coefficients
sindy_coefficients = model._model.coefficients()
np.save(f"{RESULTS_DIR}/sindy_coefficients.npy", sindy_coefficients)
print(f"Saved SINDy coefficients to {RESULTS_DIR}/sindy_coefficients.npy")
print(f"SINDy coefficients shape: {sindy_coefficients.shape}")

# Save SINDy feature names
feature_names = model._model.get_feature_names()
with open(f"{RESULTS_DIR}/sindy_feature_names.txt", "w") as f:
    for name in feature_names:
        f.write(name + "\n")
print(f"Saved SINDy feature names to {RESULTS_DIR}/sindy_feature_names.txt")

# Save configuration
config = {
    "latent_dim": latent_dim,
    "poly_order": poly_order,
    "num_sensors": num_sensors,
    "lags": lags,
    "train_length": train_length,
    "validate_length": validate_length,
    "dt": dt,
    "threshold": threshold,
    "relative_error": relative_error,
}
np.save(f"{RESULTS_DIR}/config.npy", config)
print(f"Saved configuration to {RESULTS_DIR}/config.npy")

# Save original data for reference
np.save(f"{RESULTS_DIR}/data_original.npy", data_original)
np.save(f"{RESULTS_DIR}/slow_modes.npy", slow_modes)
np.save(f"{RESULTS_DIR}/fast_modes.npy", fast_modes)
print(f"Saved original data components to {RESULTS_DIR}/")

# Print summary of saved files
print("\n" + "="*50)
print("Saved files summary:")
print("="*50)
for f in sorted(os.listdir(RESULTS_DIR)):
    fpath = os.path.join(RESULTS_DIR, f)
    size = os.path.getsize(fpath)
    print(f"  {f}: {size/1024:.1f} KB")