In [1]:
import sympy
from jax import Array, jit, grad

import jax.lax
import jax.numpy as jnp
import scipy.special as sps

In [2]:
from ggl import ggl

In [40]:
ggl(1)

(Array([-1.,  0.,  1.], dtype=float64),
 Array([0.33333333, 1.33333333, 0.33333333], dtype=float64),
 Array([[-1.5,  2. , -0.5],
        [-0.5,  0. ,  0.5],
        [ 0.5, -2. ,  1.5]], dtype=float64))

In [97]:
from functools import partial


@partial(jit, static_argnums=(0,))
def discretise(
        f: callable,
        qs: Array,
        ws: Array,
):
    """
    Computes the approximated integral of f using the discretisation (xs, ws).

    :param f: The function to be integrated using the discretisation. 
    :param qs: The points at which we will evaluate f.
    :param ws: The weights for each point in xs.
    :return: The value of the discretised integral.
    """
    return jnp.sum(jax.vmap(f)(qs) * ws, axis=0)

In [23]:
l = lambda x: x ** 2
discretise(lambda x: x ** 2, jnp.array([1, 2, 3]), jnp.array([1, 1, 1]))

Array(14, dtype=int64)

In [31]:
discretise(l, jnp.array([1, 2, 3]), jnp.array([1, 1, 1]))

Array(14, dtype=int64)

In [101]:
r = 3
dt = 0.1
reduction_factor = 2 / dt

xs, ws_reduced, dij_reduced = ggl(r)

In [102]:
ws = (1 / reduction_factor) * ws_reduced
dij = reduction_factor * dij_reduced

In [80]:
discretise(l, xs * 3, ws)

Array(0.3, dtype=float64)

In [83]:
import original.slimplectic_GGL
from sympy import Symbol, lambdify

In [168]:
ddt = Symbol('ddt')
q1 = Symbol('q1')
q2 = Symbol('q2')

expr, table = original.slimplectic_GGL.GGL_Gen_Ld(
    tsymbol=Symbol('t'),
    q_list=[q1, q2],
    qprime_list=[Symbol('qdot1'), Symbol('qdot2')],
    L=Symbol('t') + q1 + q2,
    r=r,
    ddt=ddt,
)

In [106]:
float((q1 + q2).evalf(subs={q1: 1, q2: 2}))

3.0

In [169]:
expr

0.05*ddt*(t + {q1^{[n]}} + {q2^{[n]}}) + 0.27222222222222222222*ddt*(0.3453463292920228562*ddt + t + {q1^{(1)}} + {q2^{(1)}}) + 0.35555555555555555556*ddt*(1.0*ddt + t + {q1^{(2)}} + {q2^{(2)}}) + 0.27222222222222222222*ddt*(1.6546536707079771438*ddt + t + {q1^{(3)}} + {q2^{(3)}}) + 0.05*ddt*(2.0*ddt + t + {q1^{[n+1]}} + {q2^{[n+1]}})

In [108]:
table

[[{q1^{[n]}}, {q1^{(1)}}, {q1^{(2)}}, {q1^{(3)}}, {q1^{[n+1]}}],
 [{q2^{[n]}}, {q2^{(1)}}, {q2^{(2)}}, {q2^{(3)}}, {q2^{[n+1]}}]]

In [109]:
q1_values = [1, 2, 3, 4, 5]
q2_values = [2, 4, 2, 1, 3]

In [111]:
lambdify([ddt, *table], expr)(dt, q1_values, q2_values)

0.5322222222222223

In [113]:
discretise(
    f=lambda qvec: qvec[0] + qvec[1],
    qs=jnp.array([q1_values, q2_values]).transpose(),
    ws=ws,
)

Array(0.53222222, dtype=float64)

In [88]:
lambdify([Symbol('x')], Symbol('x') ** 4)(2)

16

In [162]:
qvv = jnp.array([xs, xs]).transpose()

In [163]:
qvv

Array([[-1.00000000e+00, -1.00000000e+00],
       [-6.54653671e-01, -6.54653671e-01],
       [ 3.23815049e-17,  3.23815049e-17],
       [ 6.54653671e-01,  6.54653671e-01],
       [ 1.00000000e+00,  1.00000000e+00]], dtype=float64)

In [165]:
[*qvv.transpose()]

[Array([-1.00000000e+00, -6.54653671e-01,  3.23815049e-17,  6.54653671e-01,
         1.00000000e+00], dtype=float64),
 Array([-1.00000000e+00, -6.54653671e-01,  3.23815049e-17,  6.54653671e-01,
         1.00000000e+00], dtype=float64)]

In [134]:
from dataclasses import dataclass, field
from sympy import Expr


@dataclass
class TestCase:
    dt: float
    r: int

    """
    The JAX function which we will be discretising. This should take a single argument representing where we are in phase space.
    
    Signature: (qvec: [float; dof], qdotvec: [float; dof]: , t: float) -> float
    """
    func: callable

    """
    The 
    """
    expr: Expr

    variables: list[str] = field(default_factory=lambda: ['q1', 'q2'])
    variable_derivatives: list[str] = field(default_factory=lambda: ['q1_dot', 'q2_dot'])

    def validate(self):
        """
        Validates the coherence of the test case.
        """
        assert len(self.variables) == len(self.variable_derivatives)

        lambdify_args = [Symbol('ddt'), *self.variables, *self.variable_derivatives]
        eval_expr = lambdify(lambdify_args, self.expr)

    def compute_new(self, q_values: Array):
        """
        Computes the value of F_d under the GGL discretisation using the new code.
        
        :param q_values: Currently we are only testing the form of the expression so we pass in the values of q at the
                         different points manually. This should be a JAX array of shape (r + 2, func_arg).
        :return: 
        """
        # TODO: Currently we are onl
        _, ws_reduced, dij_reduced = ggl(self.r)
        ws = (1 / reduction_factor) * ws_reduced
        dij = reduction_factor * dij_reduced

        discretise(
            f=self.func,
            qs=jnp.array([q1_values, q2_values]).transpose(),
            ws=ws,
        )

In [136]:
TestCase(
    r=1,
    dt=2,
    expr=Symbol('q1') + Symbol('q2')
)

TestCase(dt=2, r=1, expr=q1 + q2, variables=['q1', 'q2'], variable_derivatives=['q1_dot', 'q2_dot'])

In [355]:
from typing import Callable


class System:
    def eval_at(self, qvec: Array, qdotvec: Array, t: float):
        raise NotImplementedError()

    def discretise(self, r: int, dt: float, qvec_values: Array, t0: float):
        raise NotImplementedError()


class SympySystem(System):
    def __init__(self, dof: int, builder: Callable[[list[Symbol], list[Symbol], Symbol], Expr]):
        self.dof = dof
        self.r = r
        self.dt = dt
        self.q_list = [Symbol(f'q{i}') for i in range(dof)]
        self.qdot_list = [Symbol(f'qdot{i}') for i in range(dof)]
        self.t_symbol = Symbol('t')
        self.dt_symbol = Symbol('ddt')
        self.expr = builder(self.q_list, self.qdot_list, self.t_symbol)
        self._eval = lambdify([*self.q_list, *self.qdot_list, self.t_symbol], self.expr)

    def eval_at(self, qvec: Array, qdotvec: Array, t: float):
        return self._eval(*qvec, *qdotvec, t)

    def discretise(self, r: int, dt: float, qvec_values: Array, t0: float):
        """
        Computes the values of F_d under the GGL discretisation with parameters r and dt.
        
        - qvec_values should be a JAX array of shape (r + 2, dof) representing the values of q at the different points.
        - t0 is the initial time at the start of the integration.
        """
        expr_d, table = original.slimplectic_GGL.GGL_Gen_Ld(
            tsymbol=self.t_symbol,
            q_list=self.q_list,
            qprime_list=self.qdot_list,
            L=self.expr,
            r=r,
            ddt=self.dt_symbol,
        )
        
        self.expr_d = expr_d
        self.d_table = table

        return lambdify([self.t_symbol, self.dt_symbol, *table], expr_d)(t0, self.dt, *qvec_values.transpose())


class JAXSystem(System):
    def __init__(self, function):
        self.function = function

    def eval_at(self, qvec: Array, qdotvec: Array, t: float):
        return self.function(qvec, qdotvec, t)

In [356]:
ss = (SympySystem(
    dof=2,
    builder=lambda qs, qdots, t: qs[0] + qs[1]
))

ss.discretise(
    r=3,
    dt=0.1,
    qvec_values=qvv,
    t0=0,
)

Array(-3.46944695e-17, dtype=float64)

In [357]:
ss.expr_d

0.27222222222222222222*ddt*({q0^{(1)}} + {q1^{(1)}}) + 0.35555555555555555556*ddt*({q0^{(2)}} + {q1^{(2)}}) + 0.27222222222222222222*ddt*({q0^{(3)}} + {q1^{(3)}}) + 0.05*ddt*({q0^{[n+1]}} + {q1^{[n+1]}}) + 0.05*ddt*({q0^{[n]}} + {q1^{[n]}})

In [361]:
ss.d_table[0][1]

{q0^{(1)}}

In [365]:
import sympy

In [366]:
ss.expr_d.subs(ss.d_table[0][1], sympy.Float(1))

0.35555555555555555556*ddt*({q0^{(2)}} + {q1^{(2)}}) + 0.27222222222222222222*ddt*({q0^{(3)}} + {q1^{(3)}}) + 0.05*ddt*({q0^{[n+1]}} + {q1^{[n+1]}}) + 0.05*ddt*({q0^{[n]}} + {q1^{[n]}}) + 0.27222222222222222222*ddt*({q1^{(1)}} + 1.0)

In [170]:
JAXSystem(
    function=lambda qs, qdots, t: qs[0] + qs[1]
).eval(
    q_values=jnp.array([1, 2]),
    qdot_values=jnp.array([3, 4]),
    t=5,
)

AttributeError: 'JAXSystem' object has no attribute 'eval'

In [183]:
dof = 3
print(f"dof = {dof}\nr = {r}")

dof = 3
r = 3


In [205]:
q_vec_values = jnp.array(
    [
        [3, 2, 3],  # q0
        [2, 3, 6],  # q_1
        [4, 5, 5],  # q_2
        [3, 3, 4],  # q_3
        [4, 1, 6],  # q_{r + 1}
    ]
)

In [191]:
assert q_vec_values.shape == (r + 2, dof)

In [207]:
jax.debug.print("{}", jnp.matmul(dij, q_vec_values.transpose()[1]))

[  13.33333333   84.28914006    7.5        -114.28914006 -103.33333333]


In [197]:
# Make the debug line wrap limit larger
jax.config.update("jax_log_compiles", True)

<jax._src.config.Config at 0x11012c190>

In [206]:
jnp.matmul(
    dij,
    q_vec_values
)

Array([[-198.46338311,   13.33333333,  296.92676622],
       [  40.09505738,   84.28914006,   70.09505738],
       [  19.23169155,    7.5       ,  -75.96338311],
       [ -25.37030969, -114.28914006,   50.45544726],
       [ 181.53661689, -103.33333333,  186.92676622]], dtype=float64)

In [209]:
lambdify([Symbol('x')], [Symbol('x') ** 4, Symbol('x') ** 4])(2)

[16, 16]

In [238]:
r = 5
a_qvec = jnp.arange((r + 2) * dof).reshape((r + 2, dof))
a_derivative_matrix = ggl(r)[2] * (2 / dt)

In [371]:
jnp.matmul(a_derivative_matrix, a_qvec)

Array([[471.03684468, 471.03684468, 471.03684468],
       [258.57785266, 258.57785266, 258.57785266],
       [122.63280804, 122.63280804, 122.63280804],
       [135.52563577, 135.52563577, 135.52563577],
       [122.63280804, 122.63280804, 122.63280804],
       [258.57785266, 258.57785266, 258.57785266],
       [471.03684468, 471.03684468, 471.03684468]], dtype=float64)

In [372]:
with jnp.printoptions(precision=2):
    print(a_derivative_matrix)

[[-210.    284.03 -113.38   64.    -41.     26.35  -10.  ]
 [ -48.86    0.     69.12  -31.97   19.23  -12.04    4.53]
 [  12.51  -44.32    0.     45.33  -21.33   12.33   -4.52]
 [  -6.25   18.15  -40.14    0.     40.14  -18.15    6.25]
 [   4.52  -12.33   21.33  -45.33    0.     44.32  -12.51]
 [  -4.53   12.04  -19.23   31.97  -69.12    0.     48.86]
 [  10.    -26.35   41.    -64.    113.38 -284.03  210.  ]]


In [383]:
from original.slimplectic_GGL import GGLdefs, GGL_q_Collocation_Table, DM_Sum

DM = GGLdefs(r)[2]
q_list = [Symbol(f'q_{i}') for i in range(r + 2)]
q_table = GGL_q_Collocation_Table(q_list, r + 2)
ddt_symbol = Symbol('ddt')

dphidt_Table = []
for qs in a_qvec:
    dphidt_Table.append([DM_Sum(DMvec, qs) * 2 / dt for DMvec in DM])

In [384]:
dphidt_Table

[[135.96846794160366911,
  97.717040569771579083,
  19.305819042785740631,
  -5.6508894253764180624,
  3.2838406726486043200,
  -2.9804687768979001214,
  6.3474687140486894262],
 [135.96846794160366909,
  97.717040569771579084,
  19.305819042785740633,
  -5.6508894253764180624,
  3.2838406726486043182,
  -2.9804687768979001222,
  6.3474687140486894430],
 [135.96846794160366908,
  97.717040569771579084,
  19.305819042785740634,
  -5.6508894253764180623,
  3.2838406726486043167,
  -2.9804687768979001225,
  6.3474687140486894560],
 [135.96846794160366906,
  97.717040569771579085,
  19.305819042785740636,
  -5.6508894253764180623,
  3.2838406726486043152,
  -2.9804687768979001236,
  6.3474687140486894712],
 [135.96846794160366905,
  97.717040569771579086,
  19.305819042785740637,
  -5.6508894253764180623,
  3.2838406726486043143,
  -2.9804687768979001236,
  6.3474687140486894842],
 [135.96846794160366903,
  97.717040569771579086,
  19.305819042785740639,
  -5.6508894253764180623,
  3.28384

In [373]:
from ggl.test import floatify2

with jnp.printoptions(precision=2):
    print(jnp.array(floatify2(dphidt_Table)))

[[135.97  97.72  19.31  -5.65   3.28  -2.98   6.35]
 [135.97  97.72  19.31  -5.65   3.28  -2.98   6.35]
 [135.97  97.72  19.31  -5.65   3.28  -2.98   6.35]
 [135.97  97.72  19.31  -5.65   3.28  -2.98   6.35]
 [135.97  97.72  19.31  -5.65   3.28  -2.98   6.35]
 [135.97  97.72  19.31  -5.65   3.28  -2.98   6.35]
 [135.97  97.72  19.31  -5.65   3.28  -2.98   6.35]]


In [367]:
subs = []
for d in range(dof):
    for i in range(r + 2):
        subs.append((q_table[i][d], sympy.Float(float(a_qvec[i, d]))))

In [368]:
subs

[({q_0^{[n]}}, 0.0),
 ({q_1^{[n]}}, 3.00000000000000),
 ({q_2^{[n]}}, 6.00000000000000),
 ({q_3^{[n]}}, 9.00000000000000),
 ({q_4^{[n]}}, 12.0000000000000),
 ({q_5^{[n]}}, 15.0000000000000),
 ({q_6^{[n]}}, 18.0000000000000),
 ({q_0^{(1)}}, 1.00000000000000),
 ({q_1^{(1)}}, 4.00000000000000),
 ({q_2^{(1)}}, 7.00000000000000),
 ({q_3^{(1)}}, 10.0000000000000),
 ({q_4^{(1)}}, 13.0000000000000),
 ({q_5^{(1)}}, 16.0000000000000),
 ({q_6^{(1)}}, 19.0000000000000),
 ({q_0^{(2)}}, 2.00000000000000),
 ({q_1^{(2)}}, 5.00000000000000),
 ({q_2^{(2)}}, 8.00000000000000),
 ({q_3^{(2)}}, 11.0000000000000),
 ({q_4^{(2)}}, 14.0000000000000),
 ({q_5^{(2)}}, 17.0000000000000),
 ({q_6^{(2)}}, 20.0000000000000)]

In [381]:
with jnp.printoptions(precision=2):
    print(jnp.array(floatify2([[dphidt_Table[i][d].subs(subs) for i in range(r + 2)] for d in range(dof)])))

[[135.97 135.97 135.97 135.97 135.97 135.97 135.97]
 [ 97.72  97.72  97.72  97.72  97.72  97.72  97.72]
 [ 19.31  19.31  19.31  19.31  19.31  19.31  19.31]]


In [382]:
dphidt_Table

[[135.96846794160366911,
  97.717040569771579083,
  19.305819042785740631,
  -5.6508894253764180624,
  3.2838406726486043200,
  -2.9804687768979001214,
  6.3474687140486894262],
 [135.96846794160366909,
  97.717040569771579084,
  19.305819042785740633,
  -5.6508894253764180624,
  3.2838406726486043182,
  -2.9804687768979001222,
  6.3474687140486894430],
 [135.96846794160366908,
  97.717040569771579084,
  19.305819042785740634,
  -5.6508894253764180623,
  3.2838406726486043167,
  -2.9804687768979001225,
  6.3474687140486894560],
 [135.96846794160366906,
  97.717040569771579085,
  19.305819042785740636,
  -5.6508894253764180623,
  3.2838406726486043152,
  -2.9804687768979001236,
  6.3474687140486894712],
 [135.96846794160366905,
  97.717040569771579086,
  19.305819042785740637,
  -5.6508894253764180623,
  3.2838406726486043143,
  -2.9804687768979001236,
  6.3474687140486894842],
 [135.96846794160366903,
  97.717040569771579086,
  19.305819042785740639,
  -5.6508894253764180623,
  3.28384

In [311]:
a = dphidt_Table[1][1].subs([(Symbol('q_1^{(3)}'), 1), (ddt_symbol, dt)])

In [370]:
dphidt_Table[1][1].subs([(list(a.atoms())[5], 1), (ddt_symbol, dt)])

97.717040569771579084

In [325]:
list(a.atoms())[6]

{q_1^{(5)}}

In [340]:
original.slimplectic_GGL.GGL_Gen_Ld(
    tsymbol=Symbol('t'),
    q_list=[q1],
    qprime_list=[Symbol('qdot1')],
    L=Symbol('qdot1'),
    r=r,
    ddt=ddt,
)

(-2.0117032497289633053e-21*{q1^{(1)}} - 1.1011428314305904408e-20*{q1^{(2)}} + 4.2351647362715016953e-22*{q1^{(3)}} + 1.1011428314305904408e-20*{q1^{(4)}} + 2.5410988417629010172e-21*{q1^{(5)}} + 0.99999999999999999999*{q1^{[n+1]}} - 0.99999999999999999999*{q1^{[n]}},
 [[{q1^{[n]}},
   {q1^{(1)}},
   {q1^{(2)}},
   {q1^{(3)}},
   {q1^{(4)}},
   {q1^{(5)}},
   {q1^{[n+1]}}]])

In [341]:
dphidt_Table

[[135.96846794160366911,
  97.717040569771579083,
  19.305819042785740631,
  -5.6508894253764180624,
  3.2838406726486043200,
  -2.9804687768979001214,
  6.3474687140486894262],
 [135.96846794160366909,
  97.717040569771579084,
  19.305819042785740633,
  -5.6508894253764180624,
  3.2838406726486043182,
  -2.9804687768979001222,
  6.3474687140486894430],
 [135.96846794160366908,
  97.717040569771579084,
  19.305819042785740634,
  -5.6508894253764180623,
  3.2838406726486043167,
  -2.9804687768979001225,
  6.3474687140486894560],
 [135.96846794160366906,
  97.717040569771579085,
  19.305819042785740636,
  -5.6508894253764180623,
  3.2838406726486043152,
  -2.9804687768979001236,
  6.3474687140486894712],
 [135.96846794160366905,
  97.717040569771579086,
  19.305819042785740637,
  -5.6508894253764180623,
  3.2838406726486043143,
  -2.9804687768979001236,
  6.3474687140486894842],
 [135.96846794160366903,
  97.717040569771579086,
  19.305819042785740639,
  -5.6508894253764180623,
  3.28384