In [1]:
from jax.scipy.linalg import inv, det, svd
import jax.numpy as np
from jax import random, jit
from sklearn.datasets import make_spd_matrix
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import jax.lax as lax
from jax_models import KuramotoSivashinsky
num_steps = 1000  # Number of simulation steps
n = 256 # Dimensionality of the state space for KS model
observation_interval = 5  # Interval at which observations are made
dt = 0.25  # Time step for the KS model

ks_model = KuramotoSivashinsky(dt=dt, s=n, l=22, M=16)

# Initial state
key = random.PRNGKey(0)  # Random key for reproducibility
x0 = random.normal(key, (n,))
initial_state  = x0
# Noise covariances
Q = 0.01 * np.eye(n)  # Process noise covariance
R = 0.5 * np.eye(n)  # Observation noise covariance
# Observation matrix (identity matrix for direct observation of all state variables)
H = np.eye(n)


def generate_true_states(key, num_steps, n, x0, H, Q, R, model_step, observation_interval):
    # Initialize the state with the initial condition based on x0 and C0
    x = np.zeros((num_steps, n))
    obs = np.zeros((num_steps, H.shape[0]))  # Adjust the shape based on H
    x = x.at[0].set(x0)

    for j in range(1, num_steps):
        key, subkey = random.split(key)
        # Update state using the model step function
        x_j = model_step(x[j-1])
        # Add process noise Q only at observation times
        if j % observation_interval == 0:
            x_j = x_j + random.multivariate_normal(subkey, np.zeros(n), Q)
            obs_state = np.dot(H, x_j)
            obs_noise = random.multivariate_normal(subkey, np.zeros(H.shape[0]), R)
            obs = obs.at[j].set(obs_state + obs_noise)
        else: #non observations are nans
            obs = obs.at[j].set(np.nan)
        
        x = x.at[j].set(x_j)

    return obs, x


In [2]:
import time

In [3]:

for i in range(10):
    a = time.time()
    observations, true_states = generate_true_states(key, num_steps, n, x0, H, Q, R, ks_model.step, observation_interval)
    print(time.time() - a)


3.530884265899658
3.4854602813720703
3.0887906551361084
3.2504544258117676
2.8156626224517822
2.6936001777648926
2.8925094604492188
2.8027076721191406
2.607027292251587
2.9955849647521973


In [9]:
from jax import random, lax, jit
import jax.numpy as jnp

# Assume model_step, observation_interval, Q, R, H, and other necessary variables are defined

def step_fn(carry, t):
    key, x_prev = carry
    key, subkey = random.split(key)
    
    x_j = model_step(x_prev)
    
    def add_noise_and_obs(_):
        x_j_noise = x_j + random.multivariate_normal(subkey, jnp.zeros(n), Q)
        obs_state = jnp.dot(H, x_j_noise)
        obs_noise = random.multivariate_normal(subkey, jnp.zeros(H.shape[0]), R)
        return obs_state + obs_noise, x_j_noise

    def no_noise_no_obs(_):
        return jnp.full(H.shape[0], jnp.nan), x_j

    obs_j, x_j_updated = lax.cond((t + 1) % observation_interval == 0,
                                  add_noise_and_obs,  # Function to execute if condition is True
                                  no_noise_no_obs,  # Function to execute if condition is False
                                  None)  # Operand passed to true_fun and false_fun (not used here)

    return (key, x_j_updated), (x_j_updated, obs_j)

@jit
def generate_true_states_jit(key, x0):
    x_init = jnp.zeros((num_steps, n))
    obs_init = jnp.full((num_steps, H.shape[0]), jnp.nan)  # Assuming H is defined
    x_init = x_init.at[0].set(x0)
    
    carry_init = (key, x0)
    _, (x, obs) = lax.scan(step_fn, carry_init, jnp.arange(num_steps - 1))
    
    x = jnp.vstack([x0[None, :], x])
    obs = jnp.vstack([jnp.full((1, H.shape[0]), jnp.nan), obs])  # First observation is NaN
    
    return obs, x


In [10]:
for i in range(10):
    a = time.time()
    observations, true_states = generate_true_states_jit(key, x0)
    print(time.time()-a)


1.2976999282836914
0.3859541416168213
0.31633663177490234
0.31566429138183594
0.47411537170410156
0.3068225383758545
0.32282495498657227
0.28130173683166504
0.2782151699066162
0.300736665725708


Array([[1.       , 0.8948393, 0.6411804, 0.6411804, 0.8948393],
       [0.8948393, 1.       , 0.8948393, 0.6411804, 0.6411804],
       [0.6411804, 0.8948393, 1.       , 0.8948393, 0.6411804],
       [0.6411804, 0.6411804, 0.8948393, 1.       , 0.8948393],
       [0.8948393, 0.6411804, 0.6411804, 0.8948393, 1.       ]],      dtype=float32)