In [7]:
# ruff: noqa
import jax
jax.config.update("jax_enable_x64", True)

import astropy.units as u
import jax.numpy as jnp
from astropy.coordinates import SkyCoord
from astropy.time import Time

from astroquery.jplhorizons import Horizons

from jorbit import Particle, Observations, Ephemeris
from jorbit.accelerations import create_static_default_on_sky_acc_function
from jorbit.accelerations.static_helpers import (
    get_all_dynamic_intermediate_dts,
    precompute_perturber_positions,
    generate_perturber_chebyshev_coeffs,
)


In [2]:
nights = [Time("2025-01-01 07:00"), Time("2026-01-02 07:00"), Time("2027-01-05 07:00")]

times = []
for n in nights:
    times.extend([n + i * 1 * u.hour for i in range(3)])
times = Time(times)

obj = Horizons(id="274301", location="695@399", epochs=times.utc.jd)
pts = obj.ephemerides(extra_precision=True, quantities="1")

coords = SkyCoord(pts["RA"], pts["DEC"], unit=(u.deg, u.deg))
times = Time(pts["datetime_jd"], format="jd", scale="utc")

obs = Observations(
    observed_coordinates=coords,
    times=times,
    observatories="kitt peak",
    astrometric_uncertainties=1 * u.arcsec,
)

p = Particle.from_horizons("274301", Time("2026-01-01"))
p = Particle.from_horizons(
    "274301", Time("2026-01-01"), observations=obs, fit_seed=p.keplerian_state
)

In [4]:
p.observations

Observations with 9 set(s) of observations

In [None]:
def precompute_likelihood_data(p: Particle):
    obs_times = p.observations.times

    ephem = Ephemeris(
        earliest_time=Time(jnp.min(obs_times), format="jd", scale="tdb") - 10 * u.day,
        latest_time=Time(jnp.max(obs_times), format="jd", scale="tdb") + 10 * u.day,
        ssos="default solar system",
        de_ephemeris_version="440",
    )

    # precompute Chebyshev coefficients to compute the states of the perturbers
    # around each observation time for light travel time corrections
    perturber_pos_chebys, perturber_vel_chebys = [], []
    for t in obs_times:
        pos, vel = generate_perturber_chebyshev_coeffs(
            obs_time=Time(t, format="jd", scale="tdb"),
            ephem=ephem,
        )
        perturber_pos_chebys.append(pos)
        perturber_vel_chebys.append(vel)
    perturber_pos_chebys = jnp.array(perturber_pos_chebys)
    perturber_vel_chebys = jnp.array(perturber_vel_chebys)

    # get the times of all the intermediate steps between observations
    times = jnp.concatenate([jnp.array([p.keplerian_state.time]), obs_times])
    state = p.keplerian_state.to_system()
    dynamic_acc_func = create_default_ephemeris_acceleration_func(ephem.processor)
    a0_dynamic = dynamic_acc_func(state)
    integrator_init = initialize_ias15_integrator_state(a0_dynamic)
    integrator_init.dt = jnp.diff(times.tdb.jd)[0]
    dts, inds = get_all_dynamic_intermediate_dts(
        initial_system_state=state,
        acceleration_func=dynamic_acc_func,
        times=times,
        initial_integrator_state=integrator_init,
    )

    # precompute the positions and velocities of all perturbers at each intermediate
    # step, and at each ias15 substep
    planet_pos, planet_vel, asteroid_pos, asteroid_vel, gms = precompute_perturber_positions(
        t0=Time(p.cartesian_state.time, format="jd", scale="tdb"),
        dts=dts,
        de_ephemeris_version="440"
    )
    perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
    perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

    return 


In [4]:
t0 = Time("2026-01-01")
p = Particle.from_horizons("274301", t0)
state = p.cartesian_state.to_system()

eph = Ephemeris(ssos="default solar system")
x_coeffs, v_coeffs = generate_perturber_chebyshev_coeffs(t0, eph)

state.fixed_perturber_log_gms = eph.processor.log_gms
state.acceleration_func_kwargs["perturber_position_cheby_coeffs"] = x_coeffs
state.acceleration_func_kwargs["perturber_velocity_cheby_coeffs"] = v_coeffs
state.acceleration_func_kwargs["cheby_t0"] = (t0.tdb.jd - 1.5)
state.acceleration_func_kwargs["cheby_t1"] = (t0.tdb.jd + 0.1)

acc_func = create_static_default_on_sky_acc_function()
a_new = acc_func(state)
a_new

Array([[4.03408192e-05, 2.83907818e-05, 8.81375997e-06]], dtype=float64)

In [5]:
acc_func_old = create_default_ephemeris_acceleration_func(eph.processor)
a_old = acc_func_old(state)
a_old

Array([[4.03408192e-05, 2.83907818e-05, 8.81375997e-06]], dtype=float64)

In [14]:
(1000*u.au / const.c).to(u.day)

<Quantity 5.77551833 d>

In [6]:
a_old - a_new

Array([[-8.13151629e-20, -9.82558219e-20, -3.04931861e-20]], dtype=float64)

In [7]:
t0 = Time("2026-01-01")
p = Particle.from_horizons("274301", t0)
state = p.cartesian_state.to_system()

eph = Ephemeris(ssos="default solar system")
x_coeffs, v_coeffs = generate_perturber_chebyshev_coeffs(t0, eph)

state.fixed_perturber_log_gms = eph.processor.log_gms
state.acceleration_func_kwargs["perturber_position_cheby_coeffs"] = x_coeffs
state.acceleration_func_kwargs["perturber_velocity_cheby_coeffs"] = v_coeffs
state.acceleration_func_kwargs["cheby_t0"] = (t0.tdb.jd - 1.5)
state.acceleration_func_kwargs["cheby_t1"] = (t0.tdb.jd + 0.1)

old_acc_func = jax.jit(create_default_ephemeris_acceleration_func(eph.processor))
new_acc_func = jax.jit(create_static_default_on_sky_acc_function())

_ = old_acc_func(state)
_ = new_acc_func(state)

In [8]:
%timeit old_acc_func(state)

73.2 μs ± 5.8 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [9]:
%timeit new_acc_func(state)

22.8 μs ± 494 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
