In [1]:
#ruff: noqa
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from astropy.time import Time
from astropy.coordinates import SkyCoord
import astropy.units as u

from jorbit import Particle
from jorbit.accelerations.static_helpers import precompute_perturber_positions
from jorbit.integrators import ias15_evolve, ias15_static_evolve, initialize_ias15_integrator_state, ias15_static_step
from jorbit.accelerations import create_static_default_acceleration_func

In [None]:
# get dts from a standard ias15_evolve run, use those in static_helpers
# then, create a tiny function around each observation time to handle continuous-time
# integrations for light travel time

In [23]:
@jax.jit
def test(x):
    def cond_func(x):
        return x < 1_000

    def body_func(x):
        return x + 1

    x = jax.lax.while_loop(cond_func, body_func, x)
    return x

test(0)

%timeit test(0).block_until_ready()

31 μs ± 1.46 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [24]:
%timeit jax.jacfwd(test)(0.0).block_until_ready()

716 μs ± 30.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [25]:
n = 10_000
@jax.jit
def test(x):
    def take_step(x):
        return x + 1

    def do_nothing(x):
        return x

    def cond_func(carry):
        x = carry
        return x < 1_000

    def scan_func(carry, scan_over):
        x = carry
        x = jax.lax.cond(cond_func(x), take_step, do_nothing, x)
        return x, None

    x, _ = jax.lax.scan(scan_func, x, None, length=n)
    return x
test(0)
%timeit test(0).block_until_ready()

443 μs ± 8.57 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [26]:
%timeit jax.jacfwd(test)(0.0).block_until_ready()

1.17 ms ± 21.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [2]:
t0 = Time("2026-01-01")
tf = Time("2036-01-01")

p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()
x, v = p.integrate(tf)

planet_pos, planet_vel, asteroid_pos, asteroid_vel, all_times, obs_indices, gms = precompute_perturber_positions(
    t0=t0,
    times=tf,
    max_step_size=25,
    de_ephemeris_version="440"
)

perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

state.fixed_perturber_positions = perturber_pos[0,0]
state.fixed_perturber_velocities = perturber_vel[0,0]
state.fixed_perturber_log_gms = gms

acc_func = create_static_default_acceleration_func()
a0 = acc_func(state)
state.fixed_perturber_positions *= 0.0 # just making sure these initial values aren't being used
state.fixed_perturber_velocities *= 0.0
integrator_init = initialize_ias15_integrator_state(a0)

positions, velocities, _, _ = ias15_static_evolve(
    state,
    acc_func,
    jnp.diff(all_times),
    perturber_pos,
    perturber_vel,
    gms,
    integrator_init,
)

positions[-1] - x

Array([[4.21973567e-12, 4.91118257e-12, 1.43129952e-12]], dtype=float64)

In [3]:
%%timeit
positions, velocities, _, _ = ias15_static_evolve(
    state,
    acc_func,
    jnp.diff(all_times),
    perturber_pos,
    perturber_vel,
    gms,
    integrator_init,
)
positions.block_until_ready()

35.4 ms ± 523 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [4]:
%%timeit
x, v = p.integrate(tf)
x.block_until_ready()

84.8 ms ± 1.64 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
from jorbit.data.constants import SPEED_OF_LIGHT
from jorbit.utils.states import KeplerianState

t0_num = t0.tdb.jd
def grad_wrapper(a,e,nu,Omega,inc,omega):
    s = KeplerianState(
        semi=a,
        ecc=e,
        nu=nu,
        Omega=Omega,
        inc=inc,
        omega=omega,
        acceleration_func_kwargs={"c2": SPEED_OF_LIGHT**2},
        time=t0_num,
    )
    state = s.to_system()
    state.fixed_perturber_positions = jnp.empty((27,3))
    state.fixed_perturber_velocities = jnp.empty((27,3))
    state.fixed_perturber_log_gms = jnp.empty((27,))

    positions, velocities, _, _ = ias15_static_evolve(
        state,
        acc_func,
        jnp.diff(all_times),
        perturber_pos,
        perturber_vel,
        gms,
        integrator_init,
    )
    return positions[-1]

t0 = Time("2026-01-01")
tf = Time("2036-01-01")

p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()

planet_pos, planet_vel, asteroid_pos, asteroid_vel, all_times, obs_indices, gms = precompute_perturber_positions(
    t0=t0,
    times=tf,
    max_step_size=25,
    de_ephemeris_version="440"
)

perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

state.fixed_perturber_positions = perturber_pos[0,0]
state.fixed_perturber_velocities = perturber_vel[0,0]
state.fixed_perturber_log_gms = gms
k = p.keplerian_state

j1 = jax.jit(jax.jacfwd(grad_wrapper))
j1(k.semi, k.ecc, k.nu, k.Omega, k.inc, k.omega)

Array([[[15.20479439],
        [15.48402737],
        [ 4.75829435]]], dtype=float64)

In [6]:
%timeit j1(k.semi, k.ecc, k.nu, k.Omega, k.inc, k.omega).block_until_ready()

265 ms ± 10 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
from jorbit.data.constants import SPEED_OF_LIGHT
from jorbit.utils.states import KeplerianState

t0_num = t0.tdb.jd
def grad_wrapper(a,e,nu,Omega,inc,omega):
    s = KeplerianState(
        semi=a,
        ecc=e,
        nu=nu,
        Omega=Omega,
        inc=inc,
        omega=omega,
        acceleration_func_kwargs={"c2": SPEED_OF_LIGHT**2},
        time=t0_num,
    )
    state = s.to_system()
    state.fixed_perturber_positions = jnp.empty((27,3))
    state.fixed_perturber_velocities = jnp.empty((27,3))
    state.fixed_perturber_log_gms = jnp.empty((27,))

    positions, velocities, _, _ = ias15_evolve(
        state,
        acc_func,
        jnp.array([tf.tdb.jd]),
        integrator_init,
    )
    return positions

t0 = Time("2026-01-01")
tf = Time("2036-01-01")

p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()

planet_pos, planet_vel, asteroid_pos, asteroid_vel, all_times, obs_indices, gms = precompute_perturber_positions(
    t0=t0,
    times=tf,
    max_step_size=10,
    de_ephemeris_version="440"
)

perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

state.fixed_perturber_positions = perturber_pos[0,0]
state.fixed_perturber_velocities = perturber_vel[0,0]
state.fixed_perturber_log_gms = gms
k = p.keplerian_state

j2 = jax.jit(jax.jacfwd(grad_wrapper))
j2(k.semi, k.ecc, k.nu, k.Omega, k.inc, k.omega)

Array([[[[nan],
         [nan],
         [nan]]]], dtype=float64)

In [8]:
%timeit j2(k.semi, k.ecc, k.nu, k.Omega, k.inc, k.omega).block_until_ready()

1.46 s ± 46.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    positions, velocities, _, _ = ias15_evolve(
        state,
        acc_func,
        jnp.array([tf.tdb.jd]),
        integrator_init,
    )
    positions.block_until_ready()

2026-02-03 12:58:31.760848: E external/xla/xla/python/profiler/internal/python_hooks.cc:400] Can't import tensorflow.python.profiler.trace


In [4]:
import jax
import jax.numpy as jnp
from jax.profiler import TraceAnnotation

@jax.jit
def my_computation(x):
    result = jnp.dot(x, x.T)
    result /= 10.0
    result += 5.0
    return result

@jax.jit
def preprocessing(x):
    return x * 2

x = jnp.ones((1000, 1000))
# Warm up the JIT compilation
_ = my_computation(x)
_ = preprocessing(x)

# Annotate at the CALL site, not inside the function
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    with TraceAnnotation("preprocessing"):
        x = preprocessing(x)
    with TraceAnnotation("my_computation"):
        result = my_computation(x)

2026-02-03 13:08:10.196209: E external/xla/xla/python/profiler/internal/python_hooks.cc:400] Can't import tensorflow.python.profiler.trace
2026-02-03 13:08:10.197358: E external/xla/xla/python/profiler/internal/python_hooks.cc:400] Can't import tensorflow.python.profiler.trace


Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz


127.0.0.1 - - [03/Feb/2026 13:08:14] code 404, message File not found
127.0.0.1 - - [03/Feb/2026 13:08:14] "POST /status HTTP/1.1" 404 -
127.0.0.1 - - [03/Feb/2026 13:08:14] "GET /perfetto_trace.json.gz HTTP/1.1" 200 -


In [5]:
import jax
import jax.numpy as jnp
import cProfile
import pstats

@jax.jit
def my_computation(x):
    result = jnp.dot(x, x.T)
    return result

@jax.jit  
def preprocessing(x):
    return x * 2

x = jnp.ones((1000, 1000))
_ = my_computation(x)
_ = preprocessing(x)

# Profile with cProfile
profiler = cProfile.Profile()
profiler.enable()

x = preprocessing(x)
result = my_computation(x)

profiler.disable()
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
stats.print_stats(20)  # Top 20 functions

         65 function calls in 0.000 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        3    0.000    0.000    0.000    0.000 /Users/cassese/Documents/virtual_envs/jorbit/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3541(run_code)
        3    0.000    0.000    0.000    0.000 {built-in method builtins.exec}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        3    0.000    0.000    0.000    0.000 /Users/cassese/.local/share/uv/python/cpython-3.13.2-macos-aarch64-none/lib/python3.13/contextlib.py:145(__exit__)
        3    0.000    0.000    0.000    0.000 /Users/cassese/.local/share/uv/python/cpython-3.13.2-macos-aarch64-none/lib/python3.13/codeop.py:113(__call__)
        3    0.000    0.000    0.000    0.000 {built-in method builtins.compile}
        1    0.000    0.000    0.000    0.000 /var/folders/mj/qxz5chg95r53_2nlv9f86qhm0000gn/T/ipykernel_5174/3

<pstats.Stats at 0x112333a10>