# scratch work

In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

# from decimal import Decimal, getcontext
# import mpmath

# getcontext().prec = 50

from jorbit.utils.doubledouble import DoubleDouble
from jorbit.integrators.ias15_dd import step, initialize_ias15_helper, acceleration_func

In [2]:
from jorbit.utils.generate_coefficients import create_ias15_constants

h, r, c, d = create_ias15_constants(7)
h_arr = jnp.array([float(x) for x in h])
r_arr = jnp.array([float(x) for x in r])
c_arr = jnp.array([float(x) for x in c])
d_arr = jnp.array([float(x) for x in d])

100%|██████████| 998/998 [00:00<00:00, 1399.22it/s]


In [3]:
from jorbit.data.constants import IAS15_H, IAS15_RR, IAS15_C, IAS15_D

jnp.max(jnp.abs(h_arr - IAS15_H)), jnp.max(jnp.abs(r_arr - IAS15_RR)), jnp.max(
    jnp.abs(c_arr - IAS15_C)
), jnp.max(jnp.abs(d_arr - IAS15_D))

(Array(0., dtype=float64),
 Array(0., dtype=float64),
 Array(0., dtype=float64),
 Array(0., dtype=float64))

In [5]:
from decimal import Decimal, getcontext

getcontext().prec = 50
# https://archive.org/details/gaussianquadratu00stro/page/340/mode/2up
a = Decimal("-0.88747487892615570706")

# that's in the range of -1 to 1: rescale a to 0, 1:
a = (a + 1) / 2
a  # first h spacing in ias15

Decimal('0.05626256053692214647')

In [45]:
import mpmath

mpmath.mp.dps = 50

n = 7  # number of internal evaluations, match # https://archive.org/details/gaussianquadratu00stro/page/340/mode/2up

n += 1  # include the endpoint
f = lambda x: (mpmath.legendre(n - 1, x) + mpmath.legendre(n, x)) / (x + 1)

f(mpmath.mpf("-0.88747487892615570706"))

mpf('2.6129705657498202708866397685208784576275168384388967e-19')

In [71]:
n = 8
f = lambda x: (mpmath.legendre(n - 1, x) + mpmath.legendre(n, x)) / (x + 1)

q = mpmath.findroot(
    f,
    # (mpmath.mpf("-0.91"), mpmath.mpf("-0.9")),
    mpmath.mpf("-0.9"),
    solver="secant",
    verbose=True,
    tol=mpmath.mpf("1e-35"),
)
(q + 1) / 2, q, f(q)

x:     -0.69193902674528838467959497127447044414231087727768994856
error: 0.25
x:     -0.6403159002821685726486482447939496367171874398699295483
error: 0.041939026745288384679594971274470444142310877277689681282
x:     -0.63958535256197696366509592815812349517352632160701004949
error: 0.051623126463119812030946726480520807425123437407760400263
x:     -0.63951874711267790029297906718333714397659589903188541886
error: 0.00073054772019160898355231663582614154366111826291949880676
x:     -0.63951861654772304091575879249617545882602800579883895461
error: 0.000066605449299063372116860974786351196930422575124630633427
x:     -0.63951861652621527695960887312873264849592089718170265484
error: 0.00000013056495485937722027468716168515056789323304646424575843
x:     -0.63951861652621527002484011474990631465657036616309033583
error: 0.000000000021507763956149919367442810330107108617136299771636264405
x:     -0.63951861652621527002484011438163643091587876469481345522
error: 6.93476875837882633383935

(mpf('0.18024069173689236498757994280918178454206062080547029'),
 mpf('-0.63951861652621527002484011438163643091587875838905969'),
 mpf('0.0'))

In [91]:
import mpmath
from tqdm import tqdm


def generate_gauss_radau_spacings(internal_points: int):
    # for ias15 H spacings, internal_points = 7
    # matches https://archive.org/details/gaussianquadratu00stro/page/340/mode/2up
    mpmath.mp.dps = 60

    n = internal_points + 1  # include the endpoint
    f = lambda x: (mpmath.legendre(n - 1, x) + mpmath.legendre(n, x)) / (x + 1)

    slices = mpmath.linspace(-1, 1, 1000)

    sols = []
    for i in tqdm(range(1, len(slices) - 1)):
        try:
            s = mpmath.findroot(
                f,
                (slices[i], slices[i + 1]),
                solver="secant",
                tol=mpmath.mpf("1e-50"),
            )
            possibly_new = s not in sols
            if possibly_new:
                for sol in sols:
                    if mpmath.fabs(sol - s) < mpmath.mpf("1e-50"):
                        possibly_new = False
                        break
            if possibly_new:
                assert mpmath.fabs(f(s)) < mpmath.mpf("1e-50")
                sols.append(s)

            del s
        except:
            pass

    assert len(sols) == n - 1

    sols.sort()
    sols = [((s + 1) / 2) for s in sols]  # rescale to 0, 1 from the range -1, 1
    return sols


generate_gauss_radau_spacings(7)

100%|██████████| 998/998 [00:00<00:00, 1543.05it/s]


[mpf('0.0562625605369221464656521910323111757797655147446230835724427929'),
 mpf('0.18024069173689236498757994280918178454206062080547015532262555'),
 mpf('0.35262471711316963737390777017124120280802188305727478239920932'),
 mpf('0.547153626330555383001448557652348854640385927899147541308009332'),
 mpf('0.734210177215410531523210608306610002563003118594390105979320279'),
 mpf('0.885320946839095768090359762932485372922270175468028054611010279'),
 mpf('0.977520613561287501891174500429154940077826092764399610140715742')]

In [3]:
b = initialize_ias15_helper(1)
# print(b.p0.hi)

x0 = DoubleDouble(jnp.array([1.0, 0.0, 0.0]))
v0 = DoubleDouble(jnp.array([0.0, 1.0, 0.0]))

a0 = acceleration_func(x0)

step(x0, v0, a0, b)

# print(b.p0.hi)

5.649233359601495e-08, -5.388037824036623e-25
5.649320835508429e-08, -1.3876614877682261e-24
8.747592999253804e-13, -2.238183808514221e-30
2.213840311442581e-18, 1.8152355289655439e-34
3.2412436268176264e-24, 1.3856550845946262e-40
3.8274007215499996e-30, 8.068751029066524e-47
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0


(<jorbit.integrators.ias15_dd.IAS15Helper at 0x10d9510f0>,
 <jorbit.integrators.ias15_dd.IAS15Helper at 0x129285e10>,
 DoubleDouble(0.0, 0.0),
 DoubleDouble(0.0, 0.0))

In [8]:
a = DoubleDouble(3.0)
b = DoubleDouble(7.0)

c = b / a
c

DoubleDouble(2.3333333333333335, -1.4802973661668753e-16)

In [9]:
getcontext().prec = 100

a = Decimal(3.0)
b = Decimal(7.0)

q = b / a
q

Decimal('2.333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333')

In [10]:
print(q - Decimal(float(c.hi)))
print(q - (Decimal(float(c.hi)) + Decimal(float(c.lo))))

-1.48029736616687538723150889078776041666666666666666666666666666666666666666666666667E-16
-8.217301096052206306372172555029023225762567032385656299690405527751E-33


In [None]:
z = jnp.array([[1.0, 2.0, 1.0, 1.2], [1.0, 3.0, 1.0, 1.2]])
jnp.max(z, axis=0)

Array([1. , 3. , 1. , 1.2], dtype=float64)

In [None]:
a = DoubleDouble(
    hi=jnp.array([[1.0, 2.0, 1.0, 1.2], [1.0, 3.0, 1.0, 1.2]]),
    lo=jnp.array([[1.0, 2.0, 1.0, 1.2], [1.0, 3.0, 1.0, 1.2]]) * 1e-14,
)

a.dd_max(axis=0)

DoubleDouble([1.  3.  1.  1.2], [1.0e-14 3.0e-14 1.0e-14 1.2e-14])