# SINDy-SHRED Applied to SST Dataset (Low-Level API)

This notebook demonstrates the detailed workflow of SINDy-SHRED using the low-level API.
For a simpler high-level interface, see `sst_sindy_shred_refactor.ipynb`.

**SHRED** (SHallow REcurrent Decoder) models combine a recurrent layer (GRU) with a shallow decoder network to reconstruct high-dimensional spatio-temporal fields from sensor measurements.

**SINDy-SHRED** extends this by integrating Sparse Identification of Nonlinear Dynamics (SINDy):

$$\dot{z} = \Theta(z) \xi$$

where $z$ is the latent space, $\Theta(z)$ is a library of candidate functions, and $\xi$ is a sparse coefficient matrix.

## 1. Setup and Imports

In [None]:
import os
import random
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pysindy as ps
import torch
from scipy.io import loadmat
from sklearn.preprocessing import MinMaxScaler

import plotting

# Local modules
import sindy
import sindy_shred_net
from processdata import load_data
from utils import TimeSeriesDataset, get_device

warnings.filterwarnings("ignore")

# Create results directory
RESULTS_DIR = "results/sst"
os.makedirs(RESULTS_DIR, exist_ok=True)
print(f"Results will be saved to: {RESULTS_DIR}")

### Device and Seed Configuration

In [None]:
# Device selection
device = get_device()
print(f"Using device: {device}")

# Set seeds for reproducibility
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device.type == "cuda":
    torch.cuda.manual_seed(SEED)

## 2. Data Loading and Configuration

In [None]:
# Load SST data
load_X = load_data("SST")
n = load_X.shape[0]  # Number of time samples
m = load_X.shape[1]  # Spatial dimension

print(f"Data shape: {load_X.shape}")
print(f"Number of time samples: {n}")
print(f"Spatial dimension: {m}")

In [None]:
# Sensor configuration
num_sensors = 250
sensor_locations = np.random.choice(m, size=num_sensors, replace=False)

# Trajectory length (52 weeks = 1 year of measurements)
lags = 52

# Data split lengths
train_length = 1000
validate_length = 30

# Time step (weekly data)
dt = 1 / 52.0

# Model hyperparameters
latent_dim = 3
poly_order = 1
include_sine = False

# Calculate library dimension
library_dim = sindy.library_size(
    latent_dim, poly_order, include_sine, include_constant=True
)

print(f"Number of sensors: {num_sensors}")
print(f"Trajectory length (lags): {lags}")
print(f"Training length: {train_length}")
print(f"Latent dimension: {latent_dim}")
print(f"Library dimension: {library_dim}")

## 3. Data Preprocessing

Split data into train/validation/test sets and apply MinMax scaling.

In [None]:
# Create train/validation/test indices
train_indices = np.arange(0, train_length)

mask = np.ones(n - lags)
mask[train_indices] = 0
valid_test_indices = np.arange(0, n - lags)[np.where(mask != 0)[0]]
valid_indices = valid_test_indices[:validate_length]
test_indices = valid_test_indices[validate_length:]

print(f"Train samples: {len(train_indices)}")
print(f"Validation samples: {len(valid_indices)}")
print(f"Test samples: {len(test_indices)}")

In [None]:
# Scale data using MinMaxScaler (fit on training data only)
sc = MinMaxScaler()
sc = sc.fit(load_X[train_indices])
transformed_X = sc.transform(load_X)

# Generate input sequences (sensor trajectories)
all_data_in = np.zeros((n - lags, lags, num_sensors))
for i in range(len(all_data_in)):
    all_data_in[i] = transformed_X[i : i + lags, sensor_locations]

# Create input tensors
train_data_in = torch.tensor(all_data_in[train_indices], dtype=torch.float32).to(device)
valid_data_in = torch.tensor(all_data_in[valid_indices], dtype=torch.float32).to(device)
test_data_in = torch.tensor(all_data_in[test_indices], dtype=torch.float32).to(device)

# Create output tensors (-1 to align with final sensor measurement)
train_data_out = torch.tensor(
    transformed_X[train_indices + lags - 1], dtype=torch.float32
).to(device)
valid_data_out = torch.tensor(
    transformed_X[valid_indices + lags - 1], dtype=torch.float32
).to(device)
test_data_out = torch.tensor(
    transformed_X[test_indices + lags - 1], dtype=torch.float32
).to(device)

# Create datasets
train_dataset = TimeSeriesDataset(train_data_in, train_data_out)
valid_dataset = TimeSeriesDataset(valid_data_in, valid_data_out)
test_dataset = TimeSeriesDataset(test_data_in, test_data_out)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Input shape: {train_data_in.shape}")
print(f"Output shape: {train_data_out.shape}")

### Visualize Sensor Data

In [None]:
# Plot sample sensor time series
n_sensors_to_plot = min(num_sensors, 10)

fig, axes = plt.subplots(
    n_sensors_to_plot, 1, figsize=(10, 2 * n_sensors_to_plot), sharex=True
)
for i in range(n_sensors_to_plot):
    axes[i].plot(load_X[:, sensor_locations[i]])
    axes[i].set_ylabel(f"Sensor {i}")
axes[-1].set_xlabel("Time (weeks)")
fig.suptitle("Sample Sensor Time Series")
fig.tight_layout()
plt.show()

## 4. Model Creation and Training

Create the SINDy-SHRED network and train it.

In [None]:
# Create SINDy-SHRED network
shred = sindy_shred_net.SINDy_SHRED_net(
    input_size=num_sensors,
    output_size=m,
    hidden_size=latent_dim,
    hidden_layers=2,
    l1=350,
    l2=400,
    dropout=0.1,
    library_dim=library_dim,
    poly_order=poly_order,
    include_sine=include_sine,
    dt=dt,
).to(device)

print("SINDy-SHRED network created")
print(f"Total parameters: {sum(p.numel() for p in shred.parameters())}")

In [None]:
# Train the model
validation_errors = sindy_shred_net.fit(
    shred,
    train_dataset,
    valid_dataset,
    batch_size=128,
    num_epochs=600,
    lr=1e-3,
    verbose=True,
    threshold=0.05,
    patience=5,
    sindy_regularization=10.0,
    optimizer="AdamW",
    thres_epoch=100,
)

In [None]:
# Check sparsity of learned coefficients
sparsity_rate = torch.mean(shred.e_sindy.coefficient_mask * 1.0)
print(f"Coefficient sparsity rate: {sparsity_rate:.4f}")

## 5. Evaluation

Evaluate reconstruction performance on the test set.

In [None]:
# Compute test reconstruction error
test_recons = sc.inverse_transform(shred(test_dataset.X).detach().cpu().numpy())
test_ground_truth = sc.inverse_transform(test_dataset.Y.detach().cpu().numpy())

relative_error = np.linalg.norm(test_recons - test_ground_truth) / np.linalg.norm(
    test_ground_truth
)
print(f"Test set relative reconstruction error: {relative_error:.4f}")

## 6. Post-hoc SINDy Discovery

Extract latent trajectories and discover sparse governing equations using PySINDy.

In [None]:
# Extract latent trajectories from training data
gru_outs_train, _ = shred.gru_outputs(train_dataset.X, sindy=True)
gru_outs_train = gru_outs_train[:, 0, :]

# Normalize latent trajectories to [-1, 1]
gru_outs_normalized = gru_outs_train.clone()
for i in range(latent_dim):
    gru_outs_normalized[:, i] = (
        gru_outs_train[:, i] - torch.min(gru_outs_train[:, i])
    ) / (torch.max(gru_outs_train[:, i]) - torch.min(gru_outs_train[:, i]))
gru_outs_normalized = 2 * gru_outs_normalized - 1

x_train = gru_outs_normalized.detach().cpu().numpy()
print(f"Latent trajectories shape: {x_train.shape}")

In [None]:
# SINDy discovery
sindy_threshold = 0.05

differentiation_method = ps.differentiation.FiniteDifference()
# For noisy data, consider: ps.differentiation.SmoothedFiniteDifference()

model = ps.SINDy(
    optimizer=ps.STLSQ(threshold=sindy_threshold, alpha=0.05),
    differentiation_method=differentiation_method,
    feature_library=ps.PolynomialLibrary(degree=poly_order),
)

model.fit(x_train, t=dt)
print("\nDiscovered SINDy equations:")
model.print()

In [None]:
# Simulate discovered model
t_sim = np.arange(0, len(x_train) * dt, dt)
init_cond = x_train[0, :]
x_sim = model.simulate(init_cond, t_sim)

# Plot comparison
fig, axes = plt.subplots(latent_dim, 1, figsize=(10, 2 * latent_dim), sharex=True)
for i in range(latent_dim):
    axes[i].plot(t_sim, x_train[: len(t_sim), i], label="SINDy-SHRED")
    axes[i].plot(t_sim, x_sim[:, i], "k--", label="Identified model")
    axes[i].set_ylabel(rf"$z_{{{i}}}$")
    if i == latent_dim - 1:
        axes[i].set_xlabel("Time")
        axes[i].legend()

fig.suptitle("Latent Space: SINDy-SHRED vs Identified Model")
fig.tight_layout()
fig.savefig(f"{RESULTS_DIR}/latent_comparison.pdf", bbox_inches="tight", dpi=300)
fig.savefig(f"{RESULTS_DIR}/latent_comparison.png", bbox_inches="tight", dpi=300)
print(f"Saved latent comparison plot to {RESULTS_DIR}/latent_comparison.png")
plt.show()

## 7. SINDy Prediction and Decoding

Use the discovered SINDy model to predict latent trajectories and decode back to physical space.

In [None]:
# Get test latent trajectories and normalize
gru_outs_test, _ = shred.gru_outputs(test_dataset.X, sindy=True)
gru_outs_test = gru_outs_test[:, 0, :]
gru_outs_train_np = gru_outs_train.detach().cpu().numpy()

# Normalize test using training statistics
gru_outs_test_np = gru_outs_test.detach().cpu().numpy()
for i in range(latent_dim):
    gru_outs_test_np[:, i] = (
        gru_outs_test_np[:, i] - np.min(gru_outs_train_np[:, i])
    ) / (np.max(gru_outs_train_np[:, i]) - np.min(gru_outs_train_np[:, i]))
gru_outs_test_np = 2 * gru_outs_test_np - 1

# Simulate SINDy model from test initial condition
t_test = np.arange(0, len(test_indices) * dt, dt)
init_cond_test = gru_outs_test_np[0, :]
x_predict = model.simulate(init_cond_test, t_test)

print(f"SINDy prediction shape: {x_predict.shape}")

In [None]:
# Decode SINDy predictions to physical space
# Step 1: Reverse normalization
x_denorm = (x_predict + 1) / 2  # [-1,1] -> [0,1]
for i in range(latent_dim):
    x_denorm[:, i] = x_denorm[:, i] * (
        np.max(gru_outs_train_np[:, i]) - np.min(gru_outs_train_np[:, i])
    ) + np.min(gru_outs_train_np[:, i])

# Step 2: Pass through decoder
latent_pred = torch.FloatTensor(x_denorm).to(device)
output = shred.linear1(latent_pred)
output = shred.dropout(output)
output = torch.nn.functional.relu(output)
output = shred.linear2(output)
output = shred.dropout(output)
output = torch.nn.functional.relu(output)
output = shred.linear3(output)

output_sindy = output.detach().cpu().numpy()
print(f"Decoded output shape: {output_sindy.shape}")

## 8. Visualization

In [None]:
# Load SST location indices for visualization
load_X_full = loadmat("Data/SST_data.mat")["Z"].T
mean_X = np.mean(load_X_full, axis=0)
sst_locs = np.where(mean_X != 0)[0]

In [None]:
# Spatial reconstruction comparison
timesteps = [0, 50, 75, 100, 125]
test_Y = test_dataset.Y.detach().cpu().numpy()

fig, axes = plotting.plot_reconstruction_comparison(
    test_Y,
    output_sindy,
    timesteps,
    sst_locs=sst_locs,
    lat_range=(0, 180),
    lon_range=(0, 180),
    diff_scale=10,
)
fig.suptitle("Spatial Reconstruction: Real vs SINDy-Predicted")
fig.savefig(f"{RESULTS_DIR}/spatial_comparison.pdf", bbox_inches="tight", dpi=300)
plt.show()

In [None]:
# Sensor-level predictions
sensor_locations_viz = np.random.randint(1, 40000, size=18)

fig, axes = plotting.plot_sensor_predictions(
    test_Y,
    output_sindy,
    sensor_locations=sensor_locations_viz,
    num_context=lags,
    num_pred=min(250, len(output_sindy) - lags),
    rows=3,
    cols=6,
    save_path=f"{RESULTS_DIR}/sensor_predictions_grid.pdf",
)
plt.show()

## 9. Save Results

In [None]:
import json
import sys
from io import StringIO

# Save latent trajectories
np.save(f"{RESULTS_DIR}/latent_train.npy", x_train)
np.save(f"{RESULTS_DIR}/latent_test.npy", gru_outs_test_np)
np.save(f"{RESULTS_DIR}/latent_sindy_predict.npy", x_predict)
print(f"Saved latent trajectories")

# Save SINDy coefficients
sindy_coefficients = model.coefficients()
np.save(f"{RESULTS_DIR}/sindy_coefficients.npy", sindy_coefficients)
print(f"Saved SINDy coefficients: shape {sindy_coefficients.shape}")

# Save feature names
feature_names = model.get_feature_names()
with open(f"{RESULTS_DIR}/sindy_feature_names.txt", "w") as f:
    for name in feature_names:
        f.write(name + "\n")

# Save SINDy equations
old_stdout = sys.stdout
sys.stdout = StringIO()
model.print()
equations_str = sys.stdout.getvalue()
sys.stdout = old_stdout

with open(f"{RESULTS_DIR}/sindy_equations.txt", "w") as f:
    f.write("Discovered SINDy Equations:\n")
    f.write("=" * 40 + "\n")
    f.write(equations_str)
print(f"Saved SINDy equations to {RESULTS_DIR}/sindy_equations.txt")

# Save config as JSON
config = {
    "latent_dim": latent_dim,
    "poly_order": poly_order,
    "num_sensors": num_sensors,
    "lags": lags,
    "train_length": train_length,
    "validate_length": validate_length,
    "dt": dt,
    "sindy_threshold": sindy_threshold,
}
with open(f"{RESULTS_DIR}/config.json", "w") as f:
    json.dump(config, f, indent=2)
print(f"Saved configuration to {RESULTS_DIR}/config.json")

# Compute additional errors for results
# SINDy latent prediction error (training data)
sindy_latent_error = np.linalg.norm(x_sim - x_train[: len(x_sim)]) / np.linalg.norm(
    x_train[: len(x_sim)]
)

# SINDy physical prediction error (test data, use min length to handle shape mismatch)
output_sindy_unscaled = sc.inverse_transform(output_sindy)
n_compare = min(len(output_sindy_unscaled), len(test_ground_truth))
sindy_physical_error = np.linalg.norm(
    output_sindy_unscaled[:n_compare] - test_ground_truth[:n_compare]
) / np.linalg.norm(test_ground_truth[:n_compare])

# Save results as JSON
results = {
    "reconstruction_error": float(relative_error),
    "sindy_latent_error": float(sindy_latent_error),
    "sindy_prediction_error": float(sindy_physical_error),
}
with open(f"{RESULTS_DIR}/results.json", "w") as f:
    json.dump(results, f, indent=2)
print(f"Saved results to {RESULTS_DIR}/results.json")

print("\nAll results saved!")