In [1]:
import sys
sys.path.append('..//')

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from functools import partial

from parsmooth._base import MVNStandard, FunctionalModelX, MVNSqrt
from parsmooth.linearization import cubature, extended, gauss_hermite
from parsmooth.methods import iterated_smoothing, filtering, smoothing
from parsmooth.linearization._cubature import _get_sigma_points
from parsmooth.linearization._sigma_points import linearize_functional
from parsmooth._utils import cholesky_update_many, tria

from parsmooth.parallel._filtering import _standard_associative_params
from parsmooth.parallel._filtering import _sqrt_associative_params

from parsmooth.parallel._operators import standard_filtering_operator
from parsmooth.parallel._operators import sqrt_filtering_operator
from parsmooth.parallel._operators import standard_smoothing_operator
from parsmooth.parallel._operators import sqrt_smoothing_operator

from parsmooth.parallel._smoothing import _associative_params

from parsmooth.parallel._operators import standard_smoothing_operator
from parsmooth.parallel._operators import sqrt_smoothing_operator

from parsmooth.sequential._filtering import _sqrt_predict, _sqrt_update, _standard_predict, _standard_update

from bearings.bearings_utils import make_parameters

import matplotlib.pyplot as plt

In [2]:
linearization_method = cubature
jax.config.update("jax_enable_x64", True)
s1 = jnp.array([-1.5, 0.5])  # First sensor location
s2 = jnp.array([1., 1.])  # Second sensor location
r = 0.5  # Observation noise (stddev)
dt = 0.01  # discretization time step
qc = 0.01  # discretization noise
qw = 0.1  # discretization noise

ys = np.load("bearings/ys.npy")
# if linearization_method is extended:
#     with np.load("tests/bearings//ieks.npz") as loaded:
#         expected_mean, expected_cov = loaded["arr_0"], loaded["arr_1"]
# elif linearization_method is cubature:
#     with np.load("./bearings//icks.npz") as loaded:
#         expected_mean, expected_cov = loaded["arr_0"], loaded["arr_1"]
# else:
#     pytest.skip("We don't have regression data for this linearization")

if linearization_method is extended:
    with np.load("bearings//previous_results_new.npz") as loaded:
        expected_mean, expected_cov = loaded["expected_mean_extended"], loaded["expected_cov_extended"]
elif linearization_method is cubature:
    with np.load("bearings//previous_results_new.npz") as loaded:
        expected_mean, expected_cov = loaded["expected_mean_cubature"], loaded["expected_cov_cubature"]
else:
    pass


Q_, R, observation_function, transition_function = make_parameters(qc, qw, r, dt, s1, s2)

m0 = jnp.array([-4., -1., 2., 7., 3.])
chol_P0 = jnp.eye(5)
P0 = jnp.eye(5)

chol_Q = jnp.linalg.cholesky(Q)
chol_R = jnp.linalg.cholesky(R)

T = ys.shape[0]
initial_states =  MVNStandard(jnp.repeat(jnp.array([[-1., -1., 6., 4., 2.]]),T, axis=0),
                                                     jnp.repeat(jnp.eye(5).reshape(1, 5, 5), T, axis=0))


initial_states_sqrt = MVNSqrt(jnp.repeat(jnp.array([[-1., -1., 6., 4., 2.]]),T, axis=0),
                              jnp.repeat(jnp.eye(5).reshape(1, 5, 5), T, axis=0))

init = MVNStandard(m0, P0)
chol_init = MVNSqrt(m0, chol_P0)


sqrt_transition_model = FunctionalModelX(transition_function, MVNSqrt(jnp.zeros((5,)), chol_Q))
transition_model = FunctionalModelX(transition_function, MVNStandard(jnp.zeros((5,)), Q))

sqrt_observation_model = FunctionalModelX(observation_function, MVNSqrt(jnp.zeros((2,)), chol_R))
observation_model = FunctionalModelX(observation_function, MVNStandard(jnp.zeros((2,)), R))





# check parallel-sqrt v.s parallel-str (initialization and combination in filtering and smoothing)

In [3]:
# Initialization params in Parallel-Filtering

params_str = _standard_associative_params(linearization_method, transition_model, observation_model,
                                          initial_states, init, ys)
params_sqrt = _sqrt_associative_params(linearization_method, sqrt_transition_model, sqrt_observation_model,
                                       initial_states_sqrt, chol_init, ys)

In [4]:
# (A, b, C, eta , J) v.s (A, b, U, eta, Z)
np.testing.assert_array_almost_equal(params_str[0][0], params_sqrt[0][0], decimal = 7)

np.testing.assert_array_almost_equal(params_str[0][1], params_sqrt[0][1], decimal = 7)

np.testing.assert_array_almost_equal(params_str[0][2],
                                     params_sqrt[0][2] @ np.transpose(params_sqrt[0][2], [0, 2, 1]),
                                     decimal = 7)

np.testing.assert_array_almost_equal(params_str[0][3], params_sqrt[0][3], decimal = 7)

np.testing.assert_array_almost_equal(params_str[0][4],
                                     params_sqrt[0][4] @ np.transpose(params_sqrt[0][4], [0, 2, 1]),
                                     decimal = 7)

In [5]:
# (F, Q, b, H ,R, c) v.s (F, cholQ, b, H ,cholR, c)
np.testing.assert_array_almost_equal(params_str[1][0], params_sqrt[1][0], decimal = 7)

np.testing.assert_array_almost_equal(params_str[1][1],
                                     params_sqrt[1][1] @ np.transpose(params_sqrt[1][1], [0, 2, 1]),
                                     decimal = 7)

np.testing.assert_array_almost_equal(params_str[1][2], params_sqrt[1][2], decimal = 7)

np.testing.assert_array_almost_equal(params_str[1][3], params_sqrt[1][3], decimal = 7)

np.testing.assert_array_almost_equal(params_str[1][4],
                                     params_sqrt[1][4] @ np.transpose(params_sqrt[1][4], [0, 2, 1]),
                                     decimal = 7)

np.testing.assert_array_almost_equal(params_str[1][5], params_sqrt[1][5], decimal = 7)


In [6]:
# Combination params in Parallel-Filtering
i = 4
j = 5

# standard init params
As = params_str[0][0]
bs = params_str[0][1]
Cs = params_str[0][2]
etas = params_str[0][3]
Js = params_str[0][4]
elem_i = (As[i], bs[i], Cs[i], etas[i], Js[i])
elem_j = (As[j], bs[j], Cs[j], etas[j], Js[j])

# sqrt init params
As_sqrt = params_sqrt[0][0]
bs_sqrt = params_sqrt[0][1]
Us_sqrt = params_sqrt[0][2]
etas_sqrt = params_sqrt[0][3]
Zs_sqrt = params_sqrt[0][4]
elem_i_sqrt = (As_sqrt[i], bs_sqrt[i], Us_sqrt[i], etas_sqrt[i], Zs_sqrt[i])
elem_j_sqrt = (As_sqrt[j], bs_sqrt[j], Us_sqrt[j], etas_sqrt[j], Zs_sqrt[j])

str_filter_op = standard_filtering_operator(elem_i, elem_j)
sqrt_filter_op = sqrt_filtering_operator(elem_i_sqrt, elem_j_sqrt)

In [7]:
# (A_ij, b_ij, C_ij, eta_ij , J_ij) v.s (A_ij, b_ij, U_ij, eta_ij, Z_ij)
np.testing.assert_array_almost_equal(str_filter_op[0], sqrt_filter_op[0], decimal = 7)

np.testing.assert_array_almost_equal(str_filter_op[1], sqrt_filter_op[1], decimal = 7)

np.testing.assert_array_almost_equal(str_filter_op[2],
                                     sqrt_filter_op[2] @ sqrt_filter_op[2].T,
                                     decimal = 7)

np.testing.assert_array_almost_equal(str_filter_op[3], sqrt_filter_op[3], decimal = 7)

np.testing.assert_array_almost_equal(str_filter_op[4],
                                     sqrt_filter_op[4] @ sqrt_filter_op[4].T,
                                     decimal = 7)

In [8]:
# Initialization params in Parallel-Smoothing

filtering_trajectory =  MVNStandard(jnp.repeat(jnp.array([[-1., -1., 2., 4., 6.]]),T, axis=0),
                                                     jnp.repeat(4*jnp.eye(5).reshape(1, 5, 5), T, axis=0))


filtering_trajectory_sqrt = MVNSqrt(jnp.repeat(jnp.array([[-1., -1., 2., 4., 6.]]),T, axis=0),
                              jnp.repeat(2*jnp.eye(5).reshape(1, 5, 5), T, axis=0))



str_params_smooth = _associative_params(linearization_method, transition_model,
                                        initial_states, filtering_trajectory, False)
sqrt_params_smooth = _associative_params(linearization_method, sqrt_transition_model,
                                        initial_states_sqrt, filtering_trajectory_sqrt, True)

In [9]:
# (g, E, L) v.s (g, E, D)
np.testing.assert_array_almost_equal(str_params_smooth[0], sqrt_params_smooth[0], decimal = 7)

np.testing.assert_array_almost_equal(str_params_smooth[1], sqrt_params_smooth[1], decimal = 7)

np.testing.assert_array_almost_equal(str_params_smooth[2],
                                     sqrt_params_smooth[2] @ np.transpose(sqrt_params_smooth[2], [0, 2, 1]),
                                     decimal = 7)

In [10]:
# Combination params in Parallel - Smoothing
i = 4
j = 5

# standard init params
gs = str_params_smooth[0]
Es = str_params_smooth[1]
Ls = str_params_smooth[2]
smooth_elem_i = (gs[i], Es[i], Ls[i])
smooth_elem_j = (gs[j], Es[j], Ls[j])

# sqrt init params
gs_sqrt = sqrt_params_smooth[0]
Es_sqrt = sqrt_params_smooth[1]
Ds_sqrt = sqrt_params_smooth[2]
smooth_elem_i_sqrt = (gs_sqrt[i], Es_sqrt[i], Ds_sqrt[i])
smooth_elem_j_sqrt = (gs_sqrt[j], Es_sqrt[j], Ds_sqrt[j])

str_smoother_op = standard_smoothing_operator(smooth_elem_i, smooth_elem_j)
sqrt_smoother_op = sqrt_smoothing_operator(smooth_elem_i_sqrt, smooth_elem_j_sqrt)

In [11]:
# (g_ij, E_ij, L_ij) v.s (g_ij, E_ij, U_ij, D_ij)
np.testing.assert_array_almost_equal(str_smoother_op[0], sqrt_smoother_op[0], decimal = 7)

np.testing.assert_array_almost_equal(str_smoother_op[1], sqrt_smoother_op[1], decimal = 7)

np.testing.assert_array_almost_equal(str_smoother_op[2],
                                     sqrt_smoother_op[2] @ sqrt_smoother_op[2].T,
                                     decimal = 7)

# check sequential-sqrt v.s sequential-std (predict, update, and smooth)

In [12]:
# filtering
dim_x= 5
dim_y = 2
y = np.random.randn(dim_y)

x1 = np.random.randn(dim_x)
cholx1 = np.random.rand(dim_x, dim_x)
cholx1[np.triu_indices(dim_x, 1)] = 0
x_nominal_sqrt1 = MVNSqrt(x1, cholx1)
x_nominal_std1 = MVNStandard(x1, cholx1 @ cholx1.T)

F, Q1, b = linearization_method(transition_model, x_nominal_std1)
F_sqrt, cholQ1, b_sqrt = linearization_method(sqrt_transition_model, x_nominal_sqrt1)
predict_std =  _standard_predict(F, Q1, b, x_nominal_std1)
predict_sqrt =  _sqrt_predict(F_sqrt, cholQ1, b_sqrt, x_nominal_sqrt1)

x2 = np.random.randn(dim_x)
cholx2 = np.random.rand(dim_x, dim_x)
cholx2[np.triu_indices(dim_x, 1)] = 0
x_nominal_sqrt2 = MVNSqrt(x2, cholx2)
x_nominal_std2 = MVNStandard(x2, cholx2 @ cholx2.T)

H, R1, c = linearization_method(observation_model, x_nominal_std2)
H_sqrt, cholR1, c_sqrt = linearization_method(sqrt_observation_model, x_nominal_sqrt2)
update_std = _standard_update(H, R1, c_sqrt, predict_std, y)
update_sqrt = _sqrt_update(H_sqrt, cholR1, c, predict_sqrt, y)



In [13]:
#linearization used in prediction (F, b, Q) v.s (F_sqrt, b_sqrt, cholQ)

np.testing.assert_array_almost_equal(F, F_sqrt, decimal=10)
np.testing.assert_array_almost_equal(b, b_sqrt, decimal=10)
np.testing.assert_array_almost_equal(Q1, cholQ1@cholQ1.T, decimal=4)  #?? NOT OK

In [14]:
#linearization used in update (H, c, R) v.s (H_sqrt, c_sqrt, cholR)
np.testing.assert_array_almost_equal(H, H_sqrt, decimal=10)
np.testing.assert_array_almost_equal(c, c_sqrt, decimal=10)
np.testing.assert_array_almost_equal(R1, cholR1@cholR1.T, decimal=10)  #??? this one is OK

In [15]:
# predict
np.testing.assert_array_almost_equal(predict_std.mean, predict_sqrt.mean, decimal=10)

In [16]:
np.testing.assert_array_almost_equal(predict_std.cov,
                                     predict_sqrt.chol @ predict_sqrt.chol.T,
                                     decimal=4)  ##??

In [17]:
# update
np.testing.assert_array_almost_equal(update_std[0].mean, update_sqrt[0].mean, decimal=10)

In [18]:
np.testing.assert_array_almost_equal(update_std[0].cov,
                                     update_sqrt[0].chol @ update_sqrt[0].chol.T,
                                     decimal=4)  ##??

In [19]:
# smooth

In [None]:
_standard_smooth(F, Q, b, xf, xs)

# check sequential-sqrt v.s sequential-str

In [20]:
# Filtering resuts: standard v.s square-root

seq_str_filtering = filtering(ys,
                              init,
                              transition_model,
                              observation_model,
                              linearization_method,
                              initial_states,
                              False,
                              False)

seq_sqrt_filtering = filtering(ys,
                              chol_init,
                              sqrt_transition_model,
                              sqrt_observation_model,
                              linearization_method,
                              initial_states_sqrt,
                              False,
                              False)

In [21]:
np.testing.assert_array_almost_equal(seq_str_filtering.mean, seq_sqrt_filtering.mean, decimal = 6)

np.testing.assert_array_almost_equal(seq_str_filtering.cov,
                                     seq_sqrt_filtering.chol @ np.transpose(seq_sqrt_filtering.chol, [0, 2, 1]),
                                     decimal = 6)



In [22]:
# Smoothing resuts: standard v.s square-root

nominal_trajectory = MVNStandard(jnp.repeat(jnp.array([[-7., -3., 2., 8., 3.]]),T, axis=0),
                                                     jnp.repeat(5.5 * jnp.eye(5).reshape(1, 5, 5), T, axis=0))

sqrt_nominal_trajectory = MVNSqrt(jnp.repeat(jnp.array([[-7., -3., 2., 8., 3.]]),T, axis=0),
                              jnp.repeat(jnp.sqrt(5.5 )* jnp.eye(5).reshape(1, 5, 5), T, axis=0))

seq_std_filtering_bis = MVNStandard(seq_sqrt_filtering.mean, 
                                    jax.vmap(lambda M: M @ M.T)(seq_sqrt_filtering.chol))


seq_str_smoothing = smoothing(transition_model,
                              seq_std_filtering_bis,
                              cubature,
                              nominal_trajectory,
                              False)

seq_sqrt_smoothing = smoothing(sqrt_transition_model,
                               seq_sqrt_filtering,
                               cubature,
                               sqrt_nominal_trajectory,
                               False)

In [25]:
np.testing.assert_array_almost_equal(seq_str_smoothing.mean, 
                                     seq_sqrt_smoothing.mean,
                                     decimal=4)  #???


In [28]:
np.testing.assert_array_almost_equal(seq_str_smoothing.cov, 
                                     seq_sqrt_smoothing.chol @ np.transpose(seq_sqrt_smoothing.chol, [0, 2, 1]),
                                     decimal=6)



In [None]:
### Sequential - Smoothing has problems in sqrt mode 

# check parallel-sqrt v.s parallel-str

In [29]:
# Filtering resuts: standard v.s square-root
par_str_filtering = filtering(ys,
                              init,
                              transition_model,
                              observation_model,
                              linearization_method,
                              initial_states,
                              True,
                              False)

par_sqrt_filtering = filtering(ys,
                              chol_init,
                              sqrt_transition_model,
                              sqrt_observation_model,
                              linearization_method,
                              initial_states_sqrt,
                              True,
                              False)

In [30]:
np.testing.assert_array_almost_equal(par_str_filtering.mean, 
                                     par_sqrt_filtering.mean,
                                     decimal=5)

np.testing.assert_array_almost_equal(par_str_filtering.cov,
                                     par_sqrt_filtering.chol @ np.transpose(par_sqrt_filtering.chol, [0, 2, 1]),
                                     decimal = 5)

In [31]:
# Smoothing resuts: standard v.s square-root
# Same nominal-trajectory as sequential smoothing

par_str_smoothing = smoothing(transition_model,
                              seq_str_filtering,
                              linearization_method,
                              nominal_trajectory,
                              True)

par_sqrt_smoothing = smoothing(sqrt_transition_model,
                               seq_sqrt_filtering,
                               linearization_method,
                               sqrt_nominal_trajectory,
                               True)

In [42]:
np.testing.assert_array_almost_equal(par_str_smoothing.mean, 
                                     par_sqrt_smoothing.mean,
                                     decimal=4)  #??

np.testing.assert_array_almost_equal(par_str_smoothing.cov,
                                     par_sqrt_smoothing.chol @ np.transpose(par_sqrt_smoothing.chol, [0, 2, 1]),
                                     decimal = 6) #??

In [None]:
### Parallel - Smoothing has problems in sqrt mode

# Check sequential-str v.s parallel-str / sequential-sqrt v.s parallel-sqrt

In [43]:
# Filtering standard
np.testing.assert_array_almost_equal(seq_str_filtering.mean, 
                                     par_str_filtering.mean,
                                     decimal=7) # OK

np.testing.assert_array_almost_equal(seq_str_filtering.cov,
                                     par_str_filtering.cov, 
                                     decimal = 7) # Ok

In [44]:
# Smoothing standard
np.testing.assert_array_almost_equal(seq_str_smoothing.mean, 
                                     par_str_smoothing.mean,
                                     decimal=7) # Ok

np.testing.assert_array_almost_equal(seq_str_smoothing.cov, 
                                     par_str_smoothing.cov,
                                     decimal=7) # OK

In [59]:
# Filtering square-root
np.testing.assert_array_almost_equal(seq_sqrt_filtering.mean, 
                                     par_sqrt_filtering.mean,
                                     decimal=7) # OK



In [55]:
np.testing.assert_array_almost_equal(seq_sqrt_filtering.chol @ np.transpose(seq_sqrt_filtering.chol, [0,2,1]),
                                     par_sqrt_filtering.chol @ np.transpose(par_sqrt_filtering.chol, [0,2,1]),
                                     decimal=7)  # OK


In [None]:
### Filtering-sqrt: The computation of cholesky has problem (both parallel and sequential should be checked)
#  This is OK

In [65]:
# Smoothing square-root
np.testing.assert_array_almost_equal(seq_sqrt_smoothing.mean, 
                                     par_sqrt_smoothing.mean,
                                     decimal=5) # OK

In [71]:
np.testing.assert_array_almost_equal(seq_sqrt_smoothing.chol @ np.transpose(seq_sqrt_smoothing.chol, [0,2,1]), 
                                     par_sqrt_smoothing.chol @ np.transpose(par_sqrt_smoothing.chol, [0,2,1]),
                                     decimal=7)  # OK

In [None]:
### Smoothing-sqrt: The computation of mean and cholesky has problem (both parallel and sequential should be checked)
# almost ok

## Check str v.s sqrt

In [72]:
# parallel-filtering
np.testing.assert_array_almost_equal(par_str_filtering.mean, 
                                     par_sqrt_filtering.mean,
                                     decimal=5)  # OK

np.testing.assert_array_almost_equal(par_str_filtering.cov, 
                                     par_sqrt_filtering.chol @  np.transpose(par_sqrt_filtering.chol,[0,2,1]),
                                     decimal=5)  # OK

In [82]:
# parallel-smoothing
np.testing.assert_array_almost_equal(par_str_smoothing.mean, 
                                     par_sqrt_smoothing.mean,
                                     decimal=4)  # NOT OK

np.testing.assert_array_almost_equal(par_str_smoothing.cov, 
                                     par_sqrt_smoothing.chol @  np.transpose(par_sqrt_smoothing.chol,[0,2,1]),
                                     decimal=6)  # NOT OK
# changing the initial guess 

In [None]:
# ^Expected, parallel smoothing in sqrt mode has problem

In [87]:
# sequential-filtering
np.testing.assert_array_almost_equal(seq_str_filtering.mean, 
                                     seq_sqrt_filtering.mean,
                                     decimal=7)  # OK


In [102]:
# np.testing.assert_array_almost_equal(seq_str_filtering.cov, 
#                                      seq_sqrt_filtering.chol @  np.transpose(par_sqrt_filtering.chol,[0,2,1]),
#                                      decimal=0)  # NOT OK



In [None]:
# ^Expected, the choleskey of sequential filtering in sqrt mod has problem

In [96]:
# sequential-smoothing
np.testing.assert_array_almost_equal(seq_str_smoothing.mean, 
                                     seq_sqrt_smoothing.mean,
                                     decimal=4)  # NOT OK

np.testing.assert_array_almost_equal(seq_str_smoothing.cov, 
                                     seq_sqrt_smoothing.chol @  np.transpose(seq_sqrt_smoothing.chol,[0,2,1]),
                                     decimal=6)  # OK
# changing the initial guess

In [None]:
# ^Expected, sequential smoothing in sqrt mode has problem

# Iterated standard and sqrt

In [131]:
iteration = 100
# standard
iterated_res_seq = iterated_smoothing(ys, init, transition_model, observation_model,
                                      linearization_method, initial_states, False,
                                      criterion=lambda i, *_: i < iteration)

In [None]:
iterated_res_par = iterated_smoothing(ys, init, transition_model, observation_model,
                                      linearization_method, initial_states, True,
                                      criterion=lambda i, *_: i < iteration)



In [None]:
# square-root
sqrt_iterated_res_seq = iterated_smoothing(ys, chol_init, sqrt_transition_model, sqrt_observation_model,
                                  linearization_method, initial_states_sqrt, False,
                                      criterion=lambda i, *_: i < iteration)

In [None]:
sqrt_iterated_res_par = iterated_smoothing(ys, chol_init, sqrt_transition_model, sqrt_observation_model,
                                       linearization_method, initial_states_sqrt, True,
                                       criterion=lambda i, *_: i < iteration)

In [None]:
np.testing.assert_array_almost_equal(iterated_res_seq.mean, iterated_res_par.mean, decimal=10)

np.testing.assert_array_almost_equal(iterated_res_seq.cov, iterated_res_par.cov, decimal=10)



In [None]:
np.testing.assert_array_almost_equal(sqrt_iterated_res_seq.mean, sqrt_iterated_res_par.mean, decimal=3)

np.testing.assert_array_almost_equal(sqrt_iterated_res_seq.chol @ np.transpose(sqrt_iterated_res_seq.chol, [0,2,1]),
                                     sqrt_iterated_res_par.chol @ np.transpose(sqrt_iterated_res_par.chol, [0,2,1]), 
                                     decimal=2)


In [None]:
plt.figure(figsize=(15, 10))
plt.plot(sqrt_iterated_res_seq.mean[:,0], sqrt_iterated_res_seq.mean[:,1],'g', label = "sqrt_seq")
plt.plot(sqrt_iterated_res_par.mean[:,0], sqrt_iterated_res_par.mean[:,1],'k', label = "sqrt_par")
plt.plot(iterated_res_seq.mean[:,0], iterated_res_seq.mean[:,1],'b', label = "seq")
plt.plot(iterated_res_par.mean[:,0], iterated_res_par.mean[:,1],'r', label = "par")
plt.legend();
