In [None]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from jorbit.data.constants import IAS15_BV_DENOMS, IAS15_BX_DENOMS
from jorbit.utils.doubledouble import DoubleDouble


In [None]:
def _estimate_x_v_from_b(
    a0: DoubleDouble,
    v0: DoubleDouble,
    x0: DoubleDouble,
    dt: DoubleDouble,
    b_x_denoms: DoubleDouble,
    b_v_denoms: DoubleDouble,
    h: DoubleDouble,
    bp: DoubleDouble,
) -> tuple:
    """Given the b coefficients, estimate the new position and velocity.

    Args:
        a0 (DoubleDouble):
            The initial acceleration.
        v0 (DoubleDouble):
            The initial velocity.
        x0 (DoubleDouble):
            The initial position.
        dt (DoubleDouble):
            The timestep.
        b_x_denoms (DoubleDouble):
            The denominators for the Taylor expansion of the position.
        b_v_denoms (DoubleDouble):
            The denominators for the Taylor expansion of the velocity.
        h (DoubleDouble):
            The H array from REBOUND, a list of the roots of the Chebyshev polynomial.
        bp (IAS15Helper):
            The b coefficients.

    Returns:
        tuple:
            The estimated position and velocity.
    """
    # bp is *not* an IAS15Helper, it's just a DoubleDouble w/ shape
    # (n_internal_points, n_particles, 3)
    # aiming to stay shape-agnostic, enable higher or lower order scheme

    # these are all DoubleDoubles

    xcoeffs = DoubleDouble(
        jnp.zeros((bp.hi.shape[0] + 3, bp.hi.shape[1], bp.hi.shape[2]))
    )
    xcoeffs[3:] = bp * dt * dt / b_x_denoms[:, None, None]
    xcoeffs[2] = a0 * dt * dt / DoubleDouble(2.0)
    xcoeffs[1] = v0 * dt
    xcoeffs[0] = x0
    xcoeffs = xcoeffs[::-1]

    new_x_init = DoubleDouble(jnp.zeros(xcoeffs.hi.shape[1:]))
    estimated_x, _ = jax.lax.scan(lambda y, _p: (y * h + _p, None), new_x_init, xcoeffs)

    vcoeffs = DoubleDouble(
        jnp.zeros((bp.hi.shape[0] + 2, bp.hi.shape[1], bp.hi.shape[2]))
    )
    vcoeffs[2:] = bp * dt / b_v_denoms[:, None, None]
    vcoeffs[1] = a0 * dt
    vcoeffs[0] = v0
    vcoeffs = vcoeffs[::-1]

    new_v_init = DoubleDouble(jnp.zeros(vcoeffs.hi.shape[1:]))
    estimated_v, _ = jax.lax.scan(lambda y, _p: (y * h + _p, None), new_v_init, vcoeffs)

    return estimated_x, estimated_v


In [35]:
bp = jax.random.normal(jax.random.PRNGKey(0), (7, 2, 3))
a0 = jax.random.normal(jax.random.PRNGKey(1), (2, 3))
v0 = jax.random.normal(jax.random.PRNGKey(2), (2, 3))
x0 = jax.random.normal(jax.random.PRNGKey(3), (2, 3))
dt = 0.01
h = 0.123

vcoeffs = jnp.zeros((bp.shape[0] + 2, bp.shape[1], bp.shape[2]))
vcoeffs = vcoeffs.at[2:].set(bp * dt / IAS15_BV_DENOMS[:, None, None])
vcoeffs = vcoeffs.at[1].set( a0 * dt)
vcoeffs = vcoeffs.at[0].set( v0)
vcoeffs = vcoeffs[::-1]

estimated_v, _ = jax.lax.scan(lambda y, _p: (y * h + _p, None), jnp.zeros_like(x0), vcoeffs)

vcoeffs = bp * dt  / IAS15_BV_DENOMS[:, None, None]
vcoeffs = vcoeffs[::-1]
estimated_v2, _ = jax.lax.scan(lambda y, _p: (y * h + _p, None), jnp.zeros_like(x0), vcoeffs)
estimated_v2 *= h*h
estimated_v2 += v0  + (a0 * dt ) * h

jnp.allclose(estimated_v, estimated_v2)

Array(True, dtype=bool)

In [None]:
def _estimate_x_v_from_b(
    a0: jnp.ndarray,
    v0: jnp.ndarray,
    x0: jnp.ndarray,
    dt: jnp.ndarray,
    bp: jnp.ndarray, # remember to flip it!
) -> tuple[jnp.ndarray, jnp.ndarray]:
    xcoeffs = bp * dt * dt / IAS15_BX_DENOMS
    x, _ = jax.lax.scan(lambda y, _p: (y * h + _p, None), jnp.zeros_like(x0), xcoeffs)
    x *= h*h*h
    x += (v0 * dt) * h + (a0 * dt * dt / 2.0) * h * h + x0

    vcoeffs = bp * dt / IAS15_BV_DENOMS
    v, _ = jax.lax.scan(lambda y, _p: (y * h + _p, None), jnp.zeros_like(x0), vcoeffs)
    v *= h*h
    v += v0  + (a0 * dt ) * h
    return x, v

In [6]:
n_internal_points = 7
b_x_denoms = (1.0 + jnp.arange(1, n_internal_points + 1, 1, dtype=jnp.float64)) * (
    2.0 + jnp.arange(1, n_internal_points + 1, 1, dtype=jnp.float64)
)
b_v_denoms = jnp.arange(2, n_internal_points + 2, 1, dtype=jnp.float64)

In [7]:
b_x_denoms

Array([ 6., 12., 20., 30., 42., 56., 72.], dtype=float64)