In [112]:
import jax
import jax.numpy as jnp
from tqdm import tqdm
import timeit

In [113]:
def henon_heiles(t, coords):
    return jnp.array(
        [
            coords[2],
            coords[3],
            -coords[0] * (1 + 2 * coords[1]),
            -(coords[1] + coords[0] ** 2 - coords[1] ** 2),
        ]
    )

def evolve_rk4(coords, _):
    k_1 = 0.01 * henon_heiles(0, coords)
    k_2 = 0.01 * henon_heiles(0, coords + (1 / 2) * k_1)
    k_3 = 0.01 * henon_heiles(0, coords + (1 / 2) * k_2)
    k_4 = 0.01 * henon_heiles(0, coords + k_3)

    new_coords = coords + (1 / 6) * (k_1 + 2 * k_2 + 2 * k_3 + k_4)

    return new_coords, new_coords


In [114]:
def calculate_intial_conditions(H, x_0, y_0, py_0):
    return jnp.array(
        [
            x_0,
            y_0,
            jnp.sqrt(
                2 * H
                - 2 * (x_0**2) * y_0
                + 2 * (y_0**3) / 3
                - (x_0**2)
                - (y_0**2)
                - (py_0**2)
            ),
            py_0,
        ]
    )

In [115]:
H = 0.125
dt = 0.01

x_0 = 0
y_0 = -0.25
py_0 = 0

init_coords = calculate_intial_conditions(H, x_0, y_0, py_0)

In [116]:
length = 5_000

In [117]:
%timeit jax.lax.scan(evolve_rk4, init_coords, length=length)

488 μs ± 1.02 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [118]:
def rollout_loop(init_coords):
    coords = init_coords
    for _ in range(length):
        coords, _ = evolve_rk4(coords, _)
    return 


In [119]:
%timeit rollout_loop(init_coords)

KeyboardInterrupt: 