In [1]:
# ruff: noqa
import jax
jax.config.update("jax_enable_x64", True)
from typing import Callable

import astropy.constants as const
import astropy.units as u
import jax.numpy as jnp
from astropy.coordinates import SkyCoord
from astropy.time import Time

from scipy.special import eval_chebyt
import matplotlib.pyplot as plt

from jorbit import Ephemeris, Particle
from jorbit.accelerations import static_ppn_gravity, newtonian_gravity, create_default_ephemeris_acceleration_func
from jorbit.utils.states import SystemState

In [2]:
def generate_perturber_chebyshev_coeffs(obs_time: Time, ephem: Ephemeris) -> jnp.ndarray:
    """Generate a localized set of Chebyshev coefficients for light travel time corrections.

    Usually we can just query an ephemeris for a set of coefficients, but here we want
    to cache as much information as possible to make repeated evaluations fast. But,
    since we don't know the light travel time correction a priori, we still need a
    continuous function to approximate the perturber's position over a small time
    window. I'd take the relevant chunk of Chebyshev coefficients from an ephemeris,
    but there's a chance that the correction spans two piecewise chunks, so instead
    we just fit a new set of Chebyshev polynomials over a small window around the
    observation time.
    
    """

    times = obs_time + jnp.linspace(-1.5, 0.1, 1000) * u.day # should be good for ~250 AU light travel time
    actual_xs, actual_vs = jax.vmap(eph.processor.state)(times.tdb.jd)

    deg = 8
    t = times.tdb.jd
    t0, t1 = t[0], t[-1]
    x = 2*(t - t0)/(t1 - t0) - 1

    def internal(data):
        coeffs_list = []
        for obj in range(data.shape[1]):
            obj_coeffs = []
            for dim in range(3):
                A = jnp.column_stack([
                    eval_chebyt(k, x) for k in range(deg + 1)
                ])
                coeffs, *_ = jnp.linalg.lstsq(A, data[:, obj, dim], rcond=None)
                obj_coeffs.append(coeffs)
            coeffs_list.append(jnp.stack(obj_coeffs))
        coeffs_list = jnp.array(coeffs_list)
        return jnp.moveaxis(coeffs_list, 2, 0)

    position_coeffs = internal(actual_xs)[::-1]
    velocity_coeffs = internal(actual_vs)[::-1]

    return position_coeffs, velocity_coeffs

In [3]:
def create_static_default_on_sky_acc_function():

    def static_on_sky_acc(inputs):
        num_gr_perturbers = 11  # the "planets", including the sun, moon, and pluto
        # num_newtonian_perturbers = 16  # the asteroids

        def eval_cheby(coefficients: jnp.ndarray, x: float) -> tuple:
            b_ii = 0.0
            b_i = 0.0

            def scan_func(X: tuple, a: jnp.ndarray) -> tuple:
                b_i, b_ii = X
                tmp = b_i
                b_i = a + 2 * x * b_i - b_ii
                b_ii = tmp
                return (b_i, b_ii), None

            (b_i, b_ii), _ = jax.lax.scan(scan_func, (b_i, b_ii), coefficients[:-1])
            return coefficients[-1] + x * b_i - b_ii

        x_coeffs = inputs.acceleration_func_kwargs["perturber_position_cheby_coeffs"]
        v_coeffs = inputs.acceleration_func_kwargs["perturber_velocity_cheby_coeffs"]
        cheby_t0 = inputs.acceleration_func_kwargs["cheby_t0"]
        cheby_t1 = inputs.acceleration_func_kwargs["cheby_t1"]

        x = 2 * (inputs.time - cheby_t0) / (cheby_t1 - cheby_t0) - 1

        perturber_xs = jax.vmap(
            jax.vmap(eval_cheby, in_axes=(1, None)), in_axes=(1, None)
        )(x_coeffs, x)
        perturber_vs = jax.vmap(
            jax.vmap(eval_cheby, in_axes=(1, None)), in_axes=(1, None)
        )(v_coeffs, x)

        # from here out it looks like create_static_default_acceleration_func
        perturber_log_gms = inputs.fixed_perturber_log_gms

        gr_state = SystemState(
            massive_positions=inputs.massive_positions,
            massive_velocities=inputs.massive_velocities,
            tracer_positions=inputs.tracer_positions,
            tracer_velocities=inputs.tracer_velocities,
            log_gms=inputs.log_gms[:num_gr_perturbers],
            time=inputs.time,
            fixed_perturber_positions=perturber_xs[:num_gr_perturbers],
            fixed_perturber_velocities=perturber_vs[:num_gr_perturbers],
            fixed_perturber_log_gms=perturber_log_gms[:num_gr_perturbers],
            acceleration_func_kwargs=inputs.acceleration_func_kwargs,
        )
        gr_acc = static_ppn_gravity(gr_state)

        newtonian_state = SystemState(
            massive_positions=inputs.massive_positions,
            massive_velocities=inputs.massive_velocities,
            tracer_positions=inputs.tracer_positions,
            tracer_velocities=inputs.tracer_velocities,
            log_gms=inputs.log_gms[num_gr_perturbers:],
            time=inputs.time,
            fixed_perturber_positions=perturber_xs[num_gr_perturbers:],
            fixed_perturber_velocities=perturber_vs[num_gr_perturbers:],
            fixed_perturber_log_gms=perturber_log_gms[num_gr_perturbers:],
            acceleration_func_kwargs=inputs.acceleration_func_kwargs,
        )
        newtonian_acc = newtonian_gravity(newtonian_state)

        return gr_acc + newtonian_acc

    return jax.tree_util.Partial(static_on_sky_acc)

In [4]:
t0 = Time("2026-01-01")
p = Particle.from_horizons("274301", t0)
state = p.cartesian_state.to_system()

eph = Ephemeris(ssos="default solar system")
x_coeffs, v_coeffs = generate_perturber_chebyshev_coeffs(t0, eph)

state.fixed_perturber_log_gms = eph.processor.log_gms
state.acceleration_func_kwargs["perturber_position_cheby_coeffs"] = x_coeffs
state.acceleration_func_kwargs["perturber_velocity_cheby_coeffs"] = v_coeffs
state.acceleration_func_kwargs["cheby_t0"] = (t0.tdb.jd - 1.5)
state.acceleration_func_kwargs["cheby_t1"] = (t0.tdb.jd + 0.1)

acc_func = create_static_default_on_sky_acc_function()
a_new = acc_func(state)
a_new

Array([[4.03408192e-05, 2.83907818e-05, 8.81375997e-06]], dtype=float64)

In [5]:
acc_func_old = create_default_ephemeris_acceleration_func(eph.processor)
a_old = acc_func_old(state)
a_old

Array([[4.03408192e-05, 2.83907818e-05, 8.81375997e-06]], dtype=float64)

In [6]:
a_old - a_new

Array([[-8.13151629e-20, -9.82558219e-20, -3.04931861e-20]], dtype=float64)

In [7]:
t0 = Time("2026-01-01")
p = Particle.from_horizons("274301", t0)
state = p.cartesian_state.to_system()

eph = Ephemeris(ssos="default solar system")
x_coeffs, v_coeffs = generate_perturber_chebyshev_coeffs(t0, eph)

state.fixed_perturber_log_gms = eph.processor.log_gms
state.acceleration_func_kwargs["perturber_position_cheby_coeffs"] = x_coeffs
state.acceleration_func_kwargs["perturber_velocity_cheby_coeffs"] = v_coeffs
state.acceleration_func_kwargs["cheby_t0"] = (t0.tdb.jd - 1.5)
state.acceleration_func_kwargs["cheby_t1"] = (t0.tdb.jd + 0.1)

old_acc_func = jax.jit(create_default_ephemeris_acceleration_func(eph.processor))
new_acc_func = jax.jit(create_static_default_on_sky_acc_function())

_ = old_acc_func(state)
_ = new_acc_func(state)

In [8]:
%timeit old_acc_func(state)

73.2 μs ± 5.8 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [9]:
%timeit new_acc_func(state)

22.8 μs ± 494 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
