A comparison between the initial behaviours of the original and the new implementation.

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
import jax

import slimpletic as st

assert jax.config.read('jax_enable_x64')

Here we setup the parameters for the system.

In [None]:
# System parameters, used in both methods
m = 1.0
k = 100.0
ll = 1e-4 * np.sqrt(m * k)  # ll is $\lambda$ in the paper

# Simulation and Method parameters
dt = 0.1 * np.sqrt(m / k)
t_sample_count = 1
tmax = t_sample_count * np.sqrt(m / k)
t0 = 0
t = t0 + dt * np.arange(0, t_sample_count + 1)
r = 0

# Initial conditions
q0 = [1.]
pi0 = [0.25 * dt * k]

Now the system dynamics

In [None]:
def lagrangian_f(q, qdot, t):
    return 0.5 * m * jnp.dot(qdot, qdot) - 0.5 * k * jnp.dot(q, q)

In [None]:
def create_original():
    from sympy import Symbol
    from original import slimplectic

    dho = slimplectic.GalerkinGaussLobatto('t', ['q'], ['v'])
    L = 0.5 * m * np.dot(dho.v, dho.v) - 0.5 * k * np.dot(dho.q, dho.q)
    # DHO:
    K = -ll * np.dot(dho.vp, dho.qm)
    # No damping:
    K_nd = Symbol('a')
    dho.discretize(L, K_nd, r, method='implicit', verbose=False)
    return dho


dho = create_original()

In [None]:
"""
We need to normalise the format of the results to compare them and be able to plot them nicely.
"""


def format_original(original_results):
    return np.vstack([
        np.array(original_results[0]),
        np.array(original_results[1])
    ]).T


def format_jax(jax_results):
    return np.vstack(
        [
            np.array(jax_results[0]),
            np.array(jax_results[1])
        ],
    ).T

## Original Implementation

In [None]:
*original_results, integrate_debug_escape_info = dho.integrate(q0, pi0, t)
original_results_fmt = format_original(original_results)
original_results_fmt

## JAX Implementation

In [None]:
from slimpletic import Solver

solver = Solver(r=r, dt=dt, lagrangian=lagrangian_f)

solver.integrate(q0, pi0, t0, t_sample_count)


In [None]:
*jax_results, jax_debug_info = st.iterate(
    lagrangian=lagrangian_f,
    q0=jnp.array(q0),
    pi0=jnp.array(pi0),
    dt=dt,
    t0=0,
    t_sample_count=t_sample_count,
    r=r,
    debug=True
)

jax_results_fmt = format_jax(jax_results)
jax_results_fmt

In [None]:
plt.plot(t, original_results_fmt[:, 0], label='sympy q', color='C0')
plt.scatter(t, jax_results_fmt[:, 0], label='jax q', color='C0', marker='x')
plt.plot(t, original_results_fmt[:, 1], label='sympy $\pi$', color='C1', )
plt.scatter(t, jax_results_fmt[:, 1], label='jax $\pi$', color='C1', marker='x')
plt.legend()

In [None]:
dho._pi_np1_map(*integrate_debug_escape_info['_pi_np1_map_args'][1])

In [None]:
dho.debug_escape_info

In [None]:
integrate_debug_escape_info

# THESE SHOULD BE EQUAL WTF

In [None]:
jax_debug_info['compute_qi_values'](
    jnp.array(q0),
    jnp.array(pi0),
    t0
)

In [None]:
1.0, 0.9975

In [None]:
0.9975, 0.9850250000000002

In [None]:
jax_debug_info['residue'](
    jax_debug_info['compute_qi_values'](
        jnp.array(q0),
        jnp.array(pi0),
        t0
    ), t0, jnp.array(pi0))