In [6]:
# ruff: noqa
import jax

jax.config.update("jax_enable_x64", True)

import astropy.units as u
import jax.numpy as jnp
from astropy.coordinates import SkyCoord
from astropy.time import Time

from astroquery.jplhorizons import Horizons

from jorbit import Particle, Observations, Ephemeris
from jorbit.accelerations import (
    create_static_default_on_sky_acc_func,
    create_default_ephemeris_acceleration_func,
    create_static_default_acceleration_func,
)
from jorbit.accelerations.static_helpers import (
    get_all_dynamic_intermediate_dts,
    precompute_perturber_positions,
    generate_perturber_chebyshev_coeffs,
)
from jorbit.integrators import initialize_ias15_integrator_state, ias15_static_evolve
from jorbit.astrometry.sky_projection import on_sky
from jorbit.astrometry.sky_projection import tangent_plane_projection

In [2]:
nights = [Time("2025-01-01 07:00"), Time("2026-01-02 07:00"), Time("2027-01-05 07:00")]

times = []
for n in nights:
    times.extend([n + i * 1 * u.hour for i in range(3)])
times = Time(times)

obj = Horizons(id="274301", location="695@399", epochs=times.utc.jd)
pts = obj.ephemerides(extra_precision=True, quantities="1")

coords = SkyCoord(pts["RA"], pts["DEC"], unit=(u.deg, u.deg))
times = Time(pts["datetime_jd"], format="jd", scale="utc")

obs = Observations(
    observed_coordinates=coords,
    times=times,
    observatories="kitt peak",
    astrometric_uncertainties=1 * u.arcsec,
)

p = Particle.from_horizons("274301", Time("2024-01-01"))
p = Particle.from_horizons(
    "274301", Time("2026-01-01"), observations=obs, fit_seed=p.keplerian_state
)

In [3]:
def precompute_likelihood_data(p: Particle):
    obs_times = p.observations.times

    ephem = Ephemeris(
        earliest_time=Time(jnp.min(obs_times), format="jd", scale="tdb") - 10 * u.day,
        latest_time=Time(jnp.max(obs_times), format="jd", scale="tdb") + 10 * u.day,
        ssos="default solar system",
        de_ephemeris_version="440",
    )

    # precompute Chebyshev coefficients to compute the states of the perturbers
    # around each observation time for light travel time corrections
    perturber_pos_chebys, perturber_vel_chebys = [], []
    for t in obs_times:
        pos, vel = generate_perturber_chebyshev_coeffs(
            obs_time=Time(t, format="jd", scale="tdb"),
            ephem=ephem,
        )
        perturber_pos_chebys.append(pos)
        perturber_vel_chebys.append(vel)
    perturber_pos_chebys = jnp.array(perturber_pos_chebys)
    perturber_vel_chebys = jnp.array(perturber_vel_chebys)

    cheby_info = {
        "perturber_position_cheby_coeffs": perturber_pos_chebys,
        "perturber_velocity_cheby_coeffs": perturber_vel_chebys,
        "cheby_t0": obs_times - 1.5,
        "cheby_t1": obs_times + 0.1,
    }

    # get the times of all the intermediate steps between observations
    times = Time(
        jnp.concatenate([jnp.array([p.keplerian_state.time]), obs_times]),
        format="jd",
        scale="tdb",
    )
    state = p.keplerian_state.to_system()
    dynamic_acc_func = create_default_ephemeris_acceleration_func(ephem.processor)
    a0_dynamic = dynamic_acc_func(state)
    integrator_init = initialize_ias15_integrator_state(a0_dynamic)
    integrator_init.dt = jnp.diff(times.tdb.jd)[0]
    dts, inds = get_all_dynamic_intermediate_dts(
        initial_system_state=state,
        acceleration_func=dynamic_acc_func,
        times=times,
        initial_integrator_state=integrator_init,
    )

    # precompute the positions and velocities of all perturbers at each intermediate
    # step, and at each ias15 substep
    planet_pos, planet_vel, asteroid_pos, asteroid_vel, gms = (
        precompute_perturber_positions(
            t0=Time(p.cartesian_state.time, format="jd", scale="tdb"),
            dts=dts,
            de_ephemeris_version="440",
        )
    )
    perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
    perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

    observer_positions = p.observations.observer_positions

    return (
        cheby_info,
        dts,
        inds,
        perturber_pos,
        perturber_vel,
        gms,
        observer_positions,
        times,
        obs.ra,
        obs.dec,
    )


(
    cheby_info,
    dts,
    inds,
    perturber_pos,
    perturber_vel,
    gms,
    observer_positions,
    times,
    obs_ras,
    obs_decs,
) = precompute_likelihood_data(p)
x = (
    cheby_info,
    dts,
    inds,
    perturber_pos,
    perturber_vel,
    gms,
    observer_positions,
    times,
    obs_ras,
    obs_decs,
)

In [8]:
def create_default_likelihood_ephem(x):
    (
        cheby_info,
        dts,
        inds,
        perturber_pos,
        perturber_vel,
        log_gms,
        observer_positions,
        obs_times,
        obs_ras,
        obs_decs,
    ) = x

    static_acc_func = create_static_default_acceleration_func()
    on_sky_acc_func = create_static_default_on_sky_acc_func()

    times = obs_times[1:].tdb.jd

    def static_residuals_func(state):
        state = state.to_system()
        state.fixed_perturber_positions = perturber_pos[0, 0]
        state.fixed_perturber_velocities = perturber_vel[0, 0]
        state.fixed_perturber_log_gms = log_gms
        a0 = static_acc_func(state)
        integrator_init = initialize_ias15_integrator_state(a0)
        integrator_init.dt = dts[0]

        static_x, static_v, static_state, static_integrator_state = ias15_static_evolve(
            initial_system_state=state,
            acceleration_func=static_acc_func,
            dts=dts,
            initial_integrator_state=integrator_init,
            perturber_positions=perturber_pos,
            perturber_velocities=perturber_vel,
            perturber_log_gms=log_gms,
        )

        obs_xs = static_x[inds, 0, :][1:] # cut out the initial state, which we tacked on for a0
        obs_vs = static_v[inds, 0, :][1:]

        model_ras, model_decs = jax.vmap(on_sky, in_axes=(0, 0, 0, 0, None, 0))(
            obs_xs,
            obs_vs,
            times,
            observer_positions,
            on_sky_acc_func,
            cheby_info,
        )

        xis_etas = jax.vmap(tangent_plane_projection)(obs_ras, obs_decs, model_ras, model_decs)

        return xis_etas

    return jax.jit(jax.tree_util.Partial(static_residuals_func))


func = create_default_likelihood_ephem(x)
func(p.cartesian_state)

Array([[ 5.95733862e-08, -1.11382298e-06],
       [ 1.16549637e-06, -1.91304082e-06],
       [-2.04976570e-07, -1.69924957e-06],
       [ 6.74289942e-07,  9.11332765e-07],
       [ 4.49783062e-07, -1.01510090e-06],
       [-3.05424314e-07, -9.40061223e-07],
       [-2.01414486e-06,  1.05237143e-06],
       [-1.03273046e-06, -1.41413110e-06],
       [-9.14661639e-07,  1.07643725e-07]], dtype=float64)

In [9]:
%timeit func(p.cartesian_state).block_until_ready()

22.3 ms ± 635 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
_ = p.loglike(p.cartesian_state)
%timeit p.loglike(p.cartesian_state).block_until_ready()

39.4 ms ± 452 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [20]:
_ = jax.grad(p.loglike)(p.cartesian_state)
%timeit jax.grad(p.loglike)(p.cartesian_state).x.block_until_ready()

322 ms ± 4.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
def _tmp(state):
    return jnp.sum(func(state)**2)
tmp = jax.jit(jax.jacfwd(_tmp))
_ = tmp(p.cartesian_state)
%timeit tmp(p.cartesian_state).x.block_until_ready()

264 ms ± 10.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [23]:
def create_default_likelihood_ephem(x):
    (
        cheby_info,
        dts,
        inds,
        perturber_pos,
        perturber_vel,
        log_gms,
        observer_positions,
        obs_times,
        obs_ras,
        obs_decs,
    ) = x

    static_acc_func = create_static_default_acceleration_func()
    on_sky_acc_func = create_static_default_on_sky_acc_func()

    times = obs_times[1:].tdb.jd

    def static_residuals_func(state):
        state = state.to_system()
        state.fixed_perturber_positions = perturber_pos[0, 0]
        state.fixed_perturber_velocities = perturber_vel[0, 0]
        state.fixed_perturber_log_gms = log_gms
        a0 = static_acc_func(state)
        integrator_init = initialize_ias15_integrator_state(a0)
        integrator_init.dt = dts[0]

        static_x, static_v, static_state, static_integrator_state = ias15_static_evolve(
            initial_system_state=state,
            acceleration_func=static_acc_func,
            dts=dts,
            initial_integrator_state=integrator_init,
            perturber_positions=perturber_pos,
            perturber_velocities=perturber_vel,
            perturber_log_gms=log_gms,
        )

        obs_xs = static_x[inds, 0, :][1:] # cut out the initial state, which we tacked on for a0
        obs_vs = static_v[inds, 0, :][1:]

        return obs_xs, obs_vs

    return jax.jit(jax.tree_util.Partial(static_residuals_func))


func = create_default_likelihood_ephem(x)
func(p.cartesian_state)

(Array([[-2.00572346,  1.77860123,  0.51974075],
        [-2.00600093,  1.77832501,  0.51965579],
        [-2.00627835,  1.77804874,  0.51957082],
        [-1.9496022 , -1.39136776, -0.43223441],
        [-1.94927914, -1.39167602, -0.43232474],
        [-1.94895601, -1.39198424, -0.43241506],
        [ 1.85224854, -0.81520034, -0.23234489],
        [ 1.85248521, -0.81473829, -0.23220442],
        [ 1.85272175, -0.8142762 , -0.23206394]], dtype=float64),
 Array([[-0.00665991, -0.00662871, -0.00203885],
        [-0.00665869, -0.0066298 , -0.00203917],
        [-0.00665748, -0.00663088, -0.00203949],
        [ 0.00775252, -0.00739902, -0.00216809],
        [ 0.0077542 , -0.00739782, -0.00216772],
        [ 0.00775588, -0.00739663, -0.00216735],
        [ 0.00568123,  0.01108848,  0.00337115],
        [ 0.00567852,  0.01108966,  0.00337148],
        [ 0.00567581,  0.01109085,  0.00337182]], dtype=float64))

In [24]:
%timeit func(p.cartesian_state)[0].block_until_ready()

8.57 ms ± 55.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [27]:
_ = p.integrate(Time(obs.times, format="jd", scale="tdb"))
%timeit p.integrate(Time(obs.times, format="jd", scale="tdb"))[0].block_until_ready()

27.7 ms ± 226 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
