In [1]:
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

from parsmooth._base import MVNStandard, FunctionalModel
from parsmooth.linearization import extended, extended_hessian
from parsmooth.sequential._filtering import filtering
from parsmooth.sequential._filtering_Newton import filtering as newton_filtering
from bearing_data import get_data, make_parameters

In [7]:
s1 = jnp.array([-1.5, 0.5])  # First sensor location
s2 = jnp.array([1., 1.])  # Second sensor location
r = 0.5  # Observation noise (stddev)
x0 = jnp.array([0.1, 0.2, 1, 0])  # initial true location
dt = 0.01  # discretization time step
qc = 0.01  # discretization noise
qw = 0.1  # discretization noise

T = 100
_, true_states, ys = get_data(x0, dt, r, T, s1, s2)
Q, R, observation_function, transition_function = make_parameters(qc, qw, r, dt, s1, s2)

chol_Q = jnp.linalg.cholesky(Q)
chol_R = jnp.linalg.cholesky(R)

m0 = jnp.array([-4., -1., 2., 7., 3.])
chol_P0 = jnp.eye(5)
P0 = jnp.eye(5)

init = MVNStandard(m0, P0)

initial_states =  MVNStandard(jnp.repeat(jnp.array([[0., 0., 6., 4., 2.]]),T + 1, axis=0),
                                                     jnp.repeat(jnp.eye(5).reshape(1, 5, 5), T + 1, axis=0))

transition_model = FunctionalModel(transition_function, MVNStandard(jnp.zeros((5,)), Q))
observation_model = FunctionalModel(observation_function, MVNStandard(jnp.zeros((2,)), R))

## Filtering

In [8]:
filtered_states_newton = newton_filtering(ys, init, transition_model, observation_model, extended_hessian, None)


In [4]:
filtered_states_newton.mean

DeviceArray([[-4., -1.,  2.,  7.,  3.],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan],


In [9]:
#Filtering
filtered_states = filtering(ys, init, transition_model, observation_model, extended, None)

In [6]:
filtered_states.mean

DeviceArray([[-4.00000000e+00, -1.00000000e+00,  2.00000000e+00,
               7.00000000e+00,  3.00000000e+00],
             [-3.07265175e+00, -2.44115914e+00,  2.21793626e+00,
               6.92160731e+00,  3.00047690e+00],
             [-1.50820964e+00, -2.86163659e+00,  2.45882339e+00,
               6.83468557e+00,  3.00340659e+00],
             [-5.48232115e-01, -2.38805164e+00,  2.69491171e+00,
               6.76604894e+00,  3.00645173e+00],
             [-7.32301385e-01, -1.94530646e+00,  2.88043015e+00,
               6.70797080e+00,  3.00222677e+00],
             [-3.36842418e-01, -1.50132977e+00,  3.10614497e+00,
               6.63860412e+00,  3.00522196e+00],
             [-4.06828549e-01, -1.35985233e+00,  3.28960280e+00,
               6.55367583e+00,  3.00085316e+00],
             [ 1.01957158e-01, -8.42702226e-01,  3.55875161e+00,
               6.48437068e+00,  3.01530255e+00],
             [ 8.32971509e-03, -5.17406640e-01,  3.71260084e+00,
               6.423820