In [None]:
from jax_models import KuramotoSivashinsky, generate_true_states, visualize_observations
import jax.numpy as np
from jax import random

# Initialize parameters
num_steps = 1000  # Number of simulation steps
n = 128  # Dimensionality of the state space for KS model
observation_interval = 5  # Interval at which observations are made
dt = 0.01  # Time step for the KS model

# Initialize the Kuramoto-Sivashinsky model
ks_model = KuramotoSivashinsky(dt=dt, s=n, l=22, M=16)

# Initial state
key = random.PRNGKey(0)  # Random key for reproducibility
x0 = random.normal(key, (n,))
initial_state  = x0
# Noise covariances
Q = 0.1 * np.eye(n)  # Process noise covariance
R = 5.0 * np.eye(n)  # Observation noise covariance

# Observation matrix (identity matrix for direct observation of all state variables)
H = np.eye(n)

# Generate observations
observations, true_states = generate_true_states(key, num_steps, n, x0, Q, Q, R, ks_model.step, observation_interval, H)

# Visualize the observations
visualize_observations(observations, observation_interval)


: 