In [1]:
#ruff: noqa
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from astropy.time import Time
from astropy.coordinates import SkyCoord
import astropy.units as u

from jorbit import Particle
from jorbit.accelerations.static_helpers import precompute_perturber_positions
from jorbit.integrators import ias15_evolve, ias15_static_evolve, initialize_ias15_integrator_state, ias15_static_step
from jorbit.accelerations import create_static_default_acceleration_func

In [None]:
t0 = Time("2026-01-01")
tf = Time("2036-01-01")

p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()
x, v = p.integrate(tf)

planet_pos, planet_vel, asteroid_pos, asteroid_vel, all_times, obs_indices, gms = precompute_perturber_positions(
    t0=t0,
    times=tf,
    max_step_size=10,
    de_ephemeris_version="440"
)

perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

state.fixed_perturber_positions = perturber_pos[0,0]
state.fixed_perturber_velocities = perturber_vel[0,0]
state.fixed_perturber_log_gms = gms

acc_func = create_static_default_acceleration_func()
a0 = acc_func(state)
state.fixed_perturber_positions *= 0.0 # just making sure these initial values aren't being used
state.fixed_perturber_velocities *= 0.0
integrator_init = initialize_ias15_integrator_state(a0)

positions, velocities, _, _ = ias15_static_evolve(
    state,
    acc_func,
    jnp.diff(all_times),
    perturber_pos,
    perturber_vel,
    gms,
    integrator_init,
)

positions[-1] - x

Array([[4.14424051e-12, 4.81770179e-12, 1.40243372e-12]], dtype=float64)

In [3]:
%%timeit
positions, velocities, _, _ = ias15_static_evolve(
    state,
    acc_func,
    jnp.diff(all_times),
    perturber_pos,
    perturber_vel,
    gms,
    integrator_init,
)
positions.block_until_ready()

89.5 ms ± 2.11 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [4]:
%%timeit
x, v = p.integrate(tf)
x.block_until_ready()

83 ms ± 2.53 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


np.float64(2461041.50080074)

In [30]:
from jorbit.data.constants import SPEED_OF_LIGHT
from jorbit.utils.states import KeplerianState

t0_num = t0.tdb.jd
def grad_wrapper(a,e,nu,Omega,inc,omega):
    s = KeplerianState(
        semi=a,
        ecc=e,
        nu=nu,
        Omega=Omega,
        inc=inc,
        omega=omega,
        acceleration_func_kwargs={"c2": SPEED_OF_LIGHT**2},
        time=t0_num,
    )
    state = s.to_system()
    state.fixed_perturber_positions = jnp.empty((27,3))
    state.fixed_perturber_velocities = jnp.empty((27,3))
    state.fixed_perturber_log_gms = jnp.empty((27,))

    positions, velocities, _, _ = ias15_static_evolve(
        state,
        acc_func,
        jnp.diff(all_times),
        perturber_pos,
        perturber_vel,
        gms,
        integrator_init,
    )
    return positions[-1]

t0 = Time("2026-01-01")
tf = Time("2036-01-01")

p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()

planet_pos, planet_vel, asteroid_pos, asteroid_vel, all_times, obs_indices, gms = precompute_perturber_positions(
    t0=t0,
    times=tf,
    max_step_size=10,
    de_ephemeris_version="440"
)

perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

state.fixed_perturber_positions = perturber_pos[0,0]
state.fixed_perturber_velocities = perturber_vel[0,0]
state.fixed_perturber_log_gms = gms
k = p.keplerian_state

j1 = jax.jit(jax.jacfwd(grad_wrapper))
j1(k.semi, k.ecc, k.nu, k.Omega, k.inc, k.omega)

Array([[[15.19935031],
        [15.47540925],
        [ 4.75566708]]], dtype=float64)

In [31]:
%timeit j1(k.semi, k.ecc, k.nu, k.Omega, k.inc, k.omega).block_until_ready()

840 ms ± 15.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [34]:
from jorbit.data.constants import SPEED_OF_LIGHT
from jorbit.utils.states import KeplerianState

t0_num = t0.tdb.jd
def grad_wrapper(a,e,nu,Omega,inc,omega):
    s = KeplerianState(
        semi=a,
        ecc=e,
        nu=nu,
        Omega=Omega,
        inc=inc,
        omega=omega,
        acceleration_func_kwargs={"c2": SPEED_OF_LIGHT**2},
        time=t0_num,
    )
    state = s.to_system()
    state.fixed_perturber_positions = jnp.empty((27,3))
    state.fixed_perturber_velocities = jnp.empty((27,3))
    state.fixed_perturber_log_gms = jnp.empty((27,))

    positions, velocities, _, _ = ias15_evolve(
        state,
        acc_func,
        jnp.array([tf.tdb.jd]),
        integrator_init,
    )
    return positions

t0 = Time("2026-01-01")
tf = Time("2036-01-01")

p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()

planet_pos, planet_vel, asteroid_pos, asteroid_vel, all_times, obs_indices, gms = precompute_perturber_positions(
    t0=t0,
    times=tf,
    max_step_size=10,
    de_ephemeris_version="440"
)

perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

state.fixed_perturber_positions = perturber_pos[0,0]
state.fixed_perturber_velocities = perturber_vel[0,0]
state.fixed_perturber_log_gms = gms
k = p.keplerian_state

j2 = jax.jit(jax.jacfwd(grad_wrapper))
j2(k.semi, k.ecc, k.nu, k.Omega, k.inc, k.omega)

Array([[[[nan],
         [nan],
         [nan]]]], dtype=float64)

In [35]:
%timeit j2(k.semi, k.ecc, k.nu, k.Omega, k.inc, k.omega).block_until_ready()

1.36 s ± 9.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
