# scratch work

In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from functools import partial


from decimal import Decimal, getcontext

getcontext().prec = 50


import mpmath

mpmath.mp.dps = 75

from jorbit.utils.doubledouble import DoubleDouble, dd_sum
from jorbit.utils.generate_coefficients import create_iasnn_constants
from jorbit.data.constants import IAS15_D, IAS15_H, IAS15_RR
from jorbit.integrators.ias15_dd import setup_iasnn_integrator, step

In [2]:
bx_denoms, bv_denoms, h, r, d = setup_iasnn_integrator(n_internal_points=7)

100%|██████████| 998/998 [00:00<00:00, 1334.77it/s]


In [None]:
@partial(jax.jit, static_argnums=(0,))
def refine_intermediate_g(substep_num, g, r, at, a0):
    # substep_num starts at 1, 1->h1, etc
    substep_num -= 1

    def scan_body(carry, idx):
        result, start_pos = carry
        result = (result - g[idx]) * r[start_pos + idx + 1]
        return (result, start_pos), result

    start_pos = (substep_num * (substep_num + 1)) // 2
    initial_result = (at - a0) * r[start_pos]
    indices = jnp.arange(substep_num)
    (final_result, _), _ = jax.lax.scan(scan_body, (initial_result, start_pos), indices)
    return final_result


refine_intermediate_g(
    substep_num=3,
    g=DoubleDouble(jnp.repeat(jnp.arange(7, dtype=jnp.float64), 12).reshape(7, 4, 3)),
    r=r,
    at=DoubleDouble(jnp.ones((4, 3), dtype=jnp.float64) * 0.4),
    a0=DoubleDouble(jnp.ones((4, 3), dtype=jnp.float64) * 0.2),
)

DoubleDouble([[5.30090278 5.30090278 5.30090278]
 [5.30090278 5.30090278 5.30090278]
 [5.30090278 5.30090278 5.30090278]
 [5.30090278 5.30090278 5.30090278]], [[3.54290424e-16 3.54290424e-16 3.54290424e-16]
 [3.54290424e-16 3.54290424e-16 3.54290424e-16]
 [3.54290424e-16 3.54290424e-16 3.54290424e-16]
 [3.54290424e-16 3.54290424e-16 3.54290424e-16]])

In [25]:
at = 0.4
a0 = 0.2
g = jnp.arange(7)
(((at - a0) / IAS15_RR[3] - g[0]) / IAS15_RR[4] - g[1]) / IAS15_RR[5]

Array(5.30090278, dtype=float64)

In [29]:
# def refine_b_and_g(b, g, at, a0, substep_num, r, d):
#     old_g = g
#     new_g = refine_intermediate_g(substep_num=substep_num, g=g, r=r, at=at, a0=a0)
#     g_diff = new_g - old_g


# refine_b_and_g(
#     b=DoubleDouble(jnp.ones((4, 3), dtype=jnp.float64) * 0.4),
#     g=DoubleDouble(jnp.repeat(jnp.arange(7, dtype=jnp.float64), 12).reshape(7, 4, 3)),
#     at=DoubleDouble(jnp.ones((4, 3), dtype=jnp.float64) * 0.4),
#     a0=DoubleDouble(jnp.ones((4, 3), dtype=jnp.float64) * 0.2),
#     substep_num=3,
#     r=r,
#     d=d,
# )

DoubleDouble([[5.30090278 5.30090278 5.30090278]
 [5.30090278 5.30090278 5.30090278]
 [5.30090278 5.30090278 5.30090278]
 [5.30090278 5.30090278 5.30090278]], [[3.54290424e-16 3.54290424e-16 3.54290424e-16]
 [3.54290424e-16 3.54290424e-16 3.54290424e-16]
 [3.54290424e-16 3.54290424e-16 3.54290424e-16]
 [3.54290424e-16 3.54290424e-16 3.54290424e-16]])

6

In [2]:
precompued_setup = setup_iasnn_integrator(n_internal_evals=7)

x0 = DoubleDouble(jnp.ones((4, 3), dtype=jnp.float64) * 0.123)
v0 = DoubleDouble(jnp.ones((4, 3), dtype=jnp.float64) * 0.2)
a0 = DoubleDouble(jnp.ones((4, 3), dtype=jnp.float64) * 0.1)
b = DoubleDouble(jnp.zeros((7, 4, 3), dtype=jnp.float64))

b[0] += DoubleDouble(1.0)
b[1] += DoubleDouble(2.0)
b[2] += DoubleDouble(3.0)
b[3] += DoubleDouble(4.0)
b[4] += DoubleDouble(5.0)
b[5] += DoubleDouble(6.0)
b[6] += DoubleDouble(7.0)


g_dd, x_est, v_est = step(x0, v0, a0, b, precompued_setup)

100%|██████████| 998/998 [00:00<00:00, 1309.80it/s]


In [3]:
from jorbit.integrators import initialize_ias15_helper

b = initialize_ias15_helper(4)
g = initialize_ias15_helper(4)

b.p0 += 1.0
b.p1 += 2.0
b.p2 += 3.0
b.p3 += 4.0
b.p4 += 5.0
b.p5 += 6.0
b.p6 += 7.0


g.p0 = (
    b.p6 * IAS15_D[15]
    + b.p5 * IAS15_D[10]
    + b.p4 * IAS15_D[6]
    + b.p3 * IAS15_D[3]
    + b.p2 * IAS15_D[1]
    + b.p1 * IAS15_D[0]
    + b.p0
)
g.p1 = (
    b.p6 * IAS15_D[16]
    + b.p5 * IAS15_D[11]
    + b.p4 * IAS15_D[7]
    + b.p3 * IAS15_D[4]
    + b.p2 * IAS15_D[2]
    + b.p1
)
g.p2 = (
    b.p6 * IAS15_D[17]
    + b.p5 * IAS15_D[12]
    + b.p4 * IAS15_D[8]
    + b.p3 * IAS15_D[5]
    + b.p2
)
g.p3 = b.p6 * IAS15_D[18] + b.p5 * IAS15_D[13] + b.p4 * IAS15_D[9] + b.p3
g.p4 = b.p6 * IAS15_D[19] + b.p5 * IAS15_D[14] + b.p4
g.p5 = b.p6 * IAS15_D[20] + b.p5
g.p6 = b.p6

x0 = jnp.ones((4, 3), dtype=jnp.float64) * 0.123
v0 = jnp.ones((4, 3), dtype=jnp.float64) * 0.2
a0 = jnp.ones((4, 3), dtype=jnp.float64) * 0.1
dt = 0.01

# fmt: off
n=-1
x = x0 + ((((((((b.p6*7.*IAS15_H[n]/9. + b.p5)*3.*IAS15_H[n]/4. + b.p4)*5.*IAS15_H[n]/7. + b.p3)*2.*IAS15_H[n]/3. + b.p2)*3.*IAS15_H[n]/5. + b.p1)*IAS15_H[n]/2. + b.p0)*IAS15_H[n]/3. + a0)*dt*IAS15_H[n]/2. + v0)*dt*IAS15_H[n]
v = v0 + (((((((b.p6*7.*IAS15_H[n]/8. + b.p5)*6.*IAS15_H[n]/7. + b.p4)*5.*IAS15_H[n]/6. + b.p3)*4.*IAS15_H[n]/5. + b.p2)*3.*IAS15_H[n]/4. + b.p1)*2.*IAS15_H[n]/3. + b.p0)*IAS15_H[n]/2. + a0)*dt*IAS15_H[n]
# fmt: on

In [4]:
g_dd[0].hi - g.p0, g_dd[1].hi - g.p1, g_dd[6].hi - g.p6

(Array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], dtype=float64),
 Array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], dtype=float64),
 Array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], dtype=float64))

In [5]:
x - x_est.hi, v - v_est.hi

(Array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], dtype=float64),
 Array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], dtype=float64))