In [1]:
import mpmath as mpm
from mpmath import matrix, mp, mpf
import numpy as np
from tqdm import tqdm

from decimal import Decimal, getcontext

import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from jorbit.utils.doubledouble import DoubleDouble, dd_sum
from jorbit.integrators.mpmath_testing_only import *

mp.dps = 50
getcontext().prec = 50

In [3]:
steps_per_orbit = 100
n_internal_points = 10

pre = precompute(n_internal_points)
b_x_denoms, b_v_denoms, h, r, c, d = pre
x0 = matrix([["1.0", "0", "0"]])
v0 = matrix([["0", "1.0", "0"]])
a0 = matrix([["-1.0", "0", "0"]])
b = mpm.ones(n_internal_points, 3)

dt = mpf("2") * mp.pi / mpf(str(steps_per_orbit))

for i in tqdm(range(steps_per_orbit)):
    x0, v0, b = step(x0, v0, b, dt, pre)
x0 - matrix([["1.0", "0", "0"]])

100%|██████████| 998/998 [00:00<00:00, 1296.93it/s]
100%|██████████| 100/100 [00:15<00:00,  6.50it/s]


matrix(
[['3.05228087094950562875866512798342466351328683336766371909410451852848762318e-52', '-1.43161343005138817052923417059569757654205758111772883166232088315346684839e-51', '8.77084724060278674654809331496683344853601703968548205691223663551652587233e-79']])

In [6]:
from jorbit.integrators.ias15_dd import setup_iasnn_integrator
from jorbit.integrators.ias15_dd import _estimate_x_v_from_b as exvfd
from jorbit.integrators.ias15_dd import acceleration_func as af
from jorbit.integrators.ias15_dd import _refine_b_and_g as rbg

n_internal_points = 7


pre = precompute(7)
b_x_denoms, b_v_denoms, h, r, c, d = pre
dt = mpf("0.01")
x0 = matrix([["1.0", "0", "0"]])
v0 = matrix([["0", "1.0", "0"]])
a0 = matrix([["-1.0", "0", "0"]])
b = mpm.ones(7, 3)  # missing the dimension for nparticles to work w/ mpmath
for i in range(7):
    b[i, :] = b[i, :] + mpf(str(i)) * mpf("1.23")
g = d * b


pre_dd = setup_iasnn_integrator(n_internal_points=7)
b_x_denoms_dd, b_v_denoms_dd, h_dd, r_dd, c_dd, d_matrix_dd = pre_dd
x0_dd = DoubleDouble(jnp.array([[1.0, 0, 0]]))
v0_dd = DoubleDouble(jnp.array([[0, 1.0, 0]]))
a0_dd = -x0
b_dd = DoubleDouble(jnp.ones((7, 1, 3), dtype=jnp.float64))
for i in range(7):
    b_dd[i] += DoubleDouble.from_string(str(i)) * DoubleDouble.from_string("1.23")
dt_dd = DoubleDouble.from_string("0.01")
a0_dd = -x0_dd
g_dd = dd_sum((b_dd[None, :, :, :] * d_matrix_dd[:, :, None, None]), axis=1)


for n in range(1, n_internal_points):
    print(f"n = {n}")
    x, v = estimate_x_v_from_b(
        a0=a0,
        v0=v0,
        x0=x0,
        dt=dt,
        b_x_denoms=b_x_denoms,
        b_v_denoms=b_v_denoms,
        h=h[n],
        bp=b,
    )
    at = acceleration_func(x)
    b, g = refine_b_and_g(
        r=r, c=c, b=b, g=g, at=at, a0=a0, substep_num=n, return_g_diff=False
    )

    x_dd, v_dd = exvfd(
        a0=a0_dd,
        v0=v0_dd,
        x0=x0_dd,
        dt=dt_dd,
        b_x_denoms=b_x_denoms_dd,
        b_v_denoms=b_v_denoms_dd,
        h=h_dd[n],
        bp=b_dd,
    )
    at_dd = af(x_dd)
    b_dd, g_dd = rbg(
        r=r_dd,
        c=c_dd,
        b=b_dd,
        g=g_dd,
        at=at_dd,
        a0=a0_dd,
        substep_num=n,
        return_g_diff=False,
    )

    print("x")
    print(
        max(matrix(np.array(x_dd[0, :].hi)).T + matrix(np.array(x_dd[0, :].lo)).T - x)
    )

    print("v")
    print(
        max(matrix(np.array(v_dd[0, :].hi)).T + matrix(np.array(v_dd[0, :].lo)).T - v)
    )

    print("at")
    print(
        max(
            matrix(np.array(at_dd[0, :].hi)).T + matrix(np.array(at_dd[0, :].lo)).T - at
        )
    )

    print("b")
    print(
        max(matrix(np.array(b_dd[:, 0, :].hi)) + matrix(np.array(b_dd[:, 0, :].lo)) - b)
    )

    print("g")
    print(
        max(matrix(np.array(g_dd[:, 0, :].hi)) + matrix(np.array(g_dd[:, 0, :].lo)) - g)
    )
    print()

100%|██████████| 998/998 [00:00<00:00, 1389.99it/s]
100%|██████████| 998/998 [00:00<00:00, 1390.74it/s]


n = 1
x
2.34834079054852962992229412549877108400203436582642810933900799181089159574e-37
v
6.43946449320378036351349483547747779559555910489798612989405215550563655163e-34
at
4.22464011940065259829994259097632697732071042047544159456220274629025175774e-33
b
9.46633086265214166494074278339343475607847728321233226016032624273812015255e-32
g
1.002371623724490829224019161589800837667660992212613316943517311616157087e-31

n = 2
x
1.36343857246866705512245958550755103580006617332678224934054424232341879422e-33
v
4.31395374854456632074690910673281384845611208705211899516163809745290310642e-34
at
-3.98267429555774104781078277820365140664846681658093939456764578473117734983e-41
b
1.48770685827927497170115121926392049619988318640032044616275679596905930887e-31
g
1.002371623724490829224019161589800837667660992212613316943517311616157087e-31

n = 3
x
2.97523050199598992300976908010781886125126840281061488278409969697321640697e-39
v
5.3312679999572943525979819422994637608829609036234900372622135755