In [22]:
import pandas as pd
import jax.numpy as jnp 
import numpy as np
from jax import jit
import pkg_resources, glob
from jnkepler.jaxttv import JaxTTV
from jnkepler.jaxttv.utils import *
from jnkepler.jaxttv.conversion import *
from jnkepler.jaxttv.findtransit import find_transit_times_single, find_transit_times_all, find_transit_times_kepler_all
from jnkepler.jaxttv.symplectic import integrate_xv, kepler_step_map
from jnkepler.jaxttv.hermite4 import integrate_xv as integrate_xv_hermite4

In [23]:
path = pkg_resources.resource_filename('jnkepler', 'data/')

In [24]:
p_init = [45.155305, 85.31646, 130.17809]
d = pd.read_csv(path+"kep51_ttv.txt", delim_whitespace=True, header=None, names=['tnum', 'tc', 'tcerr', 'dnum', 'planum'])
tcobs = [np.array(d.tc[d.planum==j+1]) for j in range(3)]

In [25]:
dt = 1.0
t_start, t_end = 155., 2950.
jttv = JaxTTV(t_start, t_end, dt)
jttv.set_tcobs(tcobs, p_init, print_info=False)
params_test = np.loadtxt(glob.glob(path+"kep51*_params.txt")[0])
elements, masses = params_to_elements(params_test, jttv.nplanet)
self = jttv

In [27]:
%%timeit
# develop: 3.29 ms ± 63.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
xjac0, vjac0 = jit(initialize_jacobi_xv)(elements, masses, self.t_start)

92.2 µs ± 2.21 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [32]:
@jit
def integrate_xv_jit(xjac0, vjac0, masses):
    return integrate_xv(xjac0, vjac0, masses, self.times, nitr=self.nitr_kepler)

In [34]:
%%timeit
# develop: 406 ms ± 11.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each), 414 ms for nitr=1
times, xvjac = integrate_xv_jit(xjac0, vjac0, masses)

1.15 ms ± 26 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [36]:
%%timeit
orbit_idx = self.pidx.astype(int) - 1 # idx for orbit, starting from 0
tcobs1d = self.tcobs_flatten

733 ns ± 8.51 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [38]:
@jit
def find_tt_jit(orbit_idx, tcobs1d, times, xvjac, masses):
    return find_transit_times_all(orbit_idx, tcobs1d, times, xvjac, masses, nitr=self.nitr_transit)

In [39]:
%%timeit
#develop: 351 ms ± 23.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each), 299ms for nitr=1
transit_times = find_transit_times_all(orbit_idx, tcobs1d, times, xvjac, masses, nitr=self.nitr_transit)

353 ms ± 16.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [40]:
%%timeit
tt = find_tt_jit(orbit_idx, tcobs1d, times, xvjac, masses)

722 µs ± 7.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [43]:
from functools import partial
#@partial(jit, static_argnums=(0,))
@jit
def func(elements, masses):
    xjac0, vjac0 = initialize_jacobi_xv(elements, masses, self.t_start)
    times, xvjac = integrate_xv(xjac0, vjac0, masses, self.times, nitr=self.nitr_kepler)
    orbit_idx = self.pidx.astype(int) - 1 # idx for orbit, starting from 0
    tcobs1d = self.tcobs_flatten
    transit_times = find_transit_times_all(orbit_idx, tcobs1d, times, xvjac, masses, nitr=self.nitr_transit)
    return transit_times

In [45]:
%%timeit
func(elements, masses)

1.76 ms ± 18.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [45]:
%%timeit
transit_times = find_transit_times_all(orbit_idx, tcobs1d, times, xvjac, masses, nitr=1)

299 ms ± 10.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Integrate xv

In [52]:
times, xvjac = integrate_xv(xjac0, vjac0, masses, self.times, nitr=self.nitr_kepler)