# SINDy-SHRED: Second-Order Linear ODE with Noise

This notebook demonstrates SINDy-SHRED on a **second-order linear ODE** (damped harmonic oscillator) with additive noise, projected to a high-dimensional observation space.

## Latent Dynamical System

**Damped Harmonic Oscillator:**
$$\ddot{x} + \gamma \dot{x} + \omega^2 x = 0$$

Or equivalently as a first-order system:
$$\dot{x} = v$$
$$\dot{v} = -\gamma v - \omega^2 x$$

where:
- $\omega$ = natural frequency
- $\gamma$ = damping coefficient

**High-dimensional projection:** The 2D latent state $(x, v)$ is projected to a high-dimensional space via a random linear projection matrix $\mathbf{P} \in \mathbb{R}^{d \times 2}$, plus additive Gaussian noise.

$$\mathbf{y} = \mathbf{P} \begin{bmatrix} x \\ v \end{bmatrix} + \epsilon, \quad \epsilon \sim \mathcal{N}(0, \sigma^2 I)$$

## 1. Setup and Imports

In [None]:
import os
import random
import warnings

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from scipy.integrate import solve_ivp

# Local modules
from sindy_shred import SINDySHRED
import plotting

warnings.filterwarnings("ignore")

# Create results directory
RESULTS_DIR = "results/second_order_ode"
os.makedirs(RESULTS_DIR, exist_ok=True)
print(f"Results will be saved to: {RESULTS_DIR}")

In [None]:
# Device selection
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Using device: {device}")

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device == "cuda":
    torch.cuda.manual_seed(SEED)

In [None]:
# Plotting configuration
sns.set_context("paper")
sns.set_style("whitegrid")

pcolor_kwargs = {
    "vmin": -3,
    "vmax": 3,
    "cmap": "RdBu_r",
    "rasterized": True,
}

## 2. Data Generation: Damped Harmonic Oscillator

We generate data from a second-order linear ODE (damped harmonic oscillator) and project it to a high-dimensional space.

In [None]:
def damped_harmonic_oscillator(t, state, omega, gamma):
    """
    Damped harmonic oscillator: x'' + gamma*x' + omega^2*x = 0
    
    State: [x, v] where v = x'
    Returns: [v, -gamma*v - omega^2*x]
    """
    x, v = state
    dxdt = v
    dvdt = -gamma * v - omega**2 * x
    return [dxdt, dvdt]

# Physical parameters
omega = 2.0      # Natural frequency
gamma = 0.3      # Damping coefficient (underdamped: gamma < 2*omega)

print(f"Natural frequency: omega = {omega}")
print(f"Damping coefficient: gamma = {gamma}")
print(f"Damping ratio: zeta = {gamma/(2*omega):.3f} (underdamped since < 1)")

In [None]:
# Time integration parameters
T = 50.0
dt_solve = 0.01
t_span = [0, T]
t_eval = np.arange(0, T, dt_solve)

# Initial conditions
x0 = [2.0, 0.0]  # Start displaced, zero velocity

# Solve the ODE
solution = solve_ivp(
    damped_harmonic_oscillator,
    t_span,
    x0,
    t_eval=t_eval,
    args=(omega, gamma),
    method='RK45'
)

latent_states = solution.y.T  # Shape: (n_time, 2)
print(f"Latent trajectory shape: {latent_states.shape}")

In [None]:
# Visualize the latent dynamics
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Time series
axes[0].plot(t_eval, latent_states[:, 0], label='x (position)')
axes[0].plot(t_eval, latent_states[:, 1], label='v (velocity)')
axes[0].set_xlabel('Time')
axes[0].set_ylabel('State')
axes[0].set_title('Latent Dynamics: Damped Harmonic Oscillator')
axes[0].legend()

# Phase portrait
axes[1].plot(latent_states[:, 0], latent_states[:, 1])
axes[1].scatter([latent_states[0, 0]], [latent_states[0, 1]], c='g', s=100, marker='o', label='Start')
axes[1].scatter([latent_states[-1, 0]], [latent_states[-1, 1]], c='r', s=100, marker='x', label='End')
axes[1].set_xlabel('x')
axes[1].set_ylabel('v')
axes[1].set_title('Phase Portrait')
axes[1].legend()
axes[1].axis('equal')

# True governing equations
axes[2].text(0.1, 0.7, r'$\dot{x} = v$', fontsize=16, transform=axes[2].transAxes)
axes[2].text(0.1, 0.4, rf'$\dot{{v}} = -{gamma}v - {omega**2}x$', fontsize=16, transform=axes[2].transAxes)
axes[2].set_title('True Governing Equations')
axes[2].axis('off')

fig.tight_layout()
fig.savefig(f"{RESULTS_DIR}/latent_dynamics.pdf", bbox_inches="tight", dpi=300)
plt.show()

In [None]:
# Project to high-dimensional space
high_dim = 20  # Observation dimension
noise_std = 0.1  # Observation noise

# Random projection matrix
np.random.seed(SEED)
P = np.random.randn(high_dim, 2)
P = P / np.linalg.norm(P, axis=1, keepdims=True)  # Normalize rows

# Project and add noise
observations = latent_states @ P.T  # Shape: (n_time, high_dim)
noise = noise_std * np.random.randn(*observations.shape)
observations_noisy = observations + noise

print(f"Observation dimension: {high_dim}")
print(f"Noise std: {noise_std}")
print(f"Observations shape: {observations_noisy.shape}")
print(f"SNR: {np.std(observations) / noise_std:.2f}")

In [None]:
# Visualize high-dimensional observations
data_original = observations_noisy.T  # Shape: (space, time)
space_dim = np.arange(high_dim)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Heatmap
im = axes[0].pcolormesh(t_eval, space_dim, data_original, **pcolor_kwargs)
axes[0].set_title('High-Dimensional Observations (with noise)')
axes[0].set_xlabel('Time')
axes[0].set_ylabel('Observation dimension')
plt.colorbar(im, ax=axes[0])

# Sample time series
for i in [0, 5, 10, 15]:
    axes[1].plot(t_eval, observations_noisy[:, i], alpha=0.7, label=f'dim {i}')
axes[1].set_xlabel('Time')
axes[1].set_ylabel('Observation')
axes[1].set_title('Sample Observation Channels')
axes[1].legend()

fig.tight_layout()
fig.savefig(f"{RESULTS_DIR}/observations.pdf", bbox_inches="tight", dpi=300)
plt.show()

## 3. Model Training with SINDy-SHRED

In [None]:
# Sensor configuration
num_sensors = 3
sensor_locations = np.array([0, 7, 14])  # Select 3 observation channels as sensors

# Data parameters
subsample = 4  # Subsample for efficiency
dt = dt_solve * subsample

# Prepare data
data = observations_noisy[::subsample, :]
n_time = data.shape[0]

print(f"Subsampled data shape: {data.shape}")
print(f"Effective dt: {dt}")
print(f"Sensor locations: {sensor_locations}")

In [None]:
# Initialize SINDy-SHRED model
model = SINDySHRED(
    latent_dim=2,       # Match true latent dimension
    poly_order=2,       # Linear system -> degree 2 sufficient
    hidden_layers=2,
    l1=256,
    l2=256,
    dropout=0.1,
    batch_size=64,
    num_epochs=400,
    lr=1e-3,
    threshold=0.1,
    sindy_regularization=5.0,
)

print("SINDy-SHRED model initialized")

In [None]:
# Fit the model
train_length = int(n_time * 0.6)
validate_length = int(n_time * 0.1)

model.fit(
    num_sensors=num_sensors,
    dt=dt,
    x_to_fit=data,
    lags=30,
    train_length=train_length,
    validate_length=validate_length,
    sensor_locations=sensor_locations,
)

print("\nModel training complete!")

## 4. SINDy Discovery

In [None]:
# Perform SINDy identification
model.sindy_identify(threshold=0.1, plot_result=True)

print("\n" + "="*50)
print("TRUE EQUATIONS:")
print("="*50)
print(f"x' = v")
print(f"v' = -{gamma}*v - {omega**2}*x")

In [None]:
# Auto-tune threshold
best_threshold, results = model.auto_tune_threshold(adaptive=True)
print(f"\nBest threshold: {best_threshold}")

## 5. Evaluation

In [None]:
# Forecast using identified model
n_forecast = n_time - train_length - validate_length - 30
forecast = model.forecast(n_steps=n_forecast)

print(f"Forecast shape: {forecast.shape}")

In [None]:
# Compute reconstruction error on test set
test_start = train_length + validate_length + 30
test_data = data[test_start:test_start + n_forecast]

relative_error = np.linalg.norm(forecast - test_data) / np.linalg.norm(test_data)
print(f"Test set relative reconstruction error: {relative_error:.4f}")

In [None]:
# Visualize reconstruction
fig, axes = plt.subplots(2, 1, figsize=(12, 6), sharex=True)

t_test = np.arange(n_forecast) * dt

ax = axes[0]
im = ax.pcolormesh(t_test, np.arange(high_dim), test_data.T, **pcolor_kwargs)
ax.set_title("Ground Truth (Test Set)")
ax.set_ylabel("Dimension")
plt.colorbar(im, ax=ax)

ax = axes[1]
im = ax.pcolormesh(t_test, np.arange(high_dim), forecast.T, **pcolor_kwargs)
ax.set_title("SINDy-SHRED Forecast")
ax.set_ylabel("Dimension")
ax.set_xlabel("Time")
plt.colorbar(im, ax=ax)

fig.suptitle(f"Second-Order ODE Reconstruction (Error: {relative_error:.4f})")
fig.tight_layout()
fig.savefig(f"{RESULTS_DIR}/reconstruction.pdf", bbox_inches="tight", dpi=300)
plt.show()

## 6. Summary

This notebook demonstrated SINDy-SHRED on a **second-order linear ODE** (damped harmonic oscillator):

- **True system:** $\ddot{x} + 0.3\dot{x} + 4x = 0$
- **Latent dimension:** 2 (position and velocity)
- **Observation dimension:** 20 (linear projection + noise)

SINDy-SHRED successfully:
1. Learned a 2D latent representation from noisy high-dimensional observations
2. Discovered sparse governing equations in the latent space
3. Generated accurate forecasts using the identified dynamics