In [1]:
# ruff: noqa
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import numpy as np
from astropy.time import Time

from jorbit import Ephemeris
from jorbit.accelerations.newtonian import newtonian_gravity
from jorbit.data.constants import SPEED_OF_LIGHT
from jorbit.utils.states import SystemState

from jorbit.accelerations.newtonian import newtonian_gravity
from jorbit.accelerations.newtonian_old import newtonian_gravity as newtonian_gravity_old

In [3]:
seed = 0
n_tracer = 100
n_massive = 10
n_perturber = 20

s = SystemState(
    tracer_positions=jax.random.normal(jax.random.PRNGKey(seed), (n_tracer, 3)) * 1000,
    tracer_velocities=jax.random.normal(jax.random.PRNGKey(seed + 1), (n_tracer, 3)),
    massive_positions=jax.random.normal(jax.random.PRNGKey(seed + 2), (n_massive, 3)) * 1000,
    massive_velocities=jax.random.normal(jax.random.PRNGKey(seed + 3), (n_massive, 3)),
    log_gms=jnp.log(jax.random.uniform(jax.random.PRNGKey(seed + 4), (n_massive,))),
    time=0.0,
    fixed_perturber_positions=jax.random.normal(jax.random.PRNGKey(seed + 5), (n_perturber, 3)) * 1000,
    fixed_perturber_velocities=jax.random.normal(jax.random.PRNGKey(seed + 6), (n_perturber, 3)),
    fixed_perturber_log_gms=jnp.log(jax.random.uniform(jax.random.PRNGKey(seed + 7), (n_perturber,))),
    acceleration_func_kwargs={"c2": 100.0},
)

old_grad = jax.jacfwd(newtonian_gravity_old)(s)
new_grad = jax.jacfwd(newtonian_gravity)(s)

In [6]:
print(jnp.max(jnp.abs(old_grad.tracer_positions - new_grad.tracer_positions)))
print(jnp.max(jnp.abs(old_grad.tracer_velocities - new_grad.tracer_velocities)))
print(jnp.max(jnp.abs(old_grad.massive_positions - new_grad.massive_positions)))
print(jnp.max(jnp.abs(old_grad.massive_velocities - new_grad.massive_velocities)))
print(jnp.max(jnp.abs(old_grad.log_gms - new_grad.log_gms)))
print()
print(jnp.max(jnp.abs(old_grad.fixed_perturber_positions - new_grad.fixed_perturber_positions)))
print(jnp.max(jnp.abs(old_grad.fixed_perturber_velocities - new_grad.fixed_perturber_velocities)))
print(jnp.max(jnp.abs(old_grad.fixed_perturber_log_gms - new_grad.fixed_perturber_log_gms)))

0.0
0.0
0.0
0.0
0.0

2.0897017474006962e-07
0.0
1.7233824155362678e-05


In [7]:
%timeit jax.jacfwd(newtonian_gravity)(s).tracer_positions.block_until_ready()

6.77 ms ± 46.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
%timeit jax.jacfwd(newtonian_gravity_old)(s).tracer_positions.block_until_ready()

7.19 ms ± 237 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
seed = 0
n_tracer = 1
n_massive = 0
n_perturber = 100

s = SystemState(
    tracer_positions=jax.random.normal(jax.random.PRNGKey(seed), (n_tracer, 3)) * 1000,
    tracer_velocities=jax.random.normal(jax.random.PRNGKey(seed + 1), (n_tracer, 3)),
    massive_positions=jax.random.normal(jax.random.PRNGKey(seed + 2), (n_massive, 3)) * 1000,
    massive_velocities=jax.random.normal(jax.random.PRNGKey(seed + 3), (n_massive, 3)),
    log_gms=jnp.log(jax.random.uniform(jax.random.PRNGKey(seed + 4), (n_massive,))),
    time=0.0,
    fixed_perturber_positions=jax.random.normal(jax.random.PRNGKey(seed + 5), (n_perturber, 3)) * 1000,
    fixed_perturber_velocities=jax.random.normal(jax.random.PRNGKey(seed + 6), (n_perturber, 3)),
    fixed_perturber_log_gms=jnp.log(jax.random.uniform(jax.random.PRNGKey(seed + 7), (n_perturber,))),
    acceleration_func_kwargs={"c2": 100.0},
)

old_grad = jax.jacfwd(newtonian_gravity_old)(s)
new_grad = jax.jacfwd(newtonian_gravity)(s)

In [13]:
%timeit jax.jacfwd(newtonian_gravity)(s).tracer_positions.block_until_ready()

1.96 ms ± 16.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
%timeit jax.jacfwd(newtonian_gravity_old)(s).tracer_positions.block_until_ready()

1.94 ms ± 26.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
