# Latent state estimation (filtering) for an irregularly-sampled, continuous-discrete (non-linear) Gaussian dynamical system

We show how to use cd-dynamax to estimate the latent state of a continuous-discrete (non-linear) Gaussian dynamical system

Specifically, we will showcase the following filtering alternatives:

- The Extended Kalman Filter (EKF)
- The Ensemble Kalman Filter (EnKF)
- The Unscented Kalman filter

## Preliminaries

### Code Setup

In [None]:
# Main imports
import sys
from itertools import count

# Import jax and utils
from jax import numpy as jnp
import jax.random as jr
from jaxtyping import Float, Array
from typing import Callable, NamedTuple, Tuple, Optional, Union
from jaxtyping import Array, Float

# Additional, custom codebase
sys.path.append("../..")
sys.path.append("../../..")

# Import dynamax
from dynamax.parameters import ParameterProperties
# import discrete-time filters
from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM, UKFHyperParams
from dynamax.nonlinear_gaussian_ssm import extended_kalman_smoother, unscented_kalman_smoother

# Our own custom src codebase
# continuous-discrete nonlinear Gaussian SSM codebase
from continuous_discrete_nonlinear_gaussian_ssm import ContDiscreteNonlinearGaussianSSM
from continuous_discrete_nonlinear_gaussian_ssm import cdnlgssm_filter, cdnlgssm_forecast
# Load models
from continuous_discrete_nonlinear_gaussian_ssm.models import *

# Plotting
import matplotlib
%matplotlib inline
# Our own custom plotting codebase
from utils.plotting_utils import *
from lorenz_plotting import *
# Feel free to change the default figure size
#matplotlib.rcParams['figure.figsize'] = [16, 9]

# For pretty print of ndarrays
jnp.set_printoptions(formatter={"float_kind": "{:.2f}".format})

In [2]:
# Compute RMSE
def compute_rmse(y, y_est):
    return jnp.sqrt(jnp.sum((y - y_est) ** 2) / len(y))

# Compute RMSE of estimate and print comparison with
# standard deviation of measurement noise
def compute_and_print_rmse_comparison(y, y_est, R, est_type=""):
    rmse_est = compute_rmse(y, y_est)
    print(f'{f"The RMSE of the {est_type} estimate is":<40}: {rmse_est:.2f}')
    print(f'{"The std of measurement noise is":<40}: {jnp.sqrt(R):.2f}')

## True model set up

We generate data from a Lorenz 63 system, from dynamics with the following stochastic differential equations:

\begin{align*}
\frac{d x}{d t} &= a(y-x) + \sigma w_x(t) \\
\frac{d y}{d t} &= x(b-z) - y + \sigma w_y(t) \\
\frac{d z}{d t} &= xy - cz + \sigma w_z(t),
\end{align*}

With parameters $a=10, b=28, c=8/3$, the system gives rise to chaotic behavior, and we choose $\sigma=0.1$ for light diffusion.

To generate data, we numerically approximate random path solutions to this SDE using Heun's method (i.e. improved Euler), as implemented in [Diffrax](https://docs.kidger.site/diffrax/api/solvers/sde_solvers/).


We assume the observation model is
\begin{align*}
y(t) &= H x(t) + r(t) \\
r(t) &\sim N(0,R),
\end{align*}
where we choose $R=I$. 

Namely, we impose partial observability with H=[1, 0, 0], with noisy observations, sampled at irregular time intervals.

### True, data-generating model definition

In [3]:
## Main settings
state_dim = 3
emission_dim = 1

# Define a custom drift model, inherited from LearnableFunction
class lorenz63_drift(LearnableFunction):
    params: Union[Float[Array, "state_dim"], ParameterProperties]

    def f(self, x, u=None, t=None):
        foo = jnp.array(
            [
                self.params[0] * (x[1] - x[0]),
                self.params[1] * x[0] - x[1] - x[0] * x[2],
                -self.params[2] * x[2] + x[0] * x[1],
            ]
        )
        return foo

# Define the true parameters of the drift function
true_l63_drift_params = jnp.array([10.0, 28.0, 8 / 3])
# And the corresponding Lorenz 63 system
true_drift = {
    "params": lorenz63_drift(
        params=true_l63_drift_params
    ),
    "props": lorenz63_drift(
        params=ParameterProperties()
    ),
}

# Define the true parameters of the diffusion function
true_diffusion_cov = {
    "params": LearnableMatrix(
        params=jnp.eye(state_dim)
    ),
    "props": LearnableMatrix(
        params=ParameterProperties(
            constrainer=RealToPSDBijector()
        )
    ),
}

# Define the true parameters of the diffusion function
true_diffusion_coefficient_param = 0.1
true_diffusion_coefficient = {
    "params": LearnableMatrix(
        params=true_diffusion_coefficient_param * jnp.eye(state_dim)
    ),
    "props": LearnableMatrix(
        params=ParameterProperties()
    ),
}

# Define the true parameters of the emission function
# Partial observability
H=jnp.array(
    [[1.0, 0.0, 0.0]]
)
true_emission = {
    "params": LearnableLinear(
        weights=H,
        bias=jnp.zeros(emission_dim)
    ),
    "props": LearnableLinear(
        weights=ParameterProperties(),
        bias=ParameterProperties()
    ),
}

# Define the true parameters of the emission covariance
R=jnp.eye(emission_dim)
true_emission_cov = {
    "params": LearnableMatrix(
        params=R
    ),
    "props": LearnableMatrix(
        params=ParameterProperties(
            constrainer=RealToPSDBijector()
        )
    ),
}

# Define the true initial mean and covariance
true_initial_mean = {
    "params": LearnableVector(
        params=jnp.zeros(state_dim)
    ),
    "props": LearnableVector(
        params=ParameterProperties()
    ),
}

true_initial_cov_param = 10.0
true_initial_cov = {
    "params": LearnableMatrix(
        params=true_initial_cov_param*jnp.eye(state_dim)
    ),
    "props": LearnableMatrix(
        params=ParameterProperties(
            constrainer=RealToPSDBijector()
        )
    ),
}

# Concatenate all parameters in dictionary, for later easy use
all_true_params = {
    'initial_mean': true_initial_mean,
    'initial_cov': true_initial_cov,
    'dynamics_drift': true_drift,
    'dynamics_diffusion_coefficient': true_diffusion_coefficient,
    'dynamics_diffusion_cov': true_diffusion_cov,
    'dynamics_approx_order': 2., # Check on this later
    'emission_function': true_emission,
    'emission_cov': true_emission_cov,
}

# Simulation 1: Fast sample rate and modest initial state covariance

### Simulation set-up

In [23]:
# Simulation parameters
T_total = 50
num_timesteps_total = 5000

# We collect measurements at irregular times sampled uniformly from a time domain $[0,40]$.
sim1_t = jnp.array(
        sorted(
            jr.uniform(
                jr.PRNGKey(0),
                (num_timesteps_total, 1),
                minval=0,
                maxval=T_total
            )
        )
    )
# drop duplicates
sim1_t = jnp.unique(sim1_t)[:, None]

# Separate filtering and forecasting time points
T_filter = 40
sim1_t_filter = sim1_t[sim1_t <= T_filter, None]
sim1_t_forecast = sim1_t[sim1_t > T_filter, None]

# Count number of time points
sim1_num_timesteps = len(sim1_t)
sim1_num_timesteps_filter = len(sim1_t_filter)
sim1_num_timesteps_forecast = len(sim1_t_forecast)

# Set up seed for simulation
keys = map(jr.PRNGKey, count())

### Model creation: object instantiation, with modest initial state covariance

In [5]:
sim1_state_sd = 10.0
sim1_initial_cov_param = {
    "params": LearnableMatrix(
        params=sim1_state_sd**2 *jnp.eye(state_dim)
    ),
    "props": LearnableMatrix(
        params=ParameterProperties(
            constrainer=RealToPSDBijector()
        )
    ),
}

# Concatenate all parameters in dictionary, for later easy use
all_sim1_params = {
    'initial_mean': true_initial_mean,
    'initial_cov': sim1_initial_cov_param,
    'dynamics_drift': true_drift,
    'dynamics_diffusion_coefficient': true_diffusion_coefficient,
    'dynamics_diffusion_cov': true_diffusion_cov,
    'dynamics_approx_order': 2., # Check on this later
    'emission_function': true_emission,
    'emission_cov': true_emission_cov,
}

# Create CD-NLGSSM model
sim1_model = ContDiscreteNonlinearGaussianSSM(state_dim, emission_dim)
sim1_params, _ = sim1_model.initialize(
    next(keys),
    **all_sim1_params
)

### Simulate data: sample from model

In [None]:
# sample true states and emissions from defined true model.
# Using transition_type="path" to solve the dynamics SDE.
sim1_states, sim1_emissions = sim1_model.sample(
    sim1_params,
    next(keys),
    sim1_num_timesteps,
    sim1_t,
    transition_type="path" # uses the Euler-Maruyama method
)

### Plot generated data

In [None]:
plot_advanced(
    time_grid_all=sim1_t,
    time_grid_filter=sim1_t_filter,
    time_grid_forecast=sim1_t_forecast,
    true_states=sim1_states,
    true_emissions_noisy=sim1_emissions,
    emission_function=sim1_params.emissions.emission_function
)

### State estimation with Extended Kalman Filter

In [8]:
# Filter: EKF
from continuous_discrete_nonlinear_gaussian_ssm import EKFHyperParams

In [9]:
# Execute the EKF filter
sim1_ekf_filtered = cdnlgssm_filter(
    sim1_params,
    sim1_emissions[:sim1_num_timesteps_filter], # Filter based on the first part of the data
    sim1_t_filter,
    hyperparams=EKFHyperParams()
)


In [None]:
# Plot the true states and emissions, and the EKF estimates
plot_advanced(
        time_grid_all=sim1_t,
        time_grid_filter=sim1_t_filter,
        time_grid_forecast=sim1_t_forecast,
        true_states=sim1_states,
        true_emissions_noisy=sim1_emissions,
        emission_function=sim1_params.emissions.emission_function,
        model_filtered_states=sim1_ekf_filtered.filtered_means,
        model_filtered_covariances=sim1_ekf_filtered.filtered_covariances,
        t_start=None,
        t_end=None,
    )

In [None]:
### Compute and print RMSE between true states and EKF estimates
compute_and_print_rmse_comparison(
    sim1_states[:sim1_num_timesteps_filter, -1],
    sim1_ekf_filtered.filtered_means[:,-1],
    1,
    "EKF"
)

In [None]:
# EKF forecasting: Distributional forecast
# Initialize at last filtered state
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
ekf_init_forecast = MVN(
    sim1_ekf_filtered.filtered_means[-1,:],
    sim1_ekf_filtered.filtered_covariances[-1,:]
)

# With this mean and covariance
print(ekf_init_forecast.mean())
print(ekf_init_forecast.covariance())


In [13]:
# Execute the EKF forecast, using the last filtered state as initial forecast
sim1_ekf_forecast = cdnlgssm_forecast(
    params = sim1_params,
    init_forecast = ekf_init_forecast,
    t_emissions=sim1_t_forecast,
    hyperparams=EKFHyperParams()
)

In [None]:
sim1_t_forecast

In [None]:
sim1_t_forecast.shape, sim1_ekf_forecast.forecasted_state_means.shape

In [None]:
# Plot the true states and emissions, and the EKF estimates
for (t_start, t_end) in [(None, None), (T_filter, T_total)]:
    plot_advanced(
            time_grid_all=sim1_t,
            time_grid_filter=sim1_t_filter,
            time_grid_forecast=sim1_t_forecast,
            true_states=sim1_states,
            true_emissions_noisy=sim1_emissions,
            emission_function=sim1_params.emissions.emission_function,
            model_filtered_states=sim1_ekf_filtered.filtered_means,
            model_filtered_covariances=sim1_ekf_filtered.filtered_covariances,
            model_forecast_states=sim1_ekf_forecast.forecasted_state_means,
            model_forecast_covariances=sim1_ekf_forecast.forecasted_state_covariances,
            t_start=t_start,
            t_end=t_end,
        )

In [24]:
# EKF forecasting: Path-based forecast
# Initialize at last filtered state mean
init_path_forecast = sim1_ekf_filtered.filtered_means[-1,:]

In [25]:
# Execute the forecast, using EKF's last filtered mean state as initial forecast
sim1_forecast_ekf_path = cdnlgssm_forecast(
    params = sim1_params,
    init_forecast = init_path_forecast,
    t_emissions=sim1_t_forecast,
    hyperparams=None,
    key = next(keys)
)

In [None]:
# Plot the true states and emissions, and the EKF estimates
for (t_start, t_end) in [(None, None), (T_filter, T_total)]:
    plot_advanced(
            time_grid_all=sim1_t,
            time_grid_filter=sim1_t_filter,
            time_grid_forecast=sim1_t_forecast,
            true_states=sim1_states,
            true_emissions_noisy=sim1_emissions,
            emission_function=sim1_params.emissions.emission_function,
            model_filtered_states=sim1_ekf_filtered.filtered_means,
            model_filtered_covariances=sim1_ekf_filtered.filtered_covariances,
            model_forecast_states=sim1_forecast_ekf_path.forecasted_state_path,
            t_start=t_start,
            t_end=t_end,
        )

### State estimation with Ensemble Kalman Filter

In [15]:
# Filter: EnKF
from continuous_discrete_nonlinear_gaussian_ssm import EnKFHyperParams

In [16]:
# Execute the EnKF filter
sim1_enkf_filtered = cdnlgssm_filter(
    sim1_params,
    sim1_emissions[:sim1_num_timesteps_filter], # Filter based on the first part of the data
    sim1_t_filter,
    hyperparams=EnKFHyperParams()
)


In [None]:
# Plot the true states and emissions, and the EKF estimates
plot_advanced(
        time_grid_all=sim1_t,
        time_grid_filter=sim1_t_filter,
        time_grid_forecast=sim1_t_forecast,
        true_states=sim1_states,
        true_emissions_noisy=sim1_emissions,
        emission_function=sim1_params.emissions.emission_function,
        model_filtered_states=sim1_ekf_filtered.filtered_means,
        model_filtered_covariances=sim1_ekf_filtered.filtered_covariances,
        t_start=None,
        t_end=None,
    )

In [None]:
compute_and_print_rmse_comparison(
    sim1_states[:sim1_num_timesteps_filter, -1],
    sim1_ekf_filtered.filtered_means[:,-1],
    1,
    "EKF"
)

In [None]:
# EnKF forecasting
# Initialize at last filtered state
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
enkf_init_forecast = MVN(
    sim1_enkf_filtered.filtered_means[-1,:],
    sim1_enkf_filtered.filtered_covariances[-1,:]
)

# With this mean and covariance
print(enkf_init_forecast.mean())
print(enkf_init_forecast.covariance())


In [20]:
# Execute the EKF forecast, using the last filtered state as initial forecast
sim1_enkf_forecast = cdnlgssm_forecast(
    params = sim1_params,
    init_forecast = enkf_init_forecast,
    t_emissions=sim1_t_forecast,
    hyperparams=EnKFHyperParams()
)

In [None]:
# Plot the true states and emissions, and the EKF estimates
for (t_start, t_end) in [(None, None), (T_filter, T_total)]:
    plot_advanced(
            time_grid_all=sim1_t,
            time_grid_filter=sim1_t_filter,
            time_grid_forecast=sim1_t_forecast,
            true_states=sim1_states,
            true_emissions_noisy=sim1_emissions,
            emission_function=sim1_params.emissions.emission_function,
            model_filtered_states=sim1_enkf_filtered.filtered_means,
            model_filtered_covariances=sim1_enkf_filtered.filtered_covariances,
            model_forecast_states=sim1_enkf_forecast.forecasted_state_means,
            model_forecast_covariances=sim1_enkf_forecast.forecasted_state_covariances,
            t_start=t_start,
            t_end=t_end,
        )

### State estimation with Uscented Kalman Filter

In [22]:
# Filter: UKF
from continuous_discrete_nonlinear_gaussian_ssm import UKFHyperParams

In [23]:
# Execute the UKF filter
sim1_ukf_filtered = cdnlgssm_filter(
    sim1_params,
    sim1_emissions,
    sim1_t_filter,
    hyperparams=UKFHyperParams()
)


In [None]:
# Plot the true states and emissions, and the UKF estimates
plot_advanced(
        time_grid_all=sim1_t,
        time_grid_filter=sim1_t_filter,
        time_grid_forecast=sim1_t_forecast,
        true_states=sim1_states,
        true_emissions_noisy=sim1_emissions,
        emission_function=sim1_params.emissions.emission_function,
        model_filtered_states=sim1_ukf_filtered.filtered_means,
        model_filtered_covariances=sim1_ukf_filtered.filtered_covariances,
        t_start=None,
        t_end=None,
    )

In [None]:
### Compute and print RMSE between true states and EKF estimates
compute_and_print_rmse_comparison(
    sim1_states[:sim1_num_timesteps_filter, -1],
    sim1_ukf_filtered.filtered_means[:,-1],
    1,
    "UKF"
)

In [None]:
# UKF forecasting
# Initialize at last filtered state
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
ukf_init_forecast = MVN(
    sim1_ukf_filtered.filtered_means[-1,:],
    sim1_ukf_filtered.filtered_covariances[-1,:]
)

# With this mean and covariance
print(ukf_init_forecast.mean())
print(ukf_init_forecast.covariance())


In [27]:
# Execute the EKF forecast, using the last filtered state as initial forecast
sim1_ukf_forecast = cdnlgssm_forecast(
    params = sim1_params,
    init_forecast = ukf_init_forecast,
    t_emissions=sim1_t_forecast,
    hyperparams=UKFHyperParams()
)

In [None]:
sim1_ukf_forecast.forecasted_emission_means

In [None]:
# Plot the true states and emissions, and the UKF estimates
for (t_start, t_end) in [(None, None), (T_filter, T_total)]:
    plot_advanced(
            time_grid_all=sim1_t,
            time_grid_filter=sim1_t_filter,
            time_grid_forecast=sim1_t_forecast,
            true_states=sim1_states,
            true_emissions_noisy=sim1_emissions,
            emission_function=sim1_params.emissions.emission_function,
            model_filtered_states=sim1_ukf_filtered.filtered_means,
            model_filtered_covariances=sim1_ukf_filtered.filtered_covariances,
            model_forecast_states=sim1_ukf_forecast.forecasted_state_means,
            model_forecast_covariances=sim1_ukf_forecast.forecasted_state_covariances,
            t_start=t_start,
            t_end=t_end,
        )

# Simulation 2: Slow sample rate and larger initial state covariance

### Simulation set-up

In [23]:
# Simulation parameters
T = 40
num_timesteps = 400

# We collect measurements at irregular times sampled uniformly from a time domain $[0,40]$.
sim2_t_emissions = jnp.array(
        sorted(
            jr.uniform(
                jr.PRNGKey(0),
                (num_timesteps, 1),
                minval=0,
                maxval=T
            )
        )
    )
# drop duplicates
sim2_t_emissions = jnp.unique(sim2_t_emissions)[:, None]
sim2_num_timesteps = len(sim2_t_emissions)

# Set up seed for simulation
keys = map(jr.PRNGKey, count())

### Model creation: object instantiation, with modest initial state covariance

In [24]:
sim2_state_sd = 20.0
sim2_initial_cov_param = {
    "params": LearnableMatrix(
        params=sim2_state_sd**2 *jnp.eye(state_dim)
    ),
    "props": LearnableMatrix(
        params=ParameterProperties(
            constrainer=RealToPSDBijector()
        )
    ),
}

# Concatenate all parameters in dictionary, for later easy use
all_sim2_params = {
    'initial_mean': true_initial_mean,
    'initial_cov': sim2_initial_cov_param,
    'dynamics_drift': true_drift,
    'dynamics_diffusion_coefficient': true_diffusion_coefficient,
    'dynamics_diffusion_cov': true_diffusion_cov,
    'dynamics_approx_order': 2., # Check on this later
    'emission_function': true_emission,
    'emission_cov': true_emission_cov,
}

# Create CD-NLGSSM model
sim2_model = ContDiscreteNonlinearGaussianSSM(state_dim, emission_dim)
sim2_params, _ = sim1_model.initialize(
    next(keys),
    **all_sim1_params
)

### Simulate data: sample from model

In [None]:
# sample true states and emissions from defined true model.
# Using transition_type="path" to solve the dynamics SDE.
sim2_states, sim2_emissions = sim2_model.sample(
    sim2_params,
    next(keys),
    sim2_num_timesteps,
    sim2_t_emissions,
    transition_type="path" # uses the Euler-Maruyama method
)

### Plot generated data

In [None]:
plot_lorenz(
    sim2_t_emissions,
    sim2_states,
    sim2_emissions
)

### State estimation with Extended Kalman Filter

In [27]:
# Filter: EKF
from continuous_discrete_nonlinear_gaussian_ssm import EKFHyperParams

In [28]:
# Execute the EKF filter
sim2_ekf_filtered = cdnlgssm_filter(
    sim2_params,
    sim2_emissions,
    sim2_t_emissions,
    hyperparams=EKFHyperParams()
)


In [None]:
# Plot the true states and emissions, and the EKF estimates
plot_advanced(
    sim2_t_emissions,
    sim2_states,
    sim2_emissions,
    x_est=sim2_ekf_filtered.filtered_means,
    x_unc=sim2_ekf_filtered.filtered_covariances,
    est_type="EKF"
)

In [None]:
### Compute and print RMSE between true states and EKF estimates
compute_and_print_rmse_comparison(
    sim2_states[:, -1],
    sim2_ekf_filtered.filtered_means[:,-1],
    1,
    "EKF"
)

### State estimation with Ensemble Kalman Filter

In [31]:
# Filter: EnKF
from continuous_discrete_nonlinear_gaussian_ssm import EnKFHyperParams

In [None]:
# Execute the EnKF filter
sim2_enkf_filtered = cdnlgssm_filter(
    sim2_params,
    sim2_emissions,
    sim2_t_emissions,
    hyperparams=EnKFHyperParams()
)


In [None]:
# Plot the true states and emissions, and the EKF estimates
plot_advanced(
    sim2_t_emissions,
    sim2_states,
    sim2_emissions,
    x_est=sim2_enkf_filtered.filtered_means,
    x_unc=sim2_enkf_filtered.filtered_covariances,
    est_type="EnKF"
)

In [None]:
### Compute and print RMSE between true states and EKF estimates
compute_and_print_rmse_comparison(
    sim2_states[:, -1],
    sim2_enkf_filtered.filtered_means[:,-1],
    1,
    "EnKF"
)

### State estimation with Uscented Kalman Filter

In [32]:
# Filter: UKF
from continuous_discrete_nonlinear_gaussian_ssm import UKFHyperParams

In [33]:
# Execute the UKF filter
sim2_ukf_filtered = cdnlgssm_filter(
    sim2_params,
    sim2_emissions,
    sim2_t_emissions,
    hyperparams=UKFHyperParams()
)


In [None]:
# Plot the true states and emissions, and the UKF estimates
plot_advanced(
    sim2_t_emissions,
    sim2_states,
    sim2_emissions,
    x_est=sim2_ukf_filtered.filtered_means,
    x_unc=sim2_ukf_filtered.filtered_covariances,
    est_type="UKF"
)

In [None]:
### Compute and print RMSE between true states and EKF estimates
compute_and_print_rmse_comparison(
    sim2_states[:, -1],
    sim2_ukf_filtered.filtered_means[:,-1],
    1,
    "UKF"
)