# scratch work

In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from jplephem.spk import SPK
import astropy.units as u
from astropy.time import Time
from astropy.utils.data import download_file

In [2]:
from jorbit.ephemeris import Ephemeris

In [4]:
import astropy.constants as const
import astropy.units as u

const.G.to(u.au**3 / u.Msun / u.day**2).value

np.float64(0.00029591220819207774)

In [2]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp


def gr_full2(
    x: jnp.ndarray,  # positions (N,3)
    v: jnp.ndarray,  # velocities (N,3)
    a0: jnp.ndarray,  # initial accelerations (N,3)
    m: jnp.ndarray,  # masses (N,)
    C2: float,
    G: float,
    max_iterations: int = 10,
    gravity_ignore_10: bool = False,
) -> jnp.ndarray:
    N = x.shape[0]
    # Calculate pairwise differences
    dx = x[:, None, :] - x[None, :, :]  # (N,N,3)
    r2 = jnp.sum(dx**2, axis=-1)  # (N,N)
    r = jnp.sqrt(r2)  # (N,N)
    r3 = r2 * r  # (N,N)

    # Mask for i!=j calculations
    mask = ~jnp.eye(N, dtype=bool)  # (N,N)

    # Compute initial Newtonian accelerations
    prefac = G / r3
    prefac = jnp.where(mask, prefac, 0.0)
    a_newt = -jnp.sum(prefac[:, :, None] * dx * m[None, :, None], axis=1)  # (N,3)

    # Move to barycentric frame
    x_com = jnp.sum(x * m[:, None], axis=0) / jnp.sum(m)
    v_com = jnp.sum(v * m[:, None], axis=0) / jnp.sum(m)
    x = x - x_com
    v = v - v_com

    # Compute constant acceleration terms
    v2 = jnp.sum(v**2, axis=1)  # (N,)
    vdotv = jnp.dot(v, v.T)  # (N,N)
    dv = v[:, None, :] - v[None, :, :]  # (N,N,3)
    rdotv = jnp.sum(dx * v[None, :, :], axis=-1)  # (N,N)

    # First constant part calculations
    a1 = (4.0 / C2) * G * jnp.sum(m[None, :] / r, axis=1, where=mask)  # (N,)
    a2 = (1.0 / C2) * G * jnp.sum(m[None, :] / r, axis=1, where=mask)  # (N,)
    a3 = -v2 / C2  # (N,)
    a4 = -2 * v2[None, :] / C2  # (N,N)
    a5 = (4 / C2) * vdotv  # (N,N)
    a6 = (3 / (2 * C2)) * (rdotv**2 / r2)  # (N,N)
    a7 = jnp.sum(dx * a0[None, :, :], axis=-1) / (2 * C2)  # (N,N)

    # Combine all factors
    factor1 = (a1 + a2 + a3)[:, None] + jnp.where(mask, a4 + a5 + a6 + a7, 0.0)  # (N,N)

    # Calculate first part of a_const
    a_const = jnp.sum(
        G * m[None, :, None] * dx * (factor1[:, :, None]) / r3[:, :, None],
        axis=1,
        where=mask[:, :, None],
    )  # (N,3)

    # Second constant part
    factor2 = jnp.sum(dx * (4 * v[:, None, :] - 3 * v[None, :, :]), axis=-1)  # (N,N)

    # Add second part to a_const
    a_const += jnp.sum(
        (G * m[None, :, None] / C2)
        * (
            (factor2[:, :, None] * dv / r3[:, :, None])
            + (7 / 2 * a0[None, :, :] / r[:, :, None])
        ),
        axis=1,
        where=mask[:, :, None],
    )

    def iteration_step(a_curr):
        rdota = jnp.sum(dx * a_curr[None, :, :], axis=-1)  # (N,N)
        non_const = jnp.sum(
            (G * m[None, :, None] / (2 * C2))
            * (
                (dx * rdota[:, :, None] / r3[:, :, None])
                + (7 * a_curr[None, :, :] / r[:, :, None])
            ),
            axis=1,
            where=mask[:, :, None],
        )
        return a_const + non_const

    def do_nothing(carry):
        return carry

    def do_iteration(carry):
        a_prev, a_curr, _ = carry
        a_next = iteration_step(a_curr)
        ratio = jnp.max(jnp.abs((a_next - a_curr) / a_next))
        return (a_curr, a_next, ratio)

    def body_fn(carry, _):
        a_prev, a_curr, ratio = carry

        # Use cond to either continue iteration or return current state
        should_continue = ratio > jnp.finfo(jnp.float64).eps
        new_carry = jax.lax.cond(should_continue, do_iteration, do_nothing, carry)

        return new_carry, None

    # Initialize with constant terms
    init_a = jnp.zeros_like(a_const)
    init_carry = (init_a, a_const, 1.0)

    # Run fixed number of iterations using scan
    final_carry, _ = jax.lax.scan(body_fn, init_carry, None, length=max_iterations)

    # Extract final acceleration
    _, a_final, _ = final_carry

    return a_newt + a_final


import rebound

sim = rebound.Simulation()
sim.add(m=1.0, x=-1.0, y=1.01, z=-0.05, vx=0.01, vy=0.01, vz=0.1)
sim.add(m=1.3, x=1.0, y=-0.01, z=0.01, vx=-0.21, vy=-0.01, vz=-0.01)
sim.integrate(1e-300)

x0 = jnp.array([[p.x, p.y, p.z] for p in sim.particles])
v0 = jnp.array([[p.vx, p.vy, p.vz] for p in sim.particles])
a0 = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

for i in range(3):
    print(
        float(
            gr_full2(
                x=x0,
                v=v0,
                a0=a0,
                m=jnp.array([1.0, 1.3]),
                C2=10**2,
                G=1.0,
                max_iterations=10,
                gravity_ignore_10=False,
            )[1][i]
        )
    )

-0.16860225827621192
0.08615444892210075
-0.004927216533403929


In [3]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import rebound
from jorbit.accelerations.gr import gr_full

sim = rebound.Simulation()
sim.add(m=1.0, x=-1.0, y=1.01, z=-0.05, vx=0.01, vy=0.01, vz=0.1)
sim.add(m=1.3, x=1.0, y=-0.01, z=0.01, vx=-0.21, vy=-0.01, vz=-0.01)
sim.integrate(1e-300)
# sim.move_to_com()

x0 = jnp.array([[p.x, p.y, p.z] for p in sim.particles])
v0 = jnp.array([[p.vx, p.vy, p.vz] for p in sim.particles])
a0 = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

tmp = gr_full(
    x=x0,
    v=v0,
    gms=jnp.array([1.0, 1.3]),
    max_iterations=20,
)[
    0
][1]

for i in range(3):
    print(float(tmp[i]))

doing iteration
doing iteration
doing iteration
doing iteration
doing iteration
doing iteration
doing iteration
doing iteration
doing iteration
doing iteration
-0.16860225827621192
0.08615444892210075
-0.004927216533403929


In [6]:
-0.004919974804683859 / -0.004927216533403929

0.9985302596971383

In [7]:
-0.16836086731887628 / -0.16860225827621192

0.9985682815888494

In [4]:
import rebound
import reboundx

sim = rebound.Simulation()
sim.add(m=1.0, x=-1.0, y=1.01, z=-0.05, vx=0.01, vy=0.01, vz=0.1)
sim.add(m=1.3, x=1.0, y=-0.01, z=0.01, vx=-0.21, vy=-0.01, vz=-0.01)
sim.move_to_com()
rebx = reboundx.Extras(sim)
gr = rebx.load_force("gr_full")
gr.params["c"] = 10
gr.params["max_iterations"] = 100
rebx.add_force(gr)
sim.integrate(1e-300)
a1 = sim.particles[0].ax
print(sim.particles[1].ax)
print(sim.particles[1].ay)
print(sim.particles[1].az)

-0.16836086731887628
0.08603133953385958
-0.004919974804683859


In [26]:
sim.G

1.0

In [None]:
sim = rebound.Simulation()
sim.add(m=1.0, x=-1.0, y=1.01, z=-0.05, vx=0.01, vy=0.01, vz=0.1)
sim.add(m=1.0, x=1.0, y=-0.01, z=0.01, vx=-0.21, vy=-0.01, vz=-0.01)
sim.move_to_com()
x = jnp.array([[p.x, p.y, p.z] for p in sim.particles])
v = jnp.array([[p.vx, p.vy, p.vz] for p in sim.particles])
gms = jnp.array([p.m for p in sim.particles])
# print(sim.particles

Array([-0.07623188, -0.00086957,  0.02521739], dtype=float64)

In [23]:
x_com

Array([ 0.08695652,  0.28898551, -0.01072464], dtype=float64)

In [4]:
q = jnp.array(
    [
        0.21879616777694208,
        -0.11181317786093832,
        0.006386233904172433,
    ]
) - jnp.array(
    [
        0.21910833725420858,
        -0.11197238429434422,
        -0.00491997480468386,
    ]
)
q / jnp.linalg.norm(q)

Array([-0.0275972 ,  0.01407457,  0.99952004], dtype=float64)