In [4]:
import jax

jax.config.update("jax_enable_x64", True)

import pandas as pd
from astropy.time import Time

from jorbit import Ephemeris, Particle
from jorbit.accelerations import (
    create_default_ephemeris_acceleration_func,
    create_static_generic_acceleration_func,
)
from jorbit.accelerations.static_helpers import precompute_perturber_positions


In [5]:
# eventually split every particle into two chunks, one forwards and one backwards for the integration
# from the midpoint time

In [6]:
planet_pos, planet_vel, asteroid_pos, asteroid_vel, all_times, obs_indices = (
    precompute_perturber_positions(
        t0=Time("1995-01-01T00:00:00", format="isot"),
        times=Time(
            list(pd.read_xml("/Users/cassese/Downloads/274301.xml")["obsTime"]),
            format="isot",
        ),
        max_step_size=3.2,
        de_ephemeris_version="440",
    )
)

test = Particle.from_horizons("274301", time=Time("1995-01-01T00:00:00", format="isot"),)

In [7]:
eph = Ephemeris(ssos="default solar system")
old_func = create_default_ephemeris_acceleration_func(eph.processor)
true_acc = old_func(test.cartesian_state.to_system())
true_acc

Array([[ 5.56203643e-07, -4.31812496e-05, -1.29775661e-05]], dtype=float64)

In [9]:
state = test.cartesian_state.to_system()
state.acceleration_func_kwargs["gr_perturber_xs"] = planet_pos[0, 0] # first step, first substep
state.acceleration_func_kwargs["gr_perturber_vs"] = planet_vel[0, 0]
state.acceleration_func_kwargs["newtonian_perturber_xs"] = asteroid_pos[0, 0]
state.acceleration_func_kwargs["newtonian_perturber_vs"] = asteroid_vel[0, 0]
state.acceleration_func_kwargs["perturber_log_gms"] = eph.processor.log_gms

static_acc_func = create_static_generic_acceleration_func()

static_acc = static_acc_func(state)
static_acc

Array([[ 5.56203643e-07, -4.31812496e-05, -1.29775661e-05]], dtype=float64)

In [10]:
static_acc - true_acc

Array([[0., 0., 0.]], dtype=float64)

In [11]:
s_base = test.cartesian_state.to_system()
old_func = jax.jit(old_func)
_ = old_func(s_base)

static_acc = jax.jit(static_acc_func)
tmp = test.cartesian_state.to_system()
tmp.acceleration_func_kwargs["gr_perturber_xs"] = planet_pos[0, 0] # first step, first substep
tmp.acceleration_func_kwargs["gr_perturber_vs"] = planet_vel[0, 0]
tmp.acceleration_func_kwargs["newtonian_perturber_xs"] = asteroid_pos[0, 0]
tmp.acceleration_func_kwargs["newtonian_perturber_vs"] = asteroid_vel[0, 0]
tmp.acceleration_func_kwargs["perturber_log_gms"] = eph.processor.log_gms
_ = static_acc(tmp)

In [12]:
%%timeit
old_func(tmp).block_until_ready()

83.9 μs ± 2.14 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [13]:
%%timeit
static_acc(tmp).block_until_ready()

43.7 μs ± 1.61 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
