In [1]:
# ruff: noqa
import jax

jax.config.update("jax_enable_x64", True)

import astropy.units as u
import jax.numpy as jnp
from astropy.coordinates import SkyCoord
from astropy.time import Time
from astroquery.jplhorizons import Horizons

from jorbit import Ephemeris, Observations, Particle
from jorbit.accelerations import (
    create_default_ephemeris_acceleration_func,
    create_static_default_acceleration_func,
    create_static_default_on_sky_acc_func,
)
from jorbit.accelerations.static_helpers import (
    generate_perturber_chebyshev_coeffs,
    get_all_dynamic_intermediate_dts,
    precompute_perturber_positions,
)
from jorbit.astrometry.sky_projection import on_sky, tangent_plane_projection
from jorbit.integrators import ias15_static_evolve, initialize_ias15_integrator_state
from jorbit.utils.states import KeplerianState, SystemState

In [2]:
# generate some fake observations with Horizons

nights = [Time("2025-01-01 07:00"), Time("2026-01-02 07:00"), Time("2027-01-05 07:00")]

times = []
for n in nights:
    times.extend([n + i * 1 * u.hour for i in range(3)])
times = Time(times)

obj = Horizons(id="274301", location="695@399", epochs=times.utc.jd)
pts = obj.ephemerides(extra_precision=True, quantities="1")

coords = SkyCoord(pts["RA"], pts["DEC"], unit=(u.deg, u.deg))
times = Time(pts["datetime_jd"], format="jd", scale="utc")

obs = Observations(
    observed_coordinates=coords,
    times=times,
    observatories="kitt peak",
    astrometric_uncertainties=1 * u.arcsec,
)

p = Particle.from_horizons("274301", Time("2024-01-01"))
p = Particle.from_horizons(
    "274301", Time("2026-01-01"), observations=obs, fit_seed=p.keplerian_state
)

In [3]:
# create a loglike wrapper that uses "static_residuals"
@jax.jit
def loglike(elements_vec):
    state = KeplerianState(
        semi=jnp.array([elements_vec[0]]),
        ecc=jnp.array([elements_vec[1]]),
        nu=jnp.array([elements_vec[2]]),
        inc=jnp.array([elements_vec[3]]),
        Omega=jnp.array([elements_vec[4]]),
        omega=jnp.array([elements_vec[5]]),
        acceleration_func_kwargs={},
        time=p._time,
    )
    resids = p.static_residuals(state)

    quad = jnp.einsum(
        "bi,bij,bj->b", resids, p.observations.inv_cov_matrices, resids
    )

    ll = jnp.sum(
        -0.5 * (2 * jnp.log(2 * jnp.pi) + p.observations.cov_log_dets + quad)
    )
    return ll

x0 = jnp.squeeze(
    jnp.array(
        [
            p.keplerian_state.semi,
            p.keplerian_state.ecc,
            p.keplerian_state.nu,
            p.keplerian_state.inc,
            p.keplerian_state.Omega,
            p.keplerian_state.omega,
        ]
    )
)
_ = loglike(x0)

In [4]:
# timing the new version of the loglike with the precomputation:
# now 2 ms
%timeit loglike(x0).block_until_ready()

2.01 ms ± 9.11 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
# timing the new version of the gradient, reverse mode:
# now 12 ms
_ = jax.jacrev(loglike)(x0)
%timeit jax.jacrev(loglike)(x0).block_until_ready()

11.9 ms ± 118 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
# timing the new version of the gradient, forward mode:
# now 7 ms
_ = jax.jacfwd(loglike)(x0)
%timeit jax.jacfwd(loglike)(x0).block_until_ready()

7.16 ms ± 28 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
