In [None]:
"""
Implementation of the Bootrstrap Filter for discrete time systems
**This implementation considers the case of multivariate normals**


"""
import jax.numpy as jnp
from jax import random, lax

import chex

from jax.scipy import stats
from jsl.nlds.base import NLDS


# TODO: Extend to general case
def filter(params: NLDS,
           key: chex.PRNGKey,
           init_state: chex.Array,
           sample_obs: chex.Array,
           nsamples: int = 2000,
           Vinit: chex.Array = None):
    """
    init_state: array(state_size,)
        Initial state estimate
    sample_obs: array(nsamples, obs_size)
        Samples of the observations
    """
    m, *_ = init_state.shape

    fx, fz = params.fx, params.fz
    Q, R = params.Qz, params.Rx

    key, key_init = random.split(key, 2)
    V = Q(init_state) if Vinit is None else Vinit
    zt_rvs = random.multivariate_normal(key_init, init_state, V, shape=(nsamples,))

    init_state = (zt_rvs, key)

    def __filter_step(state, obs_t):
        indices = jnp.arange(nsamples)
        zt_rvs, key_t = state

        key_t, key_reindex, key_next = random.split(key_t, 3)
        # 1. Draw new points from the dynamic model
        zt_rvs = random.multivariate_normal(key_t, fz(zt_rvs), Q(zt_rvs))

        # 2. Calculate unnormalised weights
        xt_rvs = fx(zt_rvs)
        weights_t = stats.multivariate_normal.pdf(obs_t, xt_rvs, R(zt_rvs, obs_t))

        # 3. Resampling
        pi = random.choice(key_reindex, indices,
                           p=weights_t, shape=(nsamples,))
        zt_rvs = zt_rvs[pi, ...]
        weights_t = jnp.ones(nsamples) / nsamples

        # 4. Compute latent-state estimate,
        #    Set next covariance state matrix
        mu_t = jnp.einsum("im,i->m", zt_rvs, weights_t)

        return (zt_rvs, key_next), mu_t

    _, mu_hist = lax.scan(__filter_step, init_state, sample_obs)

    return mu_hist

In [None]:
# Library of nonlinear dynamical systems
# Usage: Every discrete xKF class inherits from NLDS.
# There are two ways to use this library in the discrete case:
# 1) Explicitly initialize a discrete NLDS object with the desired parameters,
#    then pass it onto the xKF class of your choice.
# 2) Initialize the xKF object with the desired NLDS parameters using
#    the .from_base constructor.
# Way 1 is preferable whenever you want to use the same NLDS for multiple
# filtering processes. Way 2 is preferred whenever you want to use a single NLDS
# for a single filtering process

# Author: Gerardo Durán-Martín (@gerdm)

import jax
from jax.random import split, multivariate_normal

import chex

from dataclasses import dataclass
from typing import Callable


@dataclass
class NLDS:
    """
    Base class for the nonlinear dynamical systems' module

    Parameters
    ----------
    fz: function
        Nonlinear state transition function
    fx: function
        Nonlinear observation function
    Q: array(state_size, state_size) or function
        Nonlinear state transition noise covariance function
    R: array(obs_size, obs_size) or function
        Nonlinear observation noise covariance function
    """
    fz: Callable
    fx: Callable
    Q: chex.Array
    R: chex.Array
    alpha: float = 0.
    beta: float = 0.
    kappa: float = 0.
    d: int = 0

    def Qz(self, z, *args):
        if callable(self.Q):
            return self.Q(z, *args)
        else:
            return self.Q

    def Rx(self, x, *args):
        if callable(self.R):
            return self.R(x, *args)
        else:
            return self.R

    def __sample_step(self, input_vals, obs):
        key, state_t = input_vals
        key_system, key_obs, key = split(key, 3)

        state_t = multivariate_normal(key_system, self.fz(state_t), self.Qz(state_t))
        obs_t = multivariate_normal(key_obs, self.fx(state_t, *obs), self.Rx(state_t, *obs))

        return (key, state_t), (state_t, obs_t)

    def sample(self, key, x0, nsteps, obs=None):
        """
        Sample discrete elements of a nonlinear system
        Parameters
        ----------
        key: jax.random.PRNGKey
        x0: array(state_size)
            Initial state of simulation
        nsteps: int
            Total number of steps to sample from the system
        obs: None, tuple of arrays
            Observed values to pass to fx and R
        Returns
        -------
        * array(nsamples, state_size)
            State-space values
        * array(nsamples, obs_size)
            Observed-space values
        """
        obs = () if obs is None else obs
        state_t = x0.copy()
        obs_t = self.fx(state_t)

        self.state_size, *_ = state_t.shape
        self.obs_t, *_ = obs_t.shape

        init_state = (key, state_t)
        _, hist = jax.lax.scan(self.__sample_step, init_state, obs, length=nsteps)

        return hist

In [None]:
# Demo of the bootstrap filter under a
# nonlinear discrete system

import jax
from jsl.nlds.base import NLDS
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import random


def plot_samples(sample_state, sample_obs, ax=None):
    fig, ax = plt.subplots()
    ax.plot(*sample_state.T, label="state space")
    ax.scatter(*sample_obs.T, s=60, c="tab:green", marker="+")
    ax.scatter(*sample_state[0], c="black", zorder=3)
    ax.legend()
    ax.set_title("Noisy observations from hidden trajectory")
    plt.axis("equal")
    return fig


def plot_inference(sample_obs, mean_hist):
    fig, ax = plt.subplots()
    ax.scatter(*sample_obs.T, marker="+", color="tab:green", s=60)
    ax.plot(*mean_hist.T, c="tab:orange", label="filtered")
    ax.scatter(*mean_hist[0], c="black", zorder=3)
    plt.legend()
    plt.axis("equal")
    return fig

def main():
    def fz(x, dt): return x + dt * jnp.array([jnp.sin(x[1]), jnp.cos(x[0])])
    def fx(x): return x

    dt = 0.4
    nsteps = 100
    # Initial state vector
    x0 = jnp.array([1.5, 0.0])
    # State noise
    Qt = jnp.eye(2) * 0.001
    # Observed noise
    Rt = jnp.eye(2) * 0.05

    key = random.PRNGKey(314)
    model = NLDS(lambda x: fz(x, dt), fx, Qt, Rt)
    sample_state, sample_obs = model.sample(key, x0, nsteps)

    n_particles = 3_000
    fz_vec = jax.vmap(fz, in_axes=(0, None))
    particle_filter = NLDS(lambda x: fz_vec(x, dt), fx, Qt, Rt)
    pf_mean = filter(particle_filter, key, x0, sample_obs, n_particles)

    dict_figures = {}
    fig_boostrap = plot_inference(sample_obs, pf_mean)
    dict_figures["nlds2d_bootstrap"] = fig_boostrap

    fig_data = plot_samples(sample_state, sample_obs)
    dict_figures["nlds2d_data"] = fig_data

    return dict_figures

if __name__ == "__main__":
    from jsl.demos.plot_utils import savefig
    plt.rcParams["axes.spines.right"] = False
    plt.rcParams["axes.spines.top"] = False
    dict_figures = main()
    savefig(dict_figures)
    plt.show()