In [None]:
import sys
sys.path.append('..//')

import jax
import pickle
from jax import jit
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')

from newton_smoothers.base import MVNStandard, FunctionalModel
from newton_smoothers.approximation import quadratize

from newton_smoothers import trust_region_iterated_recursive_newton_smoother as tr_recur_newton
from newton_smoothers import trust_region_iterated_batch_newton_smoother as tr_batch_newton

from bearing_data import make_parameters

In [None]:
s1 = jnp.array([-1.5, 0.5])  # First sensor location
s2 = jnp.array([1., 1.])  # Second sensor location
r = 0.5  # Observation noise (stddev)
x0 = jnp.array([0.1, 0.2, 1, 0])  # initial true location

dt = 0.01  # discretization time step
qc = 0.01  # discretization noise
qw = 0.1  # discretization noise

nx, ny = 5, 2

Q, R, trans_fcn, obsrv_fcn, _, _ = make_parameters(qc, qw, r, dt, s1, s2)

trans_mdl = FunctionalModel(trans_fcn, MVNStandard(jnp.zeros((nx,)), Q))
obsrv_mdl = FunctionalModel(obsrv_fcn, MVNStandard(jnp.zeros((ny,)), R))

init_dist = MVNStandard(
    mean=jnp.array([-1., -1., 0., 0., 0.]),
    cov=jnp.eye(nx)
)

In [None]:
Ts = [100, 200, 300, 400, 500, 1000, 1500]

In [None]:
with open("outputs/data_measurements.pkl", "rb") as open_file:
    data_measurements = pickle.load(open_file)

In [None]:
import numpy as np
import time

def func(method, lengths, data, nb_runs=20, nb_iter=30, label='ls_newton_recursive'):
    res_mean = []
    res_median = []
    for t, T in enumerate(lengths):
        print(f"Length {t+1} out of {len(lengths)}")
        run_times = []
        init_nominal = MVNStandard(jnp.repeat(jnp.array([[-1., -1., 6., 4., 2.]]),T + 1, axis=0),
                                   jnp.repeat(jnp.eye(5).reshape(1, 5, 5), T + 1, axis=0))
        for i in range(nb_runs):
            ys = data[t][i]['ys']
            args = ys, init_nominal, nb_iter

            tic = time.time()
            _ = method(*args)
            toc = time.time()
            run_times.append(toc - tic)
            print(f"run {i+1} out of {nb_runs}", end="\r")
        res_mean.append(np.mean(run_times))
        res_median.append(np.median(run_times))
        # np.savez("outputs/TIME-CPU-"+label+"-"+str(t+1), time = np.array(run_times))
    print()

    return np.array(res_mean), np.array(res_median)

In [None]:
# Newton Recursive Iterated Smoother
def iterated_recursive_newton_smoother(observations, nominal_trajectory, iteration):
    return tr_recur_newton(nominal_trajectory,
                           observations,
                           init_dist,
                           trans_mdl,
                           obsrv_mdl,
                           quadratize,
                           nb_iter=30)[0]

#  Newton Batch Iterated Smoother
def  iterated_batch_newton_smoother(observations, nominal_trajectory, iteration):
    return tr_batch_newton(nominal_trajectory.mean,
                           observations,
                           init_dist,
                           trans_mdl,
                           obsrv_mdl,
                           nb_iter=30)[0]

In [None]:
cpu_recursive = jit(iterated_recursive_newton_smoother, backend="cpu")
cpu_batch = jit(iterated_batch_newton_smoother, backend="cpu")

In [None]:
cpu_recursive_runtime = func(cpu_recursive, Ts, data_measurements, label='tr_newton_recursive')
# jnp.savez("outputs/recursive_runtime15", cpu_tr_recursive_runtime = cpu_recursive_runtime)

In [None]:
cpu_batch_runtime = func(cpu_batch, Ts, data_measurements, label='tr_newton_batch')
# jnp.savez("outputs/batch_runtime15", cpu_tr_batch_runtime = cpu_batch_runtime)

In [None]:
from matplotlib import pyplot as plt

plt.loglog(Ts, cpu_batch_runtime[0],'--*', label="batch_runtime")
plt.loglog(Ts, cpu_recursive_runtime[0],'--*', label="recursive_runtime")
plt.grid(True, which="both")
plt.legend()
plt.title("Iterated trust region newton");

In [None]:
import pandas as pd

data_runtime_batch_seq = np.stack([Ts,
                                   cpu_batch_runtime[0],
                                   cpu_recursive_runtime[0]
                                   ],
                                   axis=1)
columns = ["times",
           "cpu_tr_batch_runtime",
           "cpu_tr_recursive_runtime"]

df1 = pd.DataFrame(data=data_runtime_batch_seq, columns=columns)
# df1.to_csv("outputs/tr_batch_seq_runtime.csv")