# scratch work

In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from astropy.time import Time
import astropy.units as u
from astropy.coordinates import SkyCoord
from astroquery.jplhorizons import Horizons

from jorbit.utils.horizons import (
    horizons_bulk_vector_query,
    horizons_bulk_astrometry_query,
)

t0 = Time("2024-12-24T00:00:00", scale="utc")

In [2]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import rebound
import reboundx

sim = rebound.Simulation()
sim.add(m=1.0, x=-1.0, y=1.01, z=-0.05, vx=0.01, vy=0.02, vz=0.4)
sim.add(m=1.3, x=1.0, y=-2.01, z=0.04, vx=-0.02, vy=-0.05, vz=-0.6)
sim.add(m=0.01, x=10.0, y=-2.01, z=0.04, vx=-0.02, vy=-0.05, vz=-0.6)
sim.add(m=0.01, x=10.0, y=-2.01, z=0.4, vx=-0.02, vy=-0.05, vz=0.6)
rebx = reboundx.Extras(sim)
gr = rebx.load_force("gr_full")
gr.params["c"] = 10
gr.params["max_iterations"] = 100
rebx.add_force(gr)
sim.integrate(1e-300)

reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])
reb_res

Array([[ 0.05419503, -0.08164417,  0.00249506],
       [-0.04134904,  0.06280272, -0.00191058],
       [-0.02331262,  0.00210376,  0.07367954],
       [-0.02348517,  0.00205179, -0.07450105]], dtype=float64)

In [3]:
from jorbit.accelerations.gr import ppn_gravity
from jorbit.utils.states import SystemState

sim = rebound.Simulation()
sim.add(m=1.0, x=-1.0, y=1.01, z=-0.05, vx=0.01, vy=0.02, vz=0.4)
sim.add(m=1.3, x=1.0, y=-2.01, z=0.04, vx=-0.02, vy=-0.05, vz=-0.6)
sim.add(m=0.01, x=10.0, y=-2.01, z=0.04, vx=-0.02, vy=-0.05, vz=-0.6)
sim.add(m=0.01, x=10.0, y=-2.01, z=0.4, vx=-0.02, vy=-0.05, vz=0.6)
sim.integrate(1e-300)
x0 = jnp.array([[p.x, p.y, p.z] for p in sim.particles])
v0 = jnp.array([[p.vx, p.vy, p.vz] for p in sim.particles])
gms = jnp.array([p.m for p in sim.particles])


s = SystemState(
    tracer_positions=jnp.empty((0, 3)),
    tracer_velocities=jnp.empty((0, 3)),
    massive_positions=x0,
    massive_velocities=v0,
    log_gms=jnp.log(gms),
    time=0.0,
    acceleration_func_kwargs={"c2": 100.0},
)

jorb_res = ppn_gravity(s)

reb_res - jorb_res

Array([[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 3.46944695e-18,  0.00000000e+00, -4.16333634e-17],
       [ 3.46944695e-18, -4.33680869e-19,  2.77555756e-17]],      dtype=float64)

In [4]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import numpy as np
import rebound
import reboundx

from jorbit.utils.states import SystemState
from jorbit.accelerations.gr import ppn_gravity


def _agreement_w_reboundx(n_tracer, n_massive, seed):
    np.random.seed(seed)
    n = n_tracer
    m = n_massive
    massive_x = []
    massive_v = []
    ms = []
    sim = rebound.Simulation()
    for i in range(m):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        massive_x.append(xs)
        massive_v.append(vs)
        m = np.random.uniform(0, 1)
        ms.append(m)
        sim.add(m=m, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    tracer_x = []
    tracer_v = []
    for i in range(n):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        tracer_x.append(xs)
        tracer_v.append(vs)
        sim.add(m=0.0, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    rebx = reboundx.Extras(sim)
    gr = rebx.load_force("gr_full")
    gr.params["c"] = 10
    gr.params["max_iterations"] = 100
    rebx.add_force(gr)
    sim.integrate(1e-300)
    reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

    tracer_x = jnp.array(tracer_x)
    tracer_v = jnp.array(tracer_v)
    massive_x = jnp.array(massive_x)
    massive_v = jnp.array(massive_v)
    ms = jnp.array(ms)
    s = SystemState(
        tracer_positions=tracer_x,
        tracer_velocities=tracer_v,
        massive_positions=massive_x,
        massive_velocities=massive_v,
        log_gms=jnp.log(ms),
        time=0.0,
        acceleration_func_kwargs={"c2": 100.0},
    )
    jorb_res = ppn_gravity(s)

    print(jnp.max(jnp.abs(jorb_res - reb_res)))
    # assert jnp.allclose(jorb_res, reb_res, atol=1e-14, rtol=1e-14)
    return reb_res, jorb_res


a, b = _agreement_w_reboundx(n_tracer=100, n_massive=10, seed=0)

7.147368626309915e-08


In [5]:
(a - b)[:10]

Array([[ 0.00000000e+00,  7.94093388e-23, -5.29395592e-23],
       [ 0.00000000e+00,  3.97046694e-23, -5.29395592e-23],
       [-5.29395592e-23,  0.00000000e+00,  1.05879118e-22],
       [ 2.64697796e-23,  0.00000000e+00,  7.27918939e-23],
       [ 5.29395592e-23, -3.30872245e-23,  1.05879118e-22],
       [ 0.00000000e+00,  4.63221143e-23,  3.97046694e-23],
       [ 0.00000000e+00, -1.05879118e-22,  2.11758237e-22],
       [ 0.00000000e+00,  0.00000000e+00, -2.64697796e-23],
       [-9.26442286e-23, -5.29395592e-23, -2.11758237e-22],
       [ 5.29395592e-23,  0.00000000e+00, -5.29395592e-23]],      dtype=float64)

In [6]:
(a - b)[10:]

Array([[-8.53144643e-11, -4.18565020e-11,  1.11319405e-10],
       [ 1.53126240e-10, -1.50207867e-10, -1.24978430e-10],
       [-1.09300818e-10,  3.45658866e-11,  1.32041505e-10],
       [ 1.15761317e-10, -4.90254912e-11, -8.82760062e-11],
       [ 9.72647830e-11,  1.51501252e-10,  2.18358760e-10],
       [-4.46705226e-11,  3.07411021e-11,  5.12492626e-11],
       [-1.08323697e-10,  1.12002051e-10, -1.40370236e-11],
       [ 6.53541058e-11,  3.04522994e-11, -4.29140457e-11],
       [-9.63718411e-11, -4.40308553e-11,  9.77835546e-11],
       [ 2.72747710e-10,  1.57570428e-10, -1.68741809e-10],
       [ 5.42339883e-10,  4.41772073e-11, -3.59055280e-12],
       [-1.09678375e-10,  1.83371150e-11,  2.09156781e-10],
       [-1.94715152e-10, -4.58208335e-11,  1.48004731e-10],
       [-9.05542892e-11,  7.51897235e-11, -1.46039657e-10],
       [-1.28251309e-10, -1.17008323e-10,  1.05698839e-10],
       [-1.57500340e-10, -9.70675728e-11,  2.03539243e-11],
       [-1.69858924e-10,  7.75434792e-11

In [7]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import numpy as np
import rebound
import reboundx

from jorbit.utils.states import SystemState
from jorbit.accelerations.gr import ppn_gravity


def _agreement_w_reboundx(n_tracer, n_massive, seed):
    np.random.seed(seed)
    n = n_tracer
    m = n_massive
    massive_x = []
    massive_v = []
    ms = []
    sim = rebound.Simulation()
    for i in range(m):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        massive_x.append(xs)
        massive_v.append(vs)
        m = np.random.uniform(0, 1)
        ms.append(m)
        sim.add(m=m, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    tracer_x = []
    tracer_v = []
    for i in range(n):
        xs = np.random.normal(0, 1, 3) * 1000
        vs = np.random.normal(0, 1, 3)
        tracer_x.append(xs)
        tracer_v.append(vs)
        sim.add(m=0.0, x=xs[0], y=xs[1], z=xs[2], vx=vs[0], vy=vs[1], vz=vs[2])
    rebx = reboundx.Extras(sim)
    gr = rebx.load_force("gr_full")
    gr.params["c"] = 10
    gr.params["max_iterations"] = 100
    rebx.add_force(gr)
    sim.integrate(1e-300)
    reb_res = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

    tracer_x = jnp.array(tracer_x)
    tracer_v = jnp.array(tracer_v)
    massive_x = jnp.array(massive_x)
    massive_v = jnp.array(massive_v)
    ms = jnp.array(ms)
    s = SystemState(
        tracer_positions=jnp.empty((0, 3)),
        tracer_velocities=jnp.empty((0, 3)),
        massive_positions=jnp.concatenate([massive_x, tracer_x]),
        massive_velocities=jnp.concatenate([massive_v, tracer_v]),
        log_gms=jnp.log(jnp.concatenate([ms, jnp.zeros(n)])),
        time=0.0,
        acceleration_func_kwargs={"c2": 100.0},
    )
    jorb_res = ppn_gravity(s)

    print(jnp.max(jnp.abs(jorb_res - reb_res)))
    # assert jnp.allclose(jorb_res, reb_res, atol=1e-14, rtol=1e-14)
    return reb_res, jorb_res


a, b = _agreement_w_reboundx(n_tracer=100, n_massive=10, seed=0)

2.710505431213761e-20
