In [85]:
import jax
import jax.numpy as jnp
from jax.experimental import checkify

from ggl import ggl, dereduce


def compute_qdot_from_q(qi_vec, r, dt):
    dij = dereduce(ggl(r), dt)[2]
    return jax.numpy.matmul(dij, qi_vec)


def discretise(
        r: int,
        dt: float,
        fn: callable,
) -> callable:
    """
    :param r: The order of the method.
    :param fn: The function to discretise.
    :return: A callable which takes a vector of q values and returns a vector of q dot values.
    """

    xs, ws, dij = dereduce(ggl(r), dt)

    def discretised_fn(qi_vec):
        # print(qi_vec.shape)
        qidot_vec = jax.numpy.matmul(dij, qi_vec)
        t0 = 0
        t_values = t0 + (1 + xs) * dt / 2

        return jnp.dot(ws, jax.vmap(
            fn
        )(
            qi_vec,
            qidot_vec,
            t_values,
        ))

    return discretised_fn


In [49]:
df = discretise(1, 1, lambda x, xdot, t: x ** 2 + xdot ** 2)

In [50]:
df(
    jnp.array([1, 2, 3])
)

(3,)


Array(8.33333333, dtype=float64)

In [51]:
discretise(
    r=3,
    dt=0.1,
    fn=lambda q_vec, q_dot_vec, t: jnp.dot(q_vec, q_vec) + jnp.dot(q_dot_vec, q_dot_vec)
)(
    jnp.array([
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
    ]).transpose()
)

(5, 3)


Array(548.61366091, dtype=float64)

In [65]:
data = jnp.array([
    [1, 2, 3, 4, 5],
    [1, 2, 3, 4, 5],
    [1, 2, 3, 4, 5],
])

In [56]:
from sympy import Symbol, Float
import original.slimplectic_GGL

expr, table = original.slimplectic_GGL.GGL_Gen_Ld(
    tsymbol=Symbol('t'),
    q_list=[Symbol('q1'), Symbol('q2'), Symbol('q3')],
    qprime_list=[Symbol('qdot1'), Symbol('qdot2'), Symbol('qdot3')],
    r=3,
    ddt=Float(0.1),
    L=(
            Symbol('q1') ** 2 + Symbol('q2') ** 2 + Symbol('q3') ** 2 +
            Symbol('qdot1') ** 2 + Symbol('qdot2') ** 2 + Symbol('qdot3') ** 2
    )
)

In [57]:
expr

0.027222222222222223733*{q1^{(1)}}**2 + 0.035555555555555557529*{q1^{(2)}}**2 + 0.027222222222222223733*{q1^{(3)}}**2 + 0.0050000000000000002776*{q1^{[n+1]}}**2 + 0.0050000000000000002776*{q1^{[n]}}**2 + 0.027222222222222223733*{q2^{(1)}}**2 + 0.035555555555555557529*{q2^{(2)}}**2 + 0.027222222222222223733*{q2^{(3)}}**2 + 0.0050000000000000002776*{q2^{[n+1]}}**2 + 0.0050000000000000002776*{q2^{[n]}}**2 + 0.027222222222222223733*{q3^{(1)}}**2 + 0.035555555555555557529*{q3^{(2)}}**2 + 0.027222222222222223733*{q3^{(3)}}**2 + 0.0050000000000000002776*{q3^{[n+1]}}**2 + 0.0050000000000000002776*{q3^{[n]}}**2 + 25.407407407407408818*(-{q1^{(1)}} + {q1^{(3)}} - 0.28056585887484734734*{q1^{[n+1]}} + 0.28056585887484734734*{q1^{[n]}})**2 + 33.185185185185187027*(0.4375*{q1^{(1)}} - {q1^{(2)}} + 0.71086647140211000062*{q1^{[n+1]}} - 0.14836647140211000062*{q1^{[n]}})**2 + 33.185185185185187027*({q1^{(2)}} - 0.4375*{q1^{(3)}} + 0.14836647140211000062*{q1^{[n+1]}} - 0.71086647140211000062*{q1^{[n]}

In [76]:
subs_list = []
for i in range(len(table)):
    for j in range(len(table[0])):
        subs_list.append([table[i][j], data.tolist()[i][j]])
subs_list

[[{q1^{[n]}}, 1],
 [{q1^{(1)}}, 2],
 [{q1^{(2)}}, 3],
 [{q1^{(3)}}, 4],
 [{q1^{[n+1]}}, 5],
 [{q2^{[n]}}, 1],
 [{q2^{(1)}}, 2],
 [{q2^{(2)}}, 3],
 [{q2^{(3)}}, 4],
 [{q2^{[n+1]}}, 5],
 [{q3^{[n]}}, 1],
 [{q3^{(1)}}, 2],
 [{q3^{(2)}}, 3],
 [{q3^{(3)}}, 4],
 [{q3^{[n+1]}}, 5]]

In [77]:
expr.subs(subs_list)

548.61366091064408121

In [82]:
from jax import Array
from typing import Callable
from sympy import Expr


def perform_sympy_calc(
        r: int,
        dt: float,
        dof: int,
        expr_builder: Callable[[list[Symbol], list[Symbol], Symbol], Expr],
):
    q_list = [Symbol(f'q{i}') for i in range(dof)]
    qprime_list = [Symbol(f'qdot{i}') for i in range(dof)]
    t_symbol = Symbol('t')

    expr, table = original.slimplectic_GGL.GGL_Gen_Ld(
        tsymbol=t_symbol,
        q_list=q_list,
        qprime_list=qprime_list,
        r=r,
        ddt=Float(dt),
        L=expr_builder(q_list, qprime_list, t_symbol)
    )

    def fn(q_vec: Array, t: float = 0):
        assert q_vec.shape == (dof, r + 2)

        subs_list = []
        for i in range(len(table)):
            for j in range(len(table[0])):
                subs_list.append([table[i][j], q_vec.tolist()[i][j]])

        return float(expr.subs(subs_list))

    return fn

In [87]:
perform_sympy_calc(
    r=3,
    dt=0.1,
    dof=3,
    expr_builder=lambda q_list, qprime_list, t_symbol: (
            q_list[0] ** 2 + q_list[1] ** 2 + q_list[2] ** 2 +
            qprime_list[0] ** 2 + qprime_list[1] ** 2 + qprime_list[2] ** 2
    )
)(
    jnp.array([
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
    ])
)

548.6136609106441

In [86]:
discretise(
    r=3,
    dt=0.1,
    fn=lambda q_vec, q_dot_vec, t: jnp.dot(q_vec, q_vec) + jnp.dot(q_dot_vec, q_dot_vec)
)(
    jnp.array([
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
    ]).transpose()
)

Array(548.61366091, dtype=float64)

In [88]:
548.6136609106441 - 548.61366091

6.440359356929548e-10