In [1]:
import jax
import numpy as np
import jax.numpy as jnp
from matplotlib import pyplot as plt

from parsmooth._base import MVNStandard, FunctionalModel
from parsmooth.linearization import extended, extended_hessian
from parsmooth.sequential._filtering import filtering
from parsmooth.sequential._filtering_Newton import filtering as newton_filtering

from parsmooth.methods import iterated_smoothing
from bearing_data import get_data, make_parameters

In [8]:
s1 = jnp.array([-1.5, 0.5])  # First sensor location
s2 = jnp.array([1., 1.])  # Second sensor location
r = 0.5  # Observation noise (stddev)
x0 = jnp.array([0.1, 0.2, 1, 0])  # initial true location
dt = 0.01  # discretization time step
qc = 0.01  # discretization noise
qw = 0.1  # discretization noise

T = 2
_, true_states, ys = get_data(x0, dt, r, T, s1, s2)
Q, R, observation_function, transition_function = make_parameters(qc, qw, r, dt, s1, s2)

chol_Q = jnp.linalg.cholesky(Q)
chol_R = jnp.linalg.cholesky(R)

m0 = jnp.array([-4., -1., 2., 7., 3.])
chol_P0 = jnp.eye(5)
P0 = jnp.eye(5)

init = MVNStandard(m0, P0)

initial_states =  MVNStandard(jnp.repeat(jnp.array([[-1., -1., 6., 4., 2.]]),T + 1, axis=0),
                                                     jnp.repeat(jnp.eye(5).reshape(1, 5, 5), T + 1, axis=0))

transition_model = FunctionalModel(transition_function, MVNStandard(jnp.zeros((5,)), Q))
observation_model = FunctionalModel(observation_function, MVNStandard(jnp.zeros((2,)), R))

## Filtering

In [7]:
filtered_states_newton = newton_filtering(ys, init, transition_model, observation_model, extended_hessian, None)


[nan nan nan nan nan]
[nan nan nan nan nan]


In [4]:
filtered_states_newton.mean

DeviceArray([[-4., -1.,  2.,  7.,  3.],
             [nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan]], dtype=float32)

In [5]:
#Filtering
# filtered_states = filtering(ys, init, transition_model, observation_model, extended, None)

In [6]:
# filtered_states.mean