In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from astropy.time import Time
import astropy.units as u

from jorbit import Particle, Ephemeris
from jorbit.accelerations import create_default_ephemeris_acceleration_func
from jorbit.integrators import initialize_ias15_integrator_state

from jorbit.accelerations.static_helpers import (
    precompute_perturber_positions,
    get_all_dynamic_intermediate_dts,
    get_fixed_intermediate_dts,
)

In [3]:
t0 = Time("2026-01-01")
p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()

state.fixed_perturber_positions = jnp.empty((27,3))
state.fixed_perturber_velocities = jnp.empty((27,3))
state.fixed_perturber_log_gms = jnp.empty(27,)

eph = Ephemeris(ssos="default solar system")
acc_func = create_default_ephemeris_acceleration_func(eph.processor)
a0 = acc_func(state)
integrator_init = initialize_ias15_integrator_state(a0)

times = Time(t0.tdb.jd + jnp.linspace(0, 1000.0, 10), format="jd", scale="tdb")

In [5]:
dts, inds = get_fixed_intermediate_dts(
    t0=t0,
    times=times,
    max_step_size=10.0 * u.day,
)

(t0.tdb.jd + jnp.cumsum(dts))[inds] - times.tdb.jd

Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float64)

In [5]:
dts, inds = get_all_dynamic_intermediate_dts(
    initial_system_state=state,
    acceleration_func=acc_func,
    times=times,
    initial_integrator_state=integrator_init,
)

(t0.tdb.jd + jnp.cumsum(dts))[inds] - times.tdb.jd

Array([0.00000000e+00, 4.65661287e-10, 4.65661287e-10, 9.31322575e-10,
       4.65661287e-10, 4.65661287e-10, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 4.65661287e-10], dtype=float64)

In [15]:
planet_pos, planet_vel, asteroid_pos, asteroid_vel, gms = (
    precompute_perturber_positions(
        t0=t0,
        dts=dts,
        de_ephemeris_version="440",
    )
)
planet_pos.shape, dts.shape, times.shape, dts[inds].shape

((36, 9, 11, 3), (36,), (10,), (10,))

In [1]:
# ruff: noqa
import jax
jax.config.update("jax_enable_x64", True)
from typing import Callable

import astropy.units as u
import jax.numpy as jnp
from astropy.coordinates import SkyCoord
from astropy.time import Time

from jorbit import Particle
from jorbit.accelerations import create_static_default_acceleration_func, create_default_ephemeris_acceleration_func
from jorbit.accelerations.static_helpers import precompute_perturber_positions
from jorbit.integrators import (
    ias15_evolve,
    ias15_step,
    ias15_static_evolve,
    ias15_static_step,
    initialize_ias15_integrator_state,
)
from jorbit.utils.states import IAS15IntegratorState, SystemState
from jorbit import Ephemeris

In [2]:
t0 = Time("2026-01-01")
tf = Time("2036-01-01")

p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()


state.fixed_perturber_positions = jnp.empty((27,3))
state.fixed_perturber_velocities = jnp.empty((27,3))
state.fixed_perturber_log_gms = jnp.empty(27,)

eph = Ephemeris(ssos="default solar system")
acc_func = create_default_ephemeris_acceleration_func(eph.processor)
a0 = acc_func(state)
state.fixed_perturber_positions *= 0.0 # just making sure these initial values aren't being used
state.fixed_perturber_velocities *= 0.0
integrator_init = initialize_ias15_integrator_state(a0)

In [14]:
def get_intermediate_dts(
    initial_system_state: IAS15IntegratorState,
    acceleration_func: Callable,
    final_time: float,
    initial_integrator_state: IAS15IntegratorState,
) -> jnp.ndarray:

    def step_needed(args: tuple) -> tuple:
        system_state, integrator_state, last_meaningful_dt, iter_num = args

        t = system_state.time

        diff = final_time - t
        step_length = jnp.sign(diff) * jnp.min(
            jnp.array([jnp.abs(diff), jnp.abs(integrator_state.dt)])
        )

        integrator_state.dt = step_length

        system_state, integrator_state = ias15_step(
            system_state, acceleration_func, integrator_state
        )
        return system_state, integrator_state, last_meaningful_dt, iter_num + 1

    def cond_func(args: tuple) -> bool:
        system_state, integrator_state, _last_meaningful_dt, iter_num = args
        t = system_state.time

        step_length = jnp.sign(final_time - t) * jnp.min(
            jnp.array([jnp.abs(final_time - t), jnp.abs(integrator_state.dt)])
        )
        return (step_length != 0) & (iter_num < 10_000)

    args = (
        initial_system_state,
        initial_integrator_state,
        initial_integrator_state.dt,
        0,
    )
    dts = []
    while cond_func(args):
        args = step_needed(args)
        if args[1].dt_last_done != 0:
            dts.append(args[1].dt_last_done)


    return jnp.array(dts), args[0], args[1]

dts, final_system_state, final_integrator_state = get_intermediate_dts(
        state,
        acc_func,
        t0.tdb.jd + 300,
        integrator_init,
    )
dts

Array([26.19087282, 35.44901396, 33.52251236, 31.85992714, 30.43838936,
       29.23612411, 28.23647559, 27.4248067 , 26.78855089, 26.320351  ,
        4.53297605], dtype=float64)

In [15]:
def get_all_intermediate_dts(
    initial_system_state: IAS15IntegratorState,
    acceleration_func: Callable,
    times: jnp.ndarray,
    initial_integrator_state: IAS15IntegratorState,
) -> jnp.ndarray:
    all_dts = []
    system_state = initial_system_state
    integrator_state = initial_integrator_state

    for final_time in times:
        dts, system_state, integrator_state = get_intermediate_dts(
            system_state,
            acceleration_func,
            final_time,
            integrator_state,
        )
        all_dts.append(dts)

    return jnp.concatenate(all_dts)

get_all_intermediate_dts(
    state,
    acc_func,
    jnp.linspace(t0.tdb.jd, tf.tdb.jd, 5),
    integrator_init,
)

Array([26.19087282, 35.44901396, 33.52251236, 31.85992714, 30.43838936,
       29.23612411, 28.23647559, 27.4248067 , 26.78855089, 26.320351  ,
       26.01232189, 25.86052816, 25.86219564, 26.01586526, 26.32180515,
       26.78160503, 27.39900384, 28.17934851, 29.12998971, 30.26325011,
       31.58966912, 33.12514576, 34.8866567 , 36.890833  , 39.15217002,
       41.67933981, 44.45172289, 47.41002607, 36.52149939, 52.34471163,
       54.27560599, 54.66796893, 53.27009371, 50.62580343, 47.48405788,
       44.34356027, 41.41971933, 38.7873276 , 36.46150786, 34.4250948 ,
       32.65851814, 31.13948817, 29.84678871, 28.76162087, 27.86865255,
       27.15464096, 26.61097424, 26.23033682, 26.00651885, 25.93624375,
       26.01808096, 26.25215878, 26.63812909, 27.17908986, 16.59330681,
       28.37824292, 29.34829752, 30.50116343, 31.8481963 , 33.40422768,
       35.18604454, 37.21013902, 39.48984392, 42.02767069, 44.80290964,
       47.74249863, 50.6543503 , 53.1365438 , 54.53382954, 54.23

In [None]:
t0 = Time("2026-01-01")
tf = Time("2036-01-01")

p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()
x, v = p.integrate(tf)

planet_pos, planet_vel, asteroid_pos, asteroid_vel, all_times, obs_indices, gms = precompute_perturber_positions(
    t0=t0,
    times=tf,
    max_step_size=25,
    de_ephemeris_version="440"
)

perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

state.fixed_perturber_positions = perturber_pos[0,0]
state.fixed_perturber_velocities = perturber_vel[0,0]
state.fixed_perturber_log_gms = gms

acc_func = create_static_default_acceleration_func()
a0 = acc_func(state)
state.fixed_perturber_positions *= 0.0 # just making sure these initial values aren't being used
state.fixed_perturber_velocities *= 0.0
integrator_init = initialize_ias15_integrator_state(a0)

In [2]:
t0 = Time("2026-01-01")
tf = Time("2036-01-01")

p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()
x, v = p.integrate(tf)

planet_pos, planet_vel, asteroid_pos, asteroid_vel, all_times, obs_indices, gms = precompute_perturber_positions(
    t0=t0,
    times=tf,
    max_step_size=25,
    de_ephemeris_version="440"
)

perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

state.fixed_perturber_positions = perturber_pos[0,0]
state.fixed_perturber_velocities = perturber_vel[0,0]
state.fixed_perturber_log_gms = gms

acc_func = create_static_default_acceleration_func()
a0 = acc_func(state)
state.fixed_perturber_positions *= 0.0 # just making sure these initial values aren't being used
state.fixed_perturber_velocities *= 0.0
integrator_init = initialize_ias15_integrator_state(a0)

positions, velocities, _, _ = ias15_static_evolve(
    state,
    acc_func,
    jnp.diff(all_times),
    perturber_pos,
    perturber_vel,
    gms,
    integrator_init,
)

positions[-1] - x

Array([[4.21973567e-12, 4.91118257e-12, 1.43129952e-12]], dtype=float64)

In [3]:
%%timeit
positions, velocities, _, _ = ias15_static_evolve(
    state,
    acc_func,
    jnp.diff(all_times),
    perturber_pos,
    perturber_vel,
    gms,
    integrator_init,
)
positions.block_until_ready()

35.4 ms ± 523 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [4]:
%%timeit
x, v = p.integrate(tf)
x.block_until_ready()

84.8 ms ± 1.64 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
