In [None]:
import sys
sys.path.append('..//')

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')

from newton_smoothers.base import MVNStandard, FunctionalModel
from newton_smoothers.approximation import extended, quadratize

from newton_smoothers import trust_region_iterated_recursive_newton_smoother
from newton_smoothers import trust_region_iterated_recursive_gauss_newton_smoother

from bearing_data import get_data, make_parameters

In [None]:
%%capture

s1 = jnp.array([-1.5, 0.5])  # First sensor location
s2 = jnp.array([1., 1.])  # Second sensor location
r = 0.5  # Observation noise (stddev)
x0 = jnp.array([0.1, 0.2, 1, 0])  # initial true location

dt = 0.01  # discretization time step
qc = 0.01  # discretization noise
qw = 0.1  # discretization noise

T = 500
nx, ny = 5, 2

_, true_states, observations = get_data(x0, dt, r, T, s1, s2, random_state=7)
Q, R, trans_fcn, obsrv_fcn, _, _ = make_parameters(qc, qw, r, dt, s1, s2)

trans_mdl = FunctionalModel(trans_fcn, MVNStandard(jnp.zeros((nx,)), Q))
obsrv_mdl = FunctionalModel(obsrv_fcn, MVNStandard(jnp.zeros((ny,)), R))

init_dist = MVNStandard(
    mean=jnp.array([-1., -1., 0., 0., 0.]),
    cov=jnp.eye(nx)
)

init_nominal = MVNStandard(
    mean=jnp.zeros((T + 1, nx)),
    cov=jnp.repeat(jnp.eye(nx).reshape(1, nx, nx), T + 1, axis=0),
)
init_nominal.mean.at[0].set(init_dist.mean)
init_nominal.cov.at[0].set(init_dist.cov)

In [None]:
# Newton Recursive Iterated Smoother
recursive_newton_smoothed, recursive_newton_costs =\
    trust_region_iterated_recursive_newton_smoother(init_nominal,
                                                    observations,
                                                    init_dist,
                                                    trans_mdl,
                                                    obsrv_mdl,
                                                    quadratize,
                                                    nb_iter=25)

# Gauss-Newton recursive Iterated Smoother
recursive_gauss_newton_smoothed, recursive_gauss_newton_costs =\
    trust_region_iterated_recursive_gauss_newton_smoother(init_nominal,
                                                          observations,
                                                          init_dist,
                                                          trans_mdl,
                                                          obsrv_mdl,
                                                          extended,
                                                          nb_iter=25)

In [None]:
from matplotlib import pyplot as plt
plt.plot(recursive_newton_costs, "*--", label="newton")
plt.plot(recursive_gauss_newton_costs, ':', label="gauss-newton")
plt.yscale('symlog', linthresh=0.01)
plt.legend()
plt.grid()
plt.ylabel("log posterior")
plt.xlabel("iteration")