In [2]:
import jax.random

import original.slimplectic_GGL

In [132]:
import jax
import jax.numpy as jnp
from ggl import ggl, dereduce

In [9]:
from sympy import Symbol, Float

t_sym = Symbol('t')
ddt = Float(1)
q1 = Symbol('q')
qdot_1 = Symbol('qdot_1')

In [210]:
def example_sympy_computation():
    expr, table = original.slimplectic_GGL.GGL_Gen_Ld(
        tsymbol=t_sym,
        q_list=[q1],
        qprime_list=[qdot_1],
        ddt=ddt,
        r=2,
        L=qdot_1
    )
    
    print("Expression\n\t", expr)
    print("Variables\n\t", *table[0])

In [211]:
example_sympy_computation()

Expression
	 8.4703294725430033907e-22*{q^{(1)}} - 8.4703294725430033907e-22*{q^{(2)}} + 1.0*{q^{[n+1]}} - 1.0*{q^{[n]}}
Variables
	 {q^{[n]}} {q^{(1)}} {q^{(2)}} {q^{[n+1]}}


In [163]:
def compute_qdot_from_q(qi_vec, r, dt):
    dij = dereduce(ggl(r), dt)[2]
    return jax.numpy.matmul(dij, qi_vec)

In [217]:
import itertools
list(itertools.product([1, 2, 3], [1, 2, 3]))

[(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3), (3, 1), (3, 2), (3, 3)]

In [218]:
def test_co():
    for (i, r) in itertools.product(range(20), range(10)):
        random_vec = jax.random.normal(
            key=jax.random.PRNGKey(i),
            shape=(r + 2,)
        )

        total_value = jnp.dot(
            dereduce(ggl(r), dt=1)[1],
            compute_qdot_from_q(random_vec, r=r, dt=1)
        )

        expr, table = original.slimplectic_GGL.GGL_Gen_Ld(
            tsymbol=t_sym,
            q_list=[q1],
            qprime_list=[qdot_1],
            ddt=ddt,
            r=2,
            L=qdot_1
        )

        sympy_value = float(expr.subs([
            (table[0][i], random_vec[i])
            for i in range(len(table[0]))
        ]))

        print(r, total_value, sympy_value, total_value - sympy_value)

        # assert jnp.allclose(
        #     total_value,
        #     sympy_value
        # )

In [219]:
test_co()

0 -2.5709715171383483 -2.5709715171383483 0.0
1 0.46157418893564867 0.4615741889356487 -5.551115123125783e-17
2 1.5881063522255654 1.588106352225556 9.325873406851315e-15
3 -0.06819276935801118 -0.3647089804222698 0.29651621106425863
4 -0.49754381136327147 -0.5916938645500794 0.09415005318680797
5 -0.2895608955670682 -0.05926896957880379 -0.23029192598826442
6 0.9586618369435016 1.0361223303401836 -0.07746049339668204
7 -0.9654472676740801 1.0973920109493778 -2.062839278623458
8 0.7131875811421223 -1.0864018842672198 1.7995894654093423
9 0.9064512674142023 2.3361572939633675 -1.429706026549165
0 -0.6335320747954701 -0.6335320747954701 0.0
1 0.04975925237359332 0.04975925237359313 1.8735013540549517e-16
2 2.5542496518352107 2.554249651835224 -1.3322676295501878e-14
3 -0.8867894110539497 -0.5697204975878822 -0.31706891346606747
4 -1.5138672609633106 -0.21902099882049034 -1.2948462621428203
5 -0.9789654621050793 -0.027622217791519743 -0.9513432443135595
6 -0.9892310137718912 0.60497409935

In [189]:
compute_qdot_from_q(jax.numpy.array([1, 2, 3, 1, 1]), r=3, dt=1)

Array([ 2.84633831,  6.98297249, -2.67316916, -5.45544726,  7.84633831],      dtype=float64)

In [194]:
_, ws, dij = dereduce(ggl(r=0), dt=1)

jax.numpy.matmul(
    dij,
    ws
)

Array([0., 0.], dtype=float64)

In [195]:
original.slimplectic_GGL.GGL_Gen_Ld(
    tsymbol=t_sym,
    q_list=[q1],
    qprime_list=[qdot_1],
    ddt=Float(1.0),
    r=0,
    L=qdot_1
)

(1.0*{q^{[n+1]}} - 1.0*{q^{[n]}}, [[{q^{[n]}}, {q^{[n+1]}}]])

In [122]:
compute_qdot_from_q(1, jax.numpy.array([1, 2, 3]))

Array([1., 1., 1.], dtype=float64)