# scratch work

In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from jorbit.utils.states import SystemState

In [8]:
# randomly create n massive amd m tracer particles,
# simulate them with reboundx and jorbit, compare results:

import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import numpy as np
import rebound
import reboundx

from jorbit.utils.states import SystemState
from jorbit.accelerations.gr import gr_full, gr_rewrite

np.random.seed(0)

n = 100
m = 10
massive_x = []
massive_v = []
ms = []
sim = rebound.Simulation()
for i in range(m):
    xs = np.random.normal(0, 1, 3) * 1000
    vs = np.random.normal(0, 1, 3)
    massive_x.append(xs)
    massive_v.append(vs)
    m = np.random.uniform(0, 1)
    ms.append(m)
    sim.add(m=m, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
tracer_x = []
tracer_v = []
for i in range(n):
    xs = np.random.normal(0, 1, 3) * 1000
    vs = np.random.normal(0, 1, 3)
    tracer_x.append(xs)
    tracer_v.append(vs)
    sim.add(m=0.0, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
rebx = reboundx.Extras(sim)
gr = rebx.load_force("gr_full")
gr.params["c"] = 10
gr.params["max_iterations"] = 100
rebx.add_force(gr)
sim.integrate(1e-300)
reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])
tracer_x = jnp.array(tracer_x)
tracer_v = jnp.array(tracer_v)
massive_x = jnp.array(massive_x)
massive_v = jnp.array(massive_v)
ms = jnp.array(ms)
reb_res

Array([[-3.74178066e-07, -2.17574232e-07, -4.62795007e-07],
       [-4.41465119e-08,  1.11621135e-07, -3.23634419e-07],
       [-4.26447795e-07,  3.93328890e-07,  5.65000355e-07],
       [ 9.00409824e-08, -9.69800439e-07,  5.62397494e-08],
       [-4.16905500e-07,  2.68264779e-08,  7.45328513e-07],
       [ 1.04843840e-06,  3.12514510e-08, -8.86802185e-08],
       [ 6.13800659e-07,  4.00539428e-07,  1.19005234e-06],
       [-5.28113169e-07,  8.20408323e-07, -1.85919727e-07],
       [-8.70732614e-08,  1.47359098e-07, -1.01832395e-06],
       [ 2.86274079e-07,  1.05362748e-06, -1.35957826e-07],
       [ 5.20593891e-07,  2.30113732e-07, -6.25808560e-07],
       [-7.84107183e-07,  8.46878745e-07,  6.72666912e-07],
       [ 6.18220391e-07, -1.74405651e-07, -6.17328353e-07],
       [-6.85685696e-07,  3.37120143e-07,  6.09802360e-07],
       [-3.97788883e-07, -7.60868643e-07, -9.87756982e-07],
       [ 3.06396923e-07, -2.81270983e-07, -3.10798808e-07],
       [ 7.62716265e-07, -7.81420136e-07

In [9]:
gr_full(
    x=jnp.array([[p.x, p.y, p.z] for p in sim.particles]),
    v=jnp.array([[p.vx, p.vy, p.vz] for p in sim.particles]),
    gms=jnp.array([p.m for p in sim.particles]),
    c2=100,
)

Array([[-3.74178066e-07, -2.17574232e-07, -4.62795007e-07],
       [-4.41465119e-08,  1.11621135e-07, -3.23634419e-07],
       [-4.26447795e-07,  3.93328890e-07,  5.65000355e-07],
       [ 9.00409824e-08, -9.69800439e-07,  5.62397494e-08],
       [-4.16905500e-07,  2.68264779e-08,  7.45328513e-07],
       [ 1.04843840e-06,  3.12514510e-08, -8.86802185e-08],
       [ 6.13800659e-07,  4.00539428e-07,  1.19005234e-06],
       [-5.28113169e-07,  8.20408323e-07, -1.85919727e-07],
       [-8.70732614e-08,  1.47359098e-07, -1.01832395e-06],
       [ 2.86274079e-07,  1.05362748e-06, -1.35957826e-07],
       [ 5.20593891e-07,  2.30113732e-07, -6.25808560e-07],
       [-7.84107183e-07,  8.46878745e-07,  6.72666912e-07],
       [ 6.18220391e-07, -1.74405651e-07, -6.17328353e-07],
       [-6.85685696e-07,  3.37120143e-07,  6.09802360e-07],
       [-3.97788883e-07, -7.60868643e-07, -9.87756982e-07],
       [ 3.06396923e-07, -2.81270983e-07, -3.10798808e-07],
       [ 7.62716265e-07, -7.81420136e-07

In [10]:
s = SystemState(
    tracer_positions=tracer_x,
    tracer_velocities=tracer_v,
    massive_positions=massive_x,
    massive_velocities=massive_v,
    log_gms=jnp.log(ms),
)
jorb_res = gr_rewrite(s, c2=100.0)
jorb_res

Array([[-3.74178066e-07, -2.17574232e-07, -4.62795007e-07],
       [-4.41465119e-08,  1.11621135e-07, -3.23634419e-07],
       [-4.26447795e-07,  3.93328890e-07,  5.65000355e-07],
       [ 9.00409824e-08, -9.69800439e-07,  5.62397494e-08],
       [-4.16905500e-07,  2.68264779e-08,  7.45328513e-07],
       [ 1.04843840e-06,  3.12514510e-08, -8.86802185e-08],
       [ 6.13800659e-07,  4.00539428e-07,  1.19005234e-06],
       [-5.28113169e-07,  8.20408323e-07, -1.85919727e-07],
       [-8.70732614e-08,  1.47359098e-07, -1.01832395e-06],
       [ 2.86274079e-07,  1.05362748e-06, -1.35957826e-07],
       [ 5.20679206e-07,  2.30155589e-07, -6.25919880e-07],
       [-7.84260309e-07,  8.47028953e-07,  6.72791890e-07],
       [ 6.18329692e-07, -1.74440217e-07, -6.17460394e-07],
       [-6.85801458e-07,  3.37169168e-07,  6.09890636e-07],
       [-3.97886148e-07, -7.61020145e-07, -9.87975340e-07],
       [ 3.06441594e-07, -2.81301724e-07, -3.10850057e-07],
       [ 7.62824589e-07, -7.81532138e-07

In [11]:
%timeit gr_rewrite(s, c2=100.0).block_until_ready()

116 μs ± 940 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [12]:
%%timeit
gr_full(
    x=jnp.array([[p.x, p.y, p.z] for p in sim.particles]),
    v=jnp.array([[p.vx, p.vy, p.vz] for p in sim.particles]),
    gms=jnp.array([p.m for p in sim.particles]),
    c2=100,
).block_until_ready()

1.78 ms ± 14.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
