In [None]:
# ruff: noqa
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from jorbit.accelerations import newtonian_gravity
from jorbit.integrators.ias15 import ias15_step, initialize_ias15_integrator_state
from jorbit.integrators.ias15_static import ias15_static_step
from jorbit.utils.states import SystemState

In [2]:
tracer_positions = jax.random.normal(jax.random.PRNGKey(0), (5, 3))*10
tracer_velocities = jax.random.normal(jax.random.PRNGKey(1), (5, 3))*10
massive_positions = jax.random.normal(jax.random.PRNGKey(2), (10, 3))*10
massive_velocities = jax.random.normal(jax.random.PRNGKey(3), (10, 3))*10
log_gms = jax.random.normal(jax.random.PRNGKey(4), (10,))


state = SystemState(
    tracer_positions=tracer_positions,
    tracer_velocities=tracer_velocities,
    massive_positions=massive_positions,
    massive_velocities=massive_velocities,
    log_gms=log_gms,
    acceleration_func_kwargs={},
    time=0.0,
)
a0 = newtonian_gravity(state)

integrator_state = initialize_ias15_integrator_state(a0)

In [3]:
_ = ias15_step(state, jax.tree_util.Partial(newtonian_gravity), integrator_state)
_ = ias15_static_step(state, jax.tree_util.Partial(newtonian_gravity), integrator_state)

In [4]:
%timeit ias15_step(state, jax.tree_util.Partial(newtonian_gravity), integrator_state)[0].massive_positions.block_until_ready()

392 μs ± 8.28 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [5]:
%timeit ias15_static_step(state, jax.tree_util.Partial(newtonian_gravity), integrator_state)[0].massive_positions.block_until_ready()

181 μs ± 1.79 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [6]:
@jax.jit
def func1(state):
    new_state, _ = ias15_step(state, jax.tree_util.Partial(newtonian_gravity), integrator_state)
    return jnp.sum(new_state.massive_positions)

j1 = jax.jit(jax.grad(func1))
_ = j1(state)
%timeit j1(state).massive_positions.block_until_ready()

2.05 ms ± 10.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
@jax.jit
def func2(state):
    new_state, _ = ias15_static_step(state, jax.tree_util.Partial(newtonian_gravity), integrator_state)
    return jnp.sum(new_state.massive_positions)

j2 = jax.jit(jax.grad(func2))
_ = j2(state)
%timeit j2(state).massive_positions.block_until_ready()

511 μs ± 5.93 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
