In [1]:
# ruff: noqa
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import numpy as np
import rebound
import reboundx
from astropy.time import Time

from jorbit import Ephemeris
from jorbit.accelerations.newtonian import newtonian_gravity
from jorbit.data.constants import SPEED_OF_LIGHT
from jorbit.utils.states import SystemState

from jorbit.accelerations.gr import ppn_gravity, static_ppn_gravity, static_ppn_gravity_tracer
from jorbit.accelerations.gr_old import ppn_gravity as ppn_gravity_old

In [2]:
seed = 0
n_massive = 10
n_tracer = 100

np.random.seed(seed)
massive_x = []
massive_v = []
ms = []
sim = rebound.Simulation()
for _i in range(n_massive):
    xs = np.random.normal(0, 1, 3) * 1000
    vs = np.random.normal(0, 1, 3)
    massive_x.append(xs)
    massive_v.append(vs)
    m = np.random.uniform(0, 1)
    ms.append(m)
    sim.add(m=m, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
tracer_x = []
tracer_v = []
for _i in range(n_tracer):
    xs = np.random.normal(0, 1, 3) * 1000
    vs = np.random.normal(0, 1, 3)
    tracer_x.append(xs)
    tracer_v.append(vs)
    sim.add(m=0.0, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
rebx = reboundx.Extras(sim)
gr = rebx.load_force("gr_full")
gr.params["c"] = 10
gr.params["max_iterations"] = 100
rebx.add_force(gr)
sim.integrate(1e-300)
reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

tracer_x = jnp.array(tracer_x)
tracer_v = jnp.array(tracer_v)
massive_x = jnp.array(massive_x)
massive_v = jnp.array(massive_v)
ms = jnp.array(ms)
s = SystemState(
    tracer_positions=tracer_x,
    tracer_velocities=tracer_v,
    massive_positions=massive_x,
    massive_velocities=massive_v,
    log_gms=jnp.log(ms),
    time=0.0,
    fixed_perturber_positions=jnp.empty((0, 3)),
    fixed_perturber_velocities=jnp.empty((0, 3)),
    fixed_perturber_log_gms=jnp.empty((0,)),
    acceleration_func_kwargs={"c2": 100.0},
)
new_res = ppn_gravity(s)
old_res = ppn_gravity_old(s)

In [3]:
jnp.max(jnp.abs(new_res - old_res)), jnp.max(jnp.abs(reb_res - new_res)), jnp.max(jnp.abs(reb_res - old_res))

(Array(2.71050543e-20, dtype=float64),
 Array(5.42101086e-20, dtype=float64),
 Array(2.71050543e-20, dtype=float64))

In [4]:
%timeit ppn_gravity(s).block_until_ready()

81.8 μs ± 401 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [5]:
%timeit ppn_gravity_old(s).block_until_ready()

589 μs ± 6.45 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
tracer_x = jnp.array(tracer_x)
tracer_v = jnp.array(tracer_v)
massive_x = jnp.array(massive_x)
massive_v = jnp.array(massive_v)
ms = jnp.array(ms)
s2 = SystemState(
    tracer_positions=tracer_x,
    tracer_velocities=tracer_v,
    massive_positions=jnp.empty((0, 3)),
    massive_velocities=jnp.empty((0, 3)),
    log_gms=jnp.empty((0,)),
    time=0.0,
    fixed_perturber_positions=massive_x,
    fixed_perturber_velocities=massive_v,
    fixed_perturber_log_gms=jnp.log(ms),
    acceleration_func_kwargs={"c2": 100.0},
)
new_res2 = ppn_gravity(s2)
old_res2 = ppn_gravity_old(s2)

In [7]:
jnp.max(jnp.abs(new_res - old_res))

Array(2.71050543e-20, dtype=float64)

In [8]:
%timeit ppn_gravity(s2).block_until_ready()

81.9 μs ± 213 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [9]:
%timeit ppn_gravity_old(s2).block_until_ready()

589 μs ± 7.24 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
new_static = static_ppn_gravity(s2)
jnp.max(jnp.abs(new_static - old_res2))

Array(2.71050543e-20, dtype=float64)

In [11]:
%timeit static_ppn_gravity(s2).block_until_ready()

73.1 μs ± 268 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [12]:
new_static_tracer = static_ppn_gravity_tracer(s2)
jnp.max(jnp.abs(new_static_tracer - old_res2))

Array(2.71050543e-20, dtype=float64)

In [13]:
%timeit static_ppn_gravity_tracer(s2).block_until_ready()

59.6 μs ± 499 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
s = SystemState(
    tracer_positions=jax.random.normal(jax.random.PRNGKey(seed), (n_tracer, 3)) * 1000,
    tracer_velocities=jax.random.normal(jax.random.PRNGKey(seed + 1), (n_tracer, 3)),
    massive_positions=jax.random.normal(jax.random.PRNGKey(seed + 2), (n_massive, 3)) * 1000,
    massive_velocities=jax.random.normal(jax.random.PRNGKey(seed + 3), (n_massive, 3)),
    log_gms=jnp.log(jax.random.uniform(jax.random.PRNGKey(seed + 4), (n_massive,))),
    time=0.0,
    fixed_perturber_positions=jax.random.normal(jax.random.PRNGKey(seed + 5), (n_massive, 3)) * 1000,
    fixed_perturber_velocities=jax.random.normal(jax.random.PRNGKey(seed + 6), (n_massive, 3)),
    fixed_perturber_log_gms=jnp.log(jax.random.uniform(jax.random.PRNGKey(seed + 7), (n_massive,))),
    acceleration_func_kwargs={"c2": 100.0},
)

old_grad = jax.jacfwd(ppn_gravity_old)(s)
new_grad1 = jax.jacfwd(ppn_gravity)(s)
new_grad2 = jax.jacfwd(static_ppn_gravity)(s)

In [6]:
print(jnp.max(jnp.abs(old_grad.tracer_positions - new_grad1.tracer_positions)))
print(jnp.max(jnp.abs(old_grad.tracer_velocities - new_grad1.tracer_velocities)))
print(jnp.max(jnp.abs(old_grad.massive_positions - new_grad1.massive_positions)))
print(jnp.max(jnp.abs(old_grad.massive_velocities - new_grad1.massive_velocities)))
print(jnp.max(jnp.abs(old_grad.log_gms - new_grad1.log_gms)))
print()
print(jnp.max(jnp.abs(old_grad.fixed_perturber_positions - new_grad1.fixed_perturber_positions)))
print(jnp.max(jnp.abs(old_grad.fixed_perturber_velocities - new_grad1.fixed_perturber_velocities)))
print(jnp.max(jnp.abs(old_grad.fixed_perturber_log_gms - new_grad1.fixed_perturber_log_gms)))

5.293955920339377e-23
4.235164736271502e-22
3.308722450212111e-24
2.117582368135751e-22
2.117582368135751e-22

2.889571106881684e-07
2.3131854644223462e-06
2.4137748649070033e-05


In [7]:
print(jnp.max(jnp.abs(old_grad.tracer_positions - new_grad2.tracer_positions)))
print(jnp.max(jnp.abs(old_grad.tracer_velocities - new_grad2.tracer_velocities)))
print(jnp.max(jnp.abs(old_grad.massive_positions - new_grad2.massive_positions)))
print(jnp.max(jnp.abs(old_grad.massive_velocities - new_grad2.massive_velocities)))
print(jnp.max(jnp.abs(old_grad.log_gms - new_grad2.log_gms)))
print()
print(jnp.max(jnp.abs(old_grad.fixed_perturber_positions - new_grad2.fixed_perturber_positions)))
print(jnp.max(jnp.abs(old_grad.fixed_perturber_velocities - new_grad2.fixed_perturber_velocities)))
print(jnp.max(jnp.abs(old_grad.fixed_perturber_log_gms - new_grad2.fixed_perturber_log_gms)))

5.293955920339377e-23
4.235164736271502e-22
3.308722450212111e-24
2.117582368135751e-22
2.117582368135751e-22

2.889571106881684e-07
2.3131854644223462e-06
2.4137748649070033e-05


In [14]:
s = SystemState(
    tracer_positions=jax.random.normal(jax.random.PRNGKey(seed), (n_tracer, 3)) * 1000,
    tracer_velocities=jax.random.normal(jax.random.PRNGKey(seed + 1), (n_tracer, 3)),
    massive_positions=jnp.empty((0, 3)),
    massive_velocities=jnp.empty((0, 3)),
    log_gms=jnp.empty((0,)),
    time=0.0,
    fixed_perturber_positions=jax.random.normal(jax.random.PRNGKey(seed + 5), (n_massive, 3)) * 1000,
    fixed_perturber_velocities=jax.random.normal(jax.random.PRNGKey(seed + 6), (n_massive, 3)),
    fixed_perturber_log_gms=jnp.log(jax.random.uniform(jax.random.PRNGKey(seed + 7), (n_massive,))),
    acceleration_func_kwargs={"c2": 100.0},
)

old_grad = jax.jacfwd(ppn_gravity_old)(s)
new_grad3 = jax.jacfwd(static_ppn_gravity_tracer)(s)

In [16]:
print(jnp.max(jnp.abs(old_grad.tracer_positions - new_grad3.tracer_positions)))
print(jnp.max(jnp.abs(old_grad.tracer_velocities - new_grad3.tracer_velocities)))
print()
print(jnp.max(jnp.abs(old_grad.fixed_perturber_positions - new_grad3.fixed_perturber_positions)))
print(jnp.max(jnp.abs(old_grad.fixed_perturber_velocities - new_grad3.fixed_perturber_velocities)))
print(jnp.max(jnp.abs(old_grad.fixed_perturber_log_gms - new_grad3.fixed_perturber_log_gms)))

1.0587911840678754e-22
4.235164736271502e-22

2.8757787026811276e-07
2.361786393866913e-06
2.407452246680441e-05


In [17]:
%timeit jax.jacfwd(ppn_gravity_old)(s).tracer_positions.block_until_ready()

482 ms ± 2.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
%timeit jax.jacfwd(static_ppn_gravity_tracer)(s).tracer_positions.block_until_ready()

9.24 ms ± 167 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
