# scratch work

In [1]:
import numpy as np
import math
import sys
from copy import deepcopy
import rebound


def rebx_calculate_gr_full(
    particles, C2, G, max_iterations=10, gravity_ignore_10=False
):
    N = len(particles)
    # Create a copy of particles to avoid modifying the original
    ps_b = deepcopy(particles)

    # Calculate Newtonian accelerations
    for i in range(N):
        ps_b[i].ax = 0.0
        ps_b[i].ay = 0.0
        ps_b[i].az = 0.0

    # Compute initial Newtonian accelerations
    for i in range(N):
        for j in range(i + 1, N):
            dx = ps_b[i].x - ps_b[j].x
            dy = ps_b[i].y - ps_b[j].y
            dz = ps_b[i].z - ps_b[j].z
            r2 = dx * dx + dy * dy + dz * dz
            r = math.sqrt(r2)
            prefac = G / (r2 * r)

            ps_b[i].ax -= prefac * ps_b[j].m * dx
            ps_b[i].ay -= prefac * ps_b[j].m * dy
            ps_b[i].az -= prefac * ps_b[j].m * dz

            ps_b[j].ax += prefac * ps_b[i].m * dx
            ps_b[j].ay += prefac * ps_b[i].m * dy
            ps_b[j].az += prefac * ps_b[i].m * dz

    # Transform to barycentric coordinates (placeholder - would depend on simulation structure)
    # In the original code, this uses reb_simulation_com and reb_particle_isub
    # You might need to implement this based on your specific simulation framework
    _q = rebound.Simulation()
    _q.add(ps_b)
    _q.move_to_com()
    ps_b = list(_q.particles)
    print(f"mass: {ps_b[0].m}")

    # Compute constant acceleration terms
    a_const = np.zeros((N, 3))

    for i in range(N):
        a_constx, a_consty, a_constz = 0.0, 0.0, 0.0

        for j in range(N):
            if j != i:
                dxij = ps_b[i].x - ps_b[j].x
                dyij = ps_b[i].y - ps_b[j].y
                dzij = ps_b[i].z - ps_b[j].z

                rij2 = dxij * dxij + dyij * dyij + dzij * dzij
                rij = math.sqrt(rij2)
                rij3 = rij2 * rij

                # First constant part calculations
                a1 = sum(
                    (4.0 / C2)
                    * G
                    * particles[k].m
                    / math.sqrt(
                        (ps_b[i].x - ps_b[k].x) ** 2
                        + (ps_b[i].y - ps_b[k].y) ** 2
                        + (ps_b[i].z - ps_b[k].z) ** 2
                    )
                    for k in range(N)
                    if k != i
                )

                a2 = sum(
                    (1.0 / C2)
                    * G
                    * particles[l].m
                    / math.sqrt(
                        (ps_b[l].x - ps_b[j].x) ** 2
                        + (ps_b[l].y - ps_b[j].y) ** 2
                        + (ps_b[l].z - ps_b[j].z) ** 2
                    )
                    for l in range(N)
                    if l != j
                )

                vi2 = ps_b[i].vx ** 2 + ps_b[i].vy ** 2 + ps_b[i].vz ** 2
                a3 = -vi2 / C2

                vj2 = ps_b[j].vx ** 2 + ps_b[j].vy ** 2 + ps_b[j].vz ** 2
                a4 = -2 * vj2 / C2

                a5 = (4 / C2) * (
                    ps_b[i].vx * ps_b[j].vx
                    + ps_b[i].vy * ps_b[j].vy
                    + ps_b[i].vz * ps_b[j].vz
                )

                a6_0 = dxij * ps_b[j].vx + dyij * ps_b[j].vy + dzij * ps_b[j].vz
                a6 = (3 / (2 * C2)) * a6_0 * a6_0 / rij2

                a7 = (dxij * ps_b[j].ax + dyij * ps_b[j].ay + dzij * ps_b[j].az) / (
                    2 * C2
                )

                factor1 = a1 + a2 + a3 + a4 + a5 + a6 + a7

                a_constx += G * particles[j].m * dxij * factor1 / rij3
                a_consty += G * particles[j].m * dyij * factor1 / rij3
                a_constz += G * particles[j].m * dzij * factor1 / rij3

                # Second constant part calculations
                dvxij = ps_b[i].vx - ps_b[j].vx
                dvyij = ps_b[i].vy - ps_b[j].vy
                dvzij = ps_b[i].vz - ps_b[j].vz

                factor2 = (
                    dxij * (4 * ps_b[i].vx - 3 * ps_b[j].vx)
                    + dyij * (4 * ps_b[i].vy - 3 * ps_b[j].vy)
                    + dzij * (4 * ps_b[i].vz - 3 * ps_b[j].vz)
                )

                a_constx += (
                    G
                    * particles[j].m
                    / C2
                    * (factor2 * dvxij / rij3 + 7 / 2 * ps_b[j].ax / rij)
                )
                a_consty += (
                    G
                    * particles[j].m
                    / C2
                    * (factor2 * dvyij / rij3 + 7 / 2 * ps_b[j].ay / rij)
                )
                a_constz += (
                    G
                    * particles[j].m
                    / C2
                    * (factor2 * dvzij / rij3 + 7 / 2 * ps_b[j].az / rij)
                )

        a_const[i] = [a_constx, a_consty, a_constz]

    # Set initial accelerations to constant terms
    for i in range(N):
        ps_b[i].ax = a_const[i][0]
        ps_b[i].ay = a_const[i][1]
        ps_b[i].az = a_const[i][2]

    print(ps_b[0].ax)
    # Iterative refinement of accelerations
    for k in range(10):  # Maximum 10 iterations
        # Store old accelerations
        a_old = np.array([[p.ax, p.ay, p.az] for p in ps_b])

        # Add non-constant term
        for i in range(N):
            non_constx, non_consty, non_constz = 0.0, 0.0, 0.0

            for j in range(N):
                if j != i:
                    dxij = ps_b[i].x - ps_b[j].x
                    dyij = ps_b[i].y - ps_b[j].y
                    dzij = ps_b[i].z - ps_b[j].z

                    rij = math.sqrt(dxij * dxij + dyij * dyij + dzij * dzij)
                    rij3 = rij * rij * rij

                    dotproduct = (
                        dxij * ps_b[j].ax + dyij * ps_b[j].ay + dzij * ps_b[j].az
                    )

                    non_constx += (G * particles[j].m * dxij / rij3) * dotproduct / (
                        2 * C2
                    ) + (7 / (2 * C2)) * G * particles[j].m * ps_b[j].ax / rij
                    non_consty += (G * particles[j].m * dyij / rij3) * dotproduct / (
                        2 * C2
                    ) + (7 / (2 * C2)) * G * particles[j].m * ps_b[j].ay / rij
                    non_constz += (G * particles[j].m * dzij / rij3) * dotproduct / (
                        2 * C2
                    ) + (7 / (2 * C2)) * G * particles[j].m * ps_b[j].az / rij

            ps_b[i].ax = a_const[i][0] + non_constx
            ps_b[i].ay = a_const[i][1] + non_consty
            ps_b[i].az = a_const[i][2] + non_constz

        # Check for convergence
        maxdev = 0.0
        for i in range(N):
            dx = (
                abs((ps_b[i].ax - a_old[i][0]) / ps_b[i].ax)
                if abs(ps_b[i].ax) > sys.float_info.epsilon
                else 0
            )
            dy = (
                abs((ps_b[i].ay - a_old[i][1]) / ps_b[i].ay)
                if abs(ps_b[i].ay) > sys.float_info.epsilon
                else 0
            )
            dz = (
                abs((ps_b[i].az - a_old[i][2]) / ps_b[i].az)
                if abs(ps_b[i].az) > sys.float_info.epsilon
                else 0
            )

            maxdev = max(maxdev, dx, dy, dz)

        if maxdev < sys.float_info.epsilon:
            break

        if k == 9:
            print(
                f"Warning: 10 loops in GR calculation did not converge. Fractional Error: {maxdev}"
            )

    # Update original particles with calculated accelerations
    for i in range(N):
        particles[i].ax += ps_b[i].ax
        particles[i].ay += ps_b[i].ay
        particles[i].az += ps_b[i].az

    return particles


sim = rebound.Simulation()
# sim.add(m=1.0, x=-1.0, y=0.01, z=-0.01)
# sim.add(m=1.0, x=1.0, y=-0.01, z=0.01)
sim.add(m=1.0, x=-1.0, y=1.01, z=-0.05, vx=0.01, vy=0.01, vz=0.1)
sim.add(m=1.0, x=1.0, y=-0.01, z=0.01, vx=-0.21, vy=-0.01, vz=-0.01)
sim.integrate(1e-300)
particles = list(sim.particles)

p1, p2 = rebx_calculate_gr_full(
    particles, C2=10**2, G=1.0, max_iterations=10, gravity_ignore_10=False
)
print(p1.ax, p1.ay, p1.az)

mass: 1.0
-0.007200313073905964
0.1694753988552079 -0.08660386274476144 0.004950194457503455


In [2]:
import rebound
import reboundx


sim = rebound.Simulation()
sim.add(m=1.0, x=-1.0, y=1.01, z=-0.05, vx=0.01, vy=0.01, vz=0.1)
sim.add(m=1.0, x=1.0, y=-0.01, z=0.01, vx=-0.21, vy=-0.01, vz=-0.01)
sim.integrate(1e-300)
# print(sim.particles[0].ax, sim.particles[0].ay, sim.particles[1].az)
a0 = sim.particles[0].ax

sim = rebound.Simulation()
sim.add(m=1.0, x=-1.0, y=1.01, z=-0.05, vx=0.01, vy=0.01, vz=0.1)
sim.add(m=1.0, x=1.0, y=-0.01, z=0.01, vx=-0.21, vy=-0.01, vz=-0.01)
rebx = reboundx.Extras(sim)
gr = rebx.load_force("gr_full")
gr.params["c"] = 10
gr.params["max_iterations"] = 100
rebx.add_force(gr)
sim.integrate(1e-300)
a1 = sim.particles[0].ax
print(sim.particles[0].ax, sim.particles[0].ay, sim.particles[1].az)

0.1694753988552079 -0.08660386274476144 -0.004950194457503455


In [4]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp


def gr_full2(
    x: jnp.ndarray,  # positions (N,3)
    v: jnp.ndarray,  # velocities (N,3)
    a0: jnp.ndarray,  # initial accelerations (N,3)
    m: jnp.ndarray,  # masses (N,)
    C2: float,
    G: float,
    max_iterations: int = 10,
    gravity_ignore_10: bool = False,
) -> jnp.ndarray:
    N = x.shape[0]
    # Calculate pairwise differences
    dx = x[:, None, :] - x[None, :, :]  # (N,N,3)
    r2 = jnp.sum(dx**2, axis=-1)  # (N,N)
    r = jnp.sqrt(r2)  # (N,N)
    r3 = r2 * r  # (N,N)

    # Mask for i!=j calculations
    mask = ~jnp.eye(N, dtype=bool)  # (N,N)

    # Compute initial Newtonian accelerations
    prefac = G / r3
    prefac = jnp.where(mask, prefac, 0.0)
    a_newt = -jnp.sum(prefac[:, :, None] * dx * m[None, :, None], axis=1)  # (N,3)

    # Move to barycentric frame
    x_com = jnp.sum(x * m[:, None], axis=0) / jnp.sum(m)
    v_com = jnp.sum(v * m[:, None], axis=0) / jnp.sum(m)
    x = x - x_com
    v = v - v_com

    # Compute constant acceleration terms
    v2 = jnp.sum(v**2, axis=1)  # (N,)
    vdotv = jnp.dot(v, v.T)  # (N,N)
    dv = v[:, None, :] - v[None, :, :]  # (N,N,3)
    rdotv = jnp.sum(dx * v[None, :, :], axis=-1)  # (N,N)

    # First constant part calculations
    a1 = (4.0 / C2) * G * jnp.sum(m[None, :] / r, axis=1, where=mask)  # (N,)
    a2 = (1.0 / C2) * G * jnp.sum(m[None, :] / r, axis=1, where=mask)  # (N,)
    a3 = -v2 / C2  # (N,)
    a4 = -2 * v2[None, :] / C2  # (N,N)
    a5 = (4 / C2) * vdotv  # (N,N)
    a6 = (3 / (2 * C2)) * (rdotv**2 / r2)  # (N,N)
    a7 = jnp.sum(dx * a0[None, :, :], axis=-1) / (2 * C2)  # (N,N)

    # Combine all factors
    factor1 = (a1 + a2 + a3)[:, None] + jnp.where(mask, a4 + a5 + a6 + a7, 0.0)  # (N,N)

    # Calculate first part of a_const
    a_const = jnp.sum(
        G * m[None, :, None] * dx * (factor1[:, :, None]) / r3[:, :, None],
        axis=1,
        where=mask[:, :, None],
    )  # (N,3)

    # Second constant part
    factor2 = jnp.sum(dx * (4 * v[:, None, :] - 3 * v[None, :, :]), axis=-1)  # (N,N)

    # Add second part to a_const
    a_const += jnp.sum(
        (G * m[None, :, None] / C2)
        * (
            (factor2[:, :, None] * dv / r3[:, :, None])
            + (7 / 2 * a0[None, :, :] / r[:, :, None])
        ),
        axis=1,
        where=mask[:, :, None],
    )

    def iteration_step(a_curr):
        rdota = jnp.sum(dx * a_curr[None, :, :], axis=-1)  # (N,N)
        non_const = jnp.sum(
            (G * m[None, :, None] / (2 * C2))
            * (
                (dx * rdota[:, :, None] / r3[:, :, None])
                + (7 * a_curr[None, :, :] / r[:, :, None])
            ),
            axis=1,
            where=mask[:, :, None],
        )
        return a_const + non_const

    def do_nothing(carry):
        return carry

    def do_iteration(carry):
        a_prev, a_curr, _ = carry
        a_next = iteration_step(a_curr)
        ratio = jnp.max(jnp.abs((a_next - a_curr) / a_next))
        return (a_curr, a_next, ratio)

    def body_fn(carry, _):
        a_prev, a_curr, ratio = carry

        # Use cond to either continue iteration or return current state
        should_continue = ratio > jnp.finfo(jnp.float64).eps
        new_carry = jax.lax.cond(should_continue, do_iteration, do_nothing, carry)

        return new_carry, None

    # Initialize with constant terms
    init_a = jnp.zeros_like(a_const)
    init_carry = (init_a, a_const, 1.0)

    # Run fixed number of iterations using scan
    final_carry, _ = jax.lax.scan(body_fn, init_carry, None, length=max_iterations)

    # Extract final acceleration
    _, a_final, _ = final_carry

    return a_newt + a_final


import rebound

sim = rebound.Simulation()
sim.add(m=1.0, x=-1.0, y=1.01, z=-0.05, vx=0.01, vy=0.01, vz=0.1)
sim.add(m=1.0, x=1.0, y=-0.01, z=0.01, vx=-0.21, vy=-0.01, vz=-0.01)
sim.integrate(1e-300)

x0 = jnp.array([[p.x, p.y, p.z] for p in sim.particles])
v0 = jnp.array([[p.vx, p.vy, p.vz] for p in sim.particles])
a0 = jnp.array([[p.ax, p.ay, p.az] for p in sim.particles])

float(
    gr_full2(
        x=x0,
        v=v0,
        a0=a0,
        m=jnp.array([1.0, 1.0]),
        C2=10**2,
        G=1.0,
        max_iterations=10,
        gravity_ignore_10=False,
    )[0][1]
)

-0.08660386274476145