# scratch work

In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from astropy.time import Time
import astropy.units as u
from astropy.coordinates import SkyCoord
from astroquery.jplhorizons import Horizons

from jorbit.utils.horizons import (
    horizons_bulk_vector_query,
    horizons_bulk_astrometry_query,
)

t0 = Time("2024-12-24T00:00:00", scale="utc")

In [2]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import rebound
import reboundx

sim = rebound.Simulation()
sim.add(m=1.0, x=-1.0, y=1.01, z=-0.05, vx=0.01, vy=0.02, vz=0.4)
sim.add(m=1.3, x=1.0, y=-2.01, z=0.04, vx=-0.02, vy=-0.05, vz=-0.6)
sim.add(m=0.01, x=10.0, y=-2.01, z=0.04, vx=-0.02, vy=-0.05, vz=-0.6)
sim.add(m=0.01, x=10.0, y=-2.01, z=0.4, vx=-0.02, vy=-0.05, vz=0.6)
rebx = reboundx.Extras(sim)
gr = rebx.load_force("gr_full")
gr.params["c"] = 10
gr.params["max_iterations"] = 100
rebx.add_force(gr)
sim.integrate(1e-300)

reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])
reb_res

Array([[ 0.05419503, -0.08164417,  0.00249506],
       [-0.04134904,  0.06280272, -0.00191058],
       [-0.02331262,  0.00210376,  0.07367954],
       [-0.02348517,  0.00205179, -0.07450105]], dtype=float64)

In [3]:
from jorbit.accelerations.gr import ppn_gravity
from jorbit.utils.states import SystemState

sim = rebound.Simulation()
sim.add(m=1.0, x=-1.0, y=1.01, z=-0.05, vx=0.01, vy=0.02, vz=0.4)
sim.add(m=1.3, x=1.0, y=-2.01, z=0.04, vx=-0.02, vy=-0.05, vz=-0.6)
sim.add(m=0.01, x=10.0, y=-2.01, z=0.04, vx=-0.02, vy=-0.05, vz=-0.6)
sim.add(m=0.01, x=10.0, y=-2.01, z=0.4, vx=-0.02, vy=-0.05, vz=0.6)
sim.integrate(1e-300)
x0 = jnp.array([[p.x, p.y, p.z] for p in sim.particles])
v0 = jnp.array([[p.vx, p.vy, p.vz] for p in sim.particles])
gms = jnp.array([p.m for p in sim.particles])


s = SystemState(
    tracer_positions=jnp.empty((0, 3)),
    tracer_velocities=jnp.empty((0, 3)),
    massive_positions=x0,
    massive_velocities=v0,
    log_gms=jnp.log(gms),
    time=0.0,
    acceleration_func_kwargs={"c2": 100.0},
)

jorb_res = ppn_gravity(s)

reb_res - jorb_res

Array([[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 3.46944695e-18,  0.00000000e+00, -4.16333634e-17],
       [ 3.46944695e-18, -4.33680869e-19,  2.77555756e-17]],      dtype=float64)

In [4]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import numpy as np
import rebound
import reboundx

from jorbit.utils.states import SystemState
from jorbit.accelerations.gr import ppn_gravity


def _agreement_w_reboundx(n_tracer, n_massive, seed):
    np.random.seed(seed)
    n = n_tracer
    m = n_massive
    massive_x = []
    massive_v = []
    ms = []
    sim = rebound.Simulation()
    for i in range(m):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        massive_x.append(xs)
        massive_v.append(vs)
        m = np.random.uniform(0, 1)
        ms.append(m)
        sim.add(m=m, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    tracer_x = []
    tracer_v = []
    for i in range(n):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        tracer_x.append(xs)
        tracer_v.append(vs)
        sim.add(m=0.0, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    rebx = reboundx.Extras(sim)
    gr = rebx.load_force("gr_full")
    gr.params["c"] = 10
    gr.params["max_iterations"] = 100
    rebx.add_force(gr)
    sim.integrate(1e-300)
    reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

    tracer_x = jnp.array(tracer_x)
    tracer_v = jnp.array(tracer_v)
    massive_x = jnp.array(massive_x)
    massive_v = jnp.array(massive_v)
    ms = jnp.array(ms)
    s = SystemState(
        tracer_positions=tracer_x,
        tracer_velocities=tracer_v,
        massive_positions=massive_x,
        massive_velocities=massive_v,
        log_gms=jnp.log(ms),
        time=0.0,
        acceleration_func_kwargs={"c2": 100.0},
    )
    jorb_res = ppn_gravity(s)

    print(jnp.max(jnp.abs(jorb_res - reb_res)))
    # assert jnp.allclose(jorb_res, reb_res, atol=1e-14, rtol=1e-14)
    return reb_res, jorb_res


a, b = _agreement_w_reboundx(n_tracer=100, n_massive=10, seed=0)

7.147368626309915e-08


In [5]:
(a - b)[:10]

Array([[ 0.00000000e+00,  7.94093388e-23, -5.29395592e-23],
       [ 0.00000000e+00,  3.97046694e-23, -5.29395592e-23],
       [-5.29395592e-23,  0.00000000e+00,  1.05879118e-22],
       [ 2.64697796e-23,  0.00000000e+00,  7.27918939e-23],
       [ 5.29395592e-23, -3.30872245e-23,  1.05879118e-22],
       [ 0.00000000e+00,  4.63221143e-23,  3.97046694e-23],
       [ 0.00000000e+00, -1.05879118e-22,  2.11758237e-22],
       [ 0.00000000e+00,  0.00000000e+00, -2.64697796e-23],
       [-9.26442286e-23, -5.29395592e-23, -2.11758237e-22],
       [ 5.29395592e-23,  0.00000000e+00, -5.29395592e-23]],      dtype=float64)

In [6]:
(a - b)[10:]

Array([[-8.53144643e-11, -4.18565020e-11,  1.11319405e-10],
       [ 1.53126240e-10, -1.50207867e-10, -1.24978430e-10],
       [-1.09300818e-10,  3.45658866e-11,  1.32041505e-10],
       [ 1.15761317e-10, -4.90254912e-11, -8.82760062e-11],
       [ 9.72647830e-11,  1.51501252e-10,  2.18358760e-10],
       [-4.46705226e-11,  3.07411021e-11,  5.12492626e-11],
       [-1.08323697e-10,  1.12002051e-10, -1.40370236e-11],
       [ 6.53541058e-11,  3.04522994e-11, -4.29140457e-11],
       [-9.63718411e-11, -4.40308553e-11,  9.77835546e-11],
       [ 2.72747710e-10,  1.57570428e-10, -1.68741809e-10],
       [ 5.42339883e-10,  4.41772073e-11, -3.59055280e-12],
       [-1.09678375e-10,  1.83371150e-11,  2.09156781e-10],
       [-1.94715152e-10, -4.58208335e-11,  1.48004731e-10],
       [-9.05542892e-11,  7.51897235e-11, -1.46039657e-10],
       [-1.28251309e-10, -1.17008323e-10,  1.05698839e-10],
       [-1.57500340e-10, -9.70675728e-11,  2.03539243e-11],
       [-1.69858924e-10,  7.75434792e-11

In [7]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import numpy as np
import rebound
import reboundx

from jorbit.utils.states import SystemState
from jorbit.accelerations.gr import ppn_gravity


def _agreement_w_reboundx(n_tracer, n_massive, seed):
    np.random.seed(seed)
    n = n_tracer
    m = n_massive
    massive_x = []
    massive_v = []
    ms = []
    sim = rebound.Simulation()
    for i in range(m):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        massive_x.append(xs)
        massive_v.append(vs)
        m = np.random.uniform(0, 1)
        ms.append(m)
        sim.add(m=m, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    tracer_x = []
    tracer_v = []
    for i in range(n):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        tracer_x.append(xs)
        tracer_v.append(vs)
        sim.add(m=0.0, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    rebx = reboundx.Extras(sim)
    gr = rebx.load_force("gr_full")
    gr.params["c"] = 10
    gr.params["max_iterations"] = 100
    rebx.add_force(gr)
    sim.integrate(1e-300)
    reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

    tracer_x = jnp.array(tracer_x)
    tracer_v = jnp.array(tracer_v)
    massive_x = jnp.array(massive_x)
    massive_v = jnp.array(massive_v)
    ms = jnp.array(ms)
    s = SystemState(
        tracer_positions=jnp.empty((0, 3)),
        tracer_velocities=jnp.empty((0, 3)),
        massive_positions=jnp.concatenate([massive_x, tracer_x]),
        massive_velocities=jnp.concatenate([massive_v, tracer_v]),
        log_gms=jnp.log(jnp.concatenate([ms, jnp.zeros(n)])),
        time=0.0,
        acceleration_func_kwargs={"c2": 100.0},
    )
    jorb_res = ppn_gravity(s)

    print(jnp.max(jnp.abs(jorb_res - reb_res)))
    # assert jnp.allclose(jorb_res, reb_res, atol=1e-14, rtol=1e-14)
    return reb_res, jorb_res


a, b = _agreement_w_reboundx(n_tracer=100, n_massive=10, seed=0)

2.710505431213761e-20


In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import numpy as np
import rebound
import reboundx

from jorbit.utils.states import SystemState
from jorbit.accelerations.gr import ppn_gravity


def _agreement_w_reboundx(n_tracer, n_massive, seed):
    np.random.seed(seed)
    n = n_tracer
    m = n_massive
    massive_x = []
    massive_v = []
    ms = []
    sim = rebound.Simulation()
    for i in range(m):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        massive_x.append(xs)
        massive_v.append(vs)
        m = np.random.uniform(0, 1)
        ms.append(m)
        sim.add(m=m, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    tracer_x = []
    tracer_v = []
    for i in range(n):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        tracer_x.append(xs)
        tracer_v.append(vs)
        sim.add(m=0.0, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    rebx = reboundx.Extras(sim)
    gr = rebx.load_force("gr_full")
    gr.params["c"] = 10
    gr.params["max_iterations"] = 100
    rebx.add_force(gr)
    sim.integrate(1e-300)
    reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

    tracer_x = jnp.array(tracer_x)
    tracer_v = jnp.array(tracer_v)
    massive_x = jnp.array(massive_x)
    massive_v = jnp.array(massive_v)
    ms = jnp.array(ms)
    s = SystemState(
        tracer_positions=tracer_x,
        tracer_velocities=tracer_v,
        massive_positions=massive_x,
        massive_velocities=massive_v,
        log_gms=jnp.log(ms),
        time=0.0,
        acceleration_func_kwargs={"c2": 100.0},
    )
    jorb_res = ppn_gravity(s)

    print(jnp.max(jnp.abs(jorb_res - reb_res)))
    # assert jnp.allclose(jorb_res, reb_res, atol=1e-14, rtol=1e-14)
    return reb_res, jorb_res


a, b = _agreement_w_reboundx(n_tracer=100, n_massive=10, seed=0)

2.79268421706764e-08
2.7967285075947977e-09
-8.479443502387828e-08

0.00014126397217896714


In [2]:
jnp.eye(10)

Array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float64)

In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import numpy as np
import rebound
import reboundx

from jorbit.utils.states import SystemState
from jorbit.accelerations.gr import ppn_gravity


def _agreement_w_reboundx(n_tracer, n_massive, seed):
    np.random.seed(seed)
    n = n_tracer
    m = n_massive
    massive_x = []
    massive_v = []
    ms = []
    sim = rebound.Simulation()
    for i in range(m):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        massive_x.append(xs)
        massive_v.append(vs)
        m = np.random.uniform(0, 1)
        ms.append(m)
        sim.add(m=m, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    tracer_x = []
    tracer_v = []
    for i in range(n):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        tracer_x.append(xs)
        tracer_v.append(vs)
        sim.add(m=0.0, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    rebx = reboundx.Extras(sim)
    gr = rebx.load_force("gr_full")
    gr.params["c"] = 10
    gr.params["max_iterations"] = 100
    rebx.add_force(gr)
    sim.integrate(1e-300)
    reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

    tracer_x = jnp.array(tracer_x)
    tracer_v = jnp.array(tracer_v)
    massive_x = jnp.array(massive_x)
    massive_v = jnp.array(massive_v)
    ms = jnp.array(ms)
    s = SystemState(
        tracer_positions=jnp.empty((0, 3)),
        tracer_velocities=jnp.empty((0, 3)),
        massive_positions=jnp.concatenate([massive_x, tracer_x]),
        massive_velocities=jnp.concatenate([massive_v, tracer_v]),
        log_gms=jnp.log(jnp.concatenate([ms, jnp.zeros(n)])),
        time=0.0,
        acceleration_func_kwargs={"c2": 100.0},
    )
    jorb_res = ppn_gravity(s)

    print(jnp.max(jnp.abs(jorb_res - reb_res)))
    # assert jnp.allclose(jorb_res, reb_res, atol=1e-14, rtol=1e-14)
    return reb_res, jorb_res


a, b = _agreement_w_reboundx(n_tracer=100, n_massive=10, seed=0)

2.710505431213761e-20


In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import numpy as np
import rebound
import reboundx

from jorbit.utils.states import SystemState
from jorbit.accelerations.gr import ppn_gravity


def _agreement_w_reboundx(n_tracer, n_massive, seed):
    np.random.seed(seed)
    n = n_tracer
    m = n_massive
    massive_x = []
    massive_v = []
    ms = []
    sim = rebound.Simulation()
    for i in range(m):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        massive_x.append(xs)
        massive_v.append(vs)
        m = np.random.uniform(0, 1)
        ms.append(m)
        sim.add(m=m, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    tracer_x = []
    tracer_v = []
    for i in range(n):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        tracer_x.append(xs)
        tracer_v.append(vs)
        sim.add(m=0.0, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    rebx = reboundx.Extras(sim)
    gr = rebx.load_force("gr_full")
    gr.params["c"] = 10
    gr.params["max_iterations"] = 100
    rebx.add_force(gr)
    sim.integrate(1e-300)
    reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

    tracer_x = jnp.array(tracer_x)
    tracer_v = jnp.array(tracer_v)
    massive_x = jnp.array(massive_x)
    massive_v = jnp.array(massive_v)
    ms = jnp.array(ms)
    s = SystemState(
        tracer_positions=tracer_x,
        tracer_velocities=tracer_v,
        massive_positions=massive_x,
        massive_velocities=massive_v,
        log_gms=jnp.log(ms),
        time=0.0,
        acceleration_func_kwargs={"c2": 100.0},
    )
    return ppn_gravity(s)


# jnp.min(jnp.abs(_agreement_w_reboundx(n_tracer=100, n_massive=10, seed=0) - 0.00013187992071157616))
float(_agreement_w_reboundx(n_tracer=100, n_massive=10, seed=0)[-1])

3.296998017789404e-05

In [None]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import numpy as np
import rebound
import reboundx

from jorbit.utils.states import SystemState
from jorbit.accelerations.gr import ppn_gravity


def _agreement_w_reboundx(n_tracer, n_massive, seed):
    np.random.seed(seed)
    n = n_tracer
    m = n_massive
    massive_x = []
    massive_v = []
    ms = []
    sim = rebound.Simulation()
    for i in range(m):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        massive_x.append(xs)
        massive_v.append(vs)
        m = np.random.uniform(0, 1)
        ms.append(m)
        sim.add(m=m, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    tracer_x = []
    tracer_v = []
    for i in range(n):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        tracer_x.append(xs)
        tracer_v.append(vs)
        sim.add(m=0.0, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    rebx = reboundx.Extras(sim)
    gr = rebx.load_force("gr_full")
    gr.params["c"] = 10
    gr.params["max_iterations"] = 100
    rebx.add_force(gr)
    sim.integrate(1e-300)
    reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

    tracer_x = jnp.array(tracer_x)
    tracer_v = jnp.array(tracer_v)
    massive_x = jnp.array(massive_x)
    massive_v = jnp.array(massive_v)
    ms = jnp.array(ms)
    s = SystemState(
        tracer_positions=jnp.empty((0, 3)),
        tracer_velocities=jnp.empty((0, 3)),
        massive_positions=jnp.concatenate([massive_x, tracer_x]),
        massive_velocities=jnp.concatenate([massive_v, tracer_v]),
        log_gms=jnp.log(jnp.concatenate([ms, jnp.zeros(n)])),
        time=0.0,
        acceleration_func_kwargs={"c2": 100.0},
    )
    # jorb_res = ppn_gravity(s)

    # print(jnp.max(jnp.abs(jorb_res - reb_res)))
    # # assert jnp.allclose(jorb_res, reb_res, atol=1e-14, rtol=1e-14)
    # return reb_res, jorb_res
    return ppn_gravity(s)


m = 110
mask = ~jnp.eye(m, dtype=bool)
jnp.sum(
    _agreement_w_reboundx(n_tracer=100, n_massive=10, seed=0),
    axis=1,
    where=mask[:, :, None],
)

Array([[-1.00181285e-07, -9.92337600e-08, -1.49696229e-07],
       [-4.68651915e-09,  7.48994349e-09, -2.93650930e-08],
       [ 2.54316679e-08,  1.82982234e-07,  1.95140584e-07],
       [-6.93617574e-09, -1.42213949e-07, -1.11473269e-08],
       [-2.56851714e-08,  5.89171501e-09,  7.73082814e-08],
       [ 1.20792703e-07,  1.80795272e-08,  3.31919884e-09],
       [ 4.80993758e-08,  5.04874913e-08,  1.24916086e-07],
       [-2.19974739e-08,  6.78907585e-08, -2.17211956e-08],
       [ 5.32549818e-09,  4.11352534e-08, -9.14968132e-08],
       [ 1.20520053e-08,  1.08717933e-07, -2.24045282e-08],
       [ 5.34404750e-08,  4.92840679e-09, -3.16358810e-08],
       [-3.49895420e-08,  3.86674876e-08,  3.90187923e-08],
       [ 7.26378998e-08, -6.18252023e-08, -1.13567107e-07],
       [-8.40423780e-08,  4.48138978e-08,  5.82657322e-08],
       [-2.82837824e-08, -8.40230882e-08, -9.40556250e-08],
       [ 6.37801308e-08, -1.75788125e-08, -3.18052609e-08],
       [ 1.96418474e-08, -7.90676430e-08

In [None]:
def _agreement_w_reboundx(n_tracer, n_massive, seed):
    np.random.seed(seed)
    n = n_tracer
    m = n_massive
    massive_x = []
    massive_v = []
    ms = []
    sim = rebound.Simulation()
    for i in range(m):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        massive_x.append(xs)
        massive_v.append(vs)
        m = np.random.uniform(0, 1)
        ms.append(m)
        sim.add(m=m, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    tracer_x = []
    tracer_v = []
    for i in range(n):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        tracer_x.append(xs)
        tracer_v.append(vs)
        sim.add(m=0.0, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    rebx = reboundx.Extras(sim)
    gr = rebx.load_force("gr_full")
    gr.params["c"] = 10
    gr.params["max_iterations"] = 100
    rebx.add_force(gr)
    sim.integrate(1e-300)
    reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

    tracer_x = jnp.array(tracer_x)
    tracer_v = jnp.array(tracer_v)
    massive_x = jnp.array(massive_x)
    massive_v = jnp.array(massive_v)
    ms = jnp.array(ms)
    s = SystemState(
        tracer_positions=jnp.empty((0, 3)),
        tracer_velocities=jnp.empty((0, 3)),
        massive_positions=jnp.concatenate([massive_x, tracer_x]),
        massive_velocities=jnp.concatenate([massive_v, tracer_v]),
        log_gms=jnp.log(jnp.concatenate([ms, jnp.zeros(n)])),
        time=0.0,
        acceleration_func_kwargs={"c2": 100.0},
    )

In [10]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import numpy as np
import rebound
import reboundx

from jorbit.utils.states import SystemState
from jorbit.accelerations.gr import ppn_gravity

seed = 0
n_tracer = 599
n_massive = 1

np.random.seed(seed)
n = n_tracer
m = n_massive
massive_x = []
massive_v = []
ms = []
sim = rebound.Simulation()
for i in range(m):
    xs = np.random.normal(0, 1, 3) * 1000
    vs = np.random.normal(0, 1, 3)
    massive_x.append(xs)
    massive_v.append(vs)
    m = np.random.uniform(0, 1)
    ms.append(m)
    sim.add(m=m, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
tracer_x = []
tracer_v = []
for i in range(n):
    xs = np.random.normal(0, 1, 3) * 1000
    vs = np.random.normal(0, 1, 3)
    tracer_x.append(xs)
    tracer_v.append(vs)
    sim.add(m=0.0, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
rebx = reboundx.Extras(sim)
gr = rebx.load_force("gr_full")
gr.params["c"] = 10
gr.params["max_iterations"] = 100
rebx.add_force(gr)
sim.integrate(1e-300)
reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

tracer_x = jnp.array(tracer_x)
tracer_v = jnp.array(tracer_v)
massive_x = jnp.array(massive_x)
massive_v = jnp.array(massive_v)
ms = jnp.array(ms)
s = SystemState(
    tracer_positions=jnp.empty((0, 3)),
    tracer_velocities=jnp.empty((0, 3)),
    massive_positions=jnp.concatenate([massive_x, tracer_x]),
    massive_velocities=jnp.concatenate([massive_v, tracer_v]),
    log_gms=jnp.log(jnp.concatenate([ms, jnp.zeros(n)])),
    time=0.0,
    acceleration_func_kwargs={"c2": 100.0},
)

jorb_res = ppn_gravity(s)

In [11]:
%timeit ppn_gravity(s).block_until_ready()

5.07 ms ± 10.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import numpy as np
import rebound
import reboundx

from jorbit.utils.states import SystemState
from jorbit.accelerations.gr import ppn_gravity

seed = 0
n_tracer = 100
n_massive = 10

np.random.seed(seed)
n = n_tracer
m = n_massive
massive_x = []
massive_v = []
ms = []
sim = rebound.Simulation()
for i in range(m):
    xs = np.random.normal(0, 1, 3) * 1000
    vs = np.random.normal(0, 1, 3)
    massive_x.append(xs)
    massive_v.append(vs)
    m = np.random.uniform(0, 1)
    ms.append(m)
    sim.add(m=m, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
tracer_x = []
tracer_v = []
for i in range(n):
    xs = np.random.normal(0, 1, 3) * 1000
    vs = np.random.normal(0, 1, 3)
    tracer_x.append(xs)
    tracer_v.append(vs)
    sim.add(m=0.0, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
rebx = reboundx.Extras(sim)
gr = rebx.load_force("gr_full")
gr.params["c"] = 10
gr.params["max_iterations"] = 100
rebx.add_force(gr)
sim.integrate(1e-300)
reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

tracer_x = jnp.array(tracer_x)
tracer_v = jnp.array(tracer_v)
massive_x = jnp.array(massive_x)
massive_v = jnp.array(massive_v)
ms = jnp.array(ms)


con_state = SystemState(
    tracer_positions=jnp.empty((0, 3)),
    tracer_velocities=jnp.empty((0, 3)),
    massive_positions=jnp.concatenate([massive_x, tracer_x]),
    massive_velocities=jnp.concatenate([massive_v, tracer_v]),
    log_gms=jnp.log(jnp.concatenate([ms, jnp.zeros(n)])),
    time=0.0,
    acceleration_func_kwargs={"c2": 100.0},
)

sep_state = SystemState(
    tracer_positions=tracer_x,
    tracer_velocities=tracer_v,
    massive_positions=massive_x,
    massive_velocities=massive_v,
    log_gms=jnp.log(ms),
    time=0.0,
    acceleration_func_kwargs={"c2": 100.0},
)

sep_args = ppn_gravity(sep_state)
con_args = ppn_gravity(con_state)

In [33]:
def velocity_terms(
    massive_x,
    massive_v,
    tracer_x,
    tracer_v,
    a_newt_massive,
    a_newt_tracer,
    gms,
    c2,
    r_massive,
    r2_massive,
    r3_massive,
    r_tracer,
    r2_tracer,
    r3_tracer,
    mask_massive,
    dv_massive,
    dv_tracer,
    dx_massive,
    dx_tracer,
):
    M = massive_x.shape[0]
    T = tracer_x.shape[0]
    # Compute velocity-dependent terms
    v2_massive = jnp.sum(massive_v * massive_v, axis=-1)  # (M,)
    v2_tracer = jnp.sum(tracer_v * tracer_v, axis=-1)  # (T,)
    vdot_mm = jnp.sum(massive_v[:, None, :] * massive_v[None, :, :], axis=-1)  # (M,M)
    vdot_tm = jnp.sum(tracer_v[:, None, :] * massive_v[None, :, :], axis=-1)  # (T,M)

    # Compute GR correction terms for massive particles
    a1_mm = jnp.sum((4.0 / c2) * gms / r_massive, axis=1, where=mask_massive)
    a1_mm = jnp.broadcast_to(a1_mm, (M, M)).T

    a2_mm = jnp.sum((1.0 / c2) * gms / r_massive, axis=1, where=mask_massive)
    a2_mm = jnp.broadcast_to(a2_mm, (M, M))

    a3_mm = jnp.broadcast_to(-v2_massive / c2, (M, M)).T
    a4_mm = -2.0 * jnp.broadcast_to(v2_massive, (M, M)) / c2
    a5_mm = (4.0 / c2) * vdot_mm

    a6_0_mm = jnp.sum(dx_massive * massive_v[None, :, :], axis=-1)
    a6_mm = (3.0 / (2 * c2)) * (a6_0_mm**2) / r2_massive

    a7_mm = jnp.sum(dx_massive * a_newt_massive[None, :, :], axis=-1) / (2 * c2)

    factor1_mm = a1_mm + a2_mm + a3_mm + a4_mm + a5_mm + a6_mm + a7_mm
    if factor1_mm.shape[0] != 0:
        print(f"factor1_mm: {factor1_mm[-1,0]}")
    print(f"factor1_mm shape: {factor1_mm.shape}")
    part1_mm = (
        jnp.broadcast_to(gms, (M, M))[:, :, None]
        * dx_massive
        * factor1_mm[:, :, None]
        / r3_massive[:, :, None]
    )

    factor2_massive = jnp.sum(
        dx_massive * (4 * massive_v[:, None, :] - 3 * massive_v[None, :, :]), axis=-1
    )
    part2_mm = (
        jnp.broadcast_to(gms, (M, M))[:, :, None]
        * (
            factor2_massive[:, :, None] * dv_massive / r3_massive[:, :, None]
            + 7.0 / 2.0 * a_newt_massive[None, :, :] / r_massive[:, :, None]
        )
        / c2
    )

    a_const_massive = jnp.sum(
        part1_mm + part2_mm, axis=1, where=mask_massive[:, :, None]
    )

    # Compute GR correction terms for tracer particles
    a1_tm = (4.0 / c2) * gms[None, :] / r_tracer
    print(f"a1_tm shape: {a1_tm.shape}")
    a2_tm = (1.0 / c2) * gms[None, :] / r_tracer
    a3_tm = -v2_tracer[:, None] / c2
    a4_tm = -2.0 * v2_massive[None, :] / c2
    a5_tm = (4.0 / c2) * vdot_tm

    a6_0_tm = jnp.sum(dx_tracer * massive_v[None, :, :], axis=-1)
    a6_tm = (3.0 / (2 * c2)) * (a6_0_tm**2) / r2_tracer

    a7_tm = jnp.sum(dx_tracer * a_newt_tracer[:, None, :], axis=-1) / (2 * c2)

    factor1_tm = a1_tm + a2_tm + a3_tm + a4_tm + a5_tm + a6_tm + a7_tm
    # if factor1_tm.shape[0] != 0:
    #     print(f"factor1_tm: {factor1_tm[-1,0]}")
    # print(f"factor1_tm shape: {factor1_tm.shape}")
    # part1_tm = (
    #     gms[None, :, None] * dx_tracer * factor1_tm[:, :, None] / r3_tracer[:, :, None]
    # )

    # factor2_tracer = jnp.sum(
    #     dx_tracer * (4 * tracer_v[:, None, :] - 3 * massive_v[None, :, :]), axis=-1
    # )
    # part2_tm = (
    #     gms[None, :, None]
    #     * (
    #         factor2_tracer[:, :, None] * dv_tracer / r3_tracer[:, :, None]
    #         + 7.0 / 2.0 * a_newt_tracer[:, None, :] / r_tracer[:, :, None]
    #     )
    #     / c2
    # )
    # Correct masking for tracer particles
    combined_x = jnp.concatenate([massive_x, tracer_x], axis=0)
    num_combined = combined_x.shape[0]
    mask_combined = jnp.ones((num_combined, num_combined), dtype=bool)
    mask_combined = mask_combined.at[jnp.diag_indices(num_combined)].set(False)
    mask_tm = mask_combined[M:, :M]
    print("***")
    print(mask_tm.shape)

    part1_tm = (
        gms[None, :, None] * dx_tracer * factor1_tm[:, :, None] / r3_tracer[:, :, None]
    )

    factor2_tracer = jnp.sum(
        dx_tracer * (4 * tracer_v[:, None, :] - 3 * massive_v[None, :, :]), axis=-1
    )

    part2_tm = (
        gms[None, :, None]
        * (
            factor2_tracer[:, :, None] * dv_tracer / r3_tracer[:, :, None]
            + 7.0 / 2.0 * a_newt_tracer[:, None, :] / r_tracer[:, :, None]
        )
        / c2
    )

    a_const_tracer = jnp.sum(part1_tm + part2_tm, axis=1)

    return jnp.concatenate([a_const_massive, a_const_tracer], axis=0)


sep = velocity_terms(*sep_args)
print()
con = velocity_terms(*con_args)

factor1_mm: -0.23537139158964143
factor1_mm shape: (10, 10)
a1_tm shape: (100, 10)
***
(100, 10)

factor1_mm: -0.2975914195021021
factor1_mm shape: (110, 110)
a1_tm shape: (0, 110)
***
(0, 110)


In [30]:
con[-1], sep[-1]

(Array([7.96239144e-08, 4.48597162e-08, 2.90424477e-08], dtype=float64),
 Array([7.92793152e-08, 4.50868662e-08, 2.88633816e-08], dtype=float64))

In [18]:
con_args[4][-1], sep_args[5][-1]

(Array([-1.47394607e-06,  1.02855664e-06, -8.22896051e-07], dtype=float64),
 Array([-1.47394607e-06,  1.02855664e-06, -8.22896051e-07], dtype=float64))

In [3]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import numpy as np
import rebound
import reboundx

from jorbit.utils.states import SystemState
from jorbit.accelerations.gr import ppn_gravity


def _agreement_w_reboundx(n_tracer, n_massive, seed):
    np.random.seed(seed)
    n = n_tracer
    m = n_massive
    massive_x = []
    massive_v = []
    ms = []
    sim = rebound.Simulation()
    for i in range(m):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        massive_x.append(xs)
        massive_v.append(vs)
        m = np.random.uniform(0, 1)
        ms.append(m)
        sim.add(m=m, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    tracer_x = []
    tracer_v = []
    for i in range(n):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        tracer_x.append(xs)
        tracer_v.append(vs)
        sim.add(m=0.0, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    rebx = reboundx.Extras(sim)
    gr = rebx.load_force("gr_full")
    gr.params["c"] = 10
    gr.params["max_iterations"] = 100
    rebx.add_force(gr)
    sim.integrate(1e-300)
    reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

    tracer_x = jnp.array(tracer_x)
    tracer_v = jnp.array(tracer_v)
    massive_x = jnp.array(massive_x)
    massive_v = jnp.array(massive_v)
    ms = jnp.array(ms)
    s = SystemState(
        tracer_positions=jnp.empty((0, 3)),
        tracer_velocities=jnp.empty((0, 3)),
        massive_positions=jnp.concatenate([massive_x, tracer_x]),
        massive_velocities=jnp.concatenate([massive_v, tracer_v]),
        log_gms=jnp.log(jnp.concatenate([ms, jnp.zeros(n)])),
        time=0.0,
        acceleration_func_kwargs={"c2": 100.0},
    )
    jorb_res = ppn_gravity(s)

    print(jnp.max(jnp.abs(jorb_res - reb_res)))
    assert jnp.allclose(jorb_res, reb_res, atol=1e-14, rtol=1e-14)
    return reb_res, jorb_res


a, b = _agreement_w_reboundx(n_tracer=100, n_massive=10, seed=0)

2.710505431213761e-20


In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import numpy as np
import rebound
import reboundx

from jorbit.utils.states import SystemState
from jorbit.accelerations.gr import ppn_gravity
from jorbit.accelerations.newtonian import newtonian_gravity


def _gr_agreement_w_reboundx(n_tracer, n_massive, seed):
    np.random.seed(seed)
    n = n_tracer
    m = n_massive
    massive_x = []
    massive_v = []
    ms = []
    sim = rebound.Simulation()
    for i in range(m):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        massive_x.append(xs)
        massive_v.append(vs)
        m = np.random.uniform(0, 1)
        ms.append(m)
        sim.add(m=m, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    tracer_x = []
    tracer_v = []
    for i in range(n):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        tracer_x.append(xs)
        tracer_v.append(vs)
        sim.add(m=0.0, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    rebx = reboundx.Extras(sim)
    gr = rebx.load_force("gr_full")
    gr.params["c"] = 10
    gr.params["max_iterations"] = 100
    rebx.add_force(gr)
    sim.integrate(1e-300)
    reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

    tracer_x = jnp.array(tracer_x)
    tracer_v = jnp.array(tracer_v)
    massive_x = jnp.array(massive_x)
    massive_v = jnp.array(massive_v)
    ms = jnp.array(ms)
    s = SystemState(
        tracer_positions=tracer_x,
        tracer_velocities=tracer_v,
        massive_positions=massive_x,
        massive_velocities=massive_v,
        log_gms=jnp.log(ms),
        time=0.0,
        acceleration_func_kwargs={"c2": 100.0},
    )
    # s = SystemState(
    #     tracer_positions=jnp.empty((0, 3)),
    #     tracer_velocities=jnp.empty((0, 3)),
    #     massive_positions=jnp.concatenate([massive_x, tracer_x]),
    #     massive_velocities=jnp.concatenate([massive_v, tracer_v]),
    #     log_gms=jnp.log(jnp.concatenate([ms, jnp.zeros(n)])),
    #     time=0.0,
    #     acceleration_func_kwargs={"c2": 100.0},
    # )
    jorb_res = ppn_gravity(s)

    print(jnp.max(jnp.abs(jorb_res - reb_res)))
    assert jnp.allclose(jorb_res, reb_res, atol=1e-14, rtol=1e-14)


_gr_agreement_w_reboundx(n_tracer=10, n_massive=10, seed=0)

5.293955920339377e-22
