In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

import math
from functools import partial
from typing import Callable, Tuple

import diffrax
import equinox as eqx
import matplotlib.animation
import matplotlib.lines
import matplotlib.patches
import matplotlib.pyplot as plt
import optax
from jaxtyping import Array, ArrayLike
from tqdm.auto import trange

In [None]:
def cartpole_ode(t: ArrayLike, y: Array, args: Array):
    f = args
    f = 10 * f(y)[0]

    g = 9.81
    mass_cart = 1.0
    mass_pole = 0.1
    length_pole = 0.5

    x = y[0]
    theta = y[1]
    dot_x = y[2]
    dot_theta = y[3]

    ddot_theta = (
        g * jnp.sin(theta)
        + jnp.cos(theta)
        * (
            (-f - mass_pole * length_pole * jnp.square(dot_theta) * jnp.sin(theta))
            / (mass_cart + mass_pole)
        )
    ) / (
        length_pole
        * (4 / 3 - (mass_pole * jnp.square(jnp.cos(theta)) / (mass_cart + mass_pole)))
    )

    ddot_x = (
        f
        + mass_pole
        * length_pole
        * (jnp.square(dot_theta) * jnp.sin(theta) - ddot_theta * jnp.cos(theta))
    ) / (mass_cart + mass_pole)

    dy = jnp.stack((dot_x, dot_theta, ddot_x, ddot_theta), axis=-1)
    return dy

In [None]:
y0 = jnp.asarray([0.0, 0.1, 0.0, 0.0])
sol = diffrax.diffeqsolve(
    terms=diffrax.ODETerm(cartpole_ode),
    solver=diffrax.Dopri5(),
    t0=0.0,
    t1=10.0,
    dt0=0.01,
    y0=y0,
    args=jnp.zeros(1),
    saveat=diffrax.SaveAt(ts=jnp.linspace(0.0, 10.0, 1024)),
    stepsize_controller=diffrax.PIDController(
        rtol=1e-5, atol=1e-5, pcoeff=0.3, icoeff=0.3
    ),
)

# Active Control Test

In [None]:
key = jax.random.PRNGKey(1234)

key, subkey = jax.random.split(key)
control = eqx.nn.MLP(
    in_size=4,
    out_size=1,
    width_size=64,
    depth=2,
    final_activation=jax.nn.sigmoid,
    key=subkey,
)

optimizer = optax.adam(learning_rate=1e-4)
opt_state = optimizer.init(params=eqx.filter(control, eqx.is_array))


def reward_fn(ys: Array) -> float:
    return -jnp.mean(jnp.abs(ys[..., 1]))


@eqx.filter_jit
def eval_traj(control: eqx.Module, key: jax.random.KeyArray):
    y0 = jax.random.uniform(
        key,
        shape=(4,),
        minval=-0.05,
        maxval=0.05,
        # minval=jnp.asarray([-0.5, -0.1, -5.0, -1.0]),
        # maxval=jnp.asarray([0.5, 0.1, 5.0, 1.0]),
    )

    sol = diffrax.diffeqsolve(
        terms=diffrax.ODETerm(cartpole_ode),
        solver=diffrax.Dopri5(),
        t0=0.0,
        t1=10.0,
        dt0=0.01,
        y0=y0,
        args=control,
        saveat=diffrax.SaveAt(ts=jnp.linspace(0.0, 10.0, 1024)),
        stepsize_controller=diffrax.PIDController(
            rtol=1e-5, atol=1e-5, pcoeff=0.3, icoeff=0.3
        ),
    )

    return sol


@eqx.filter_value_and_grad
def eval_reward(control: eqx.Module, key: jax.random.KeyArray) -> float:
    sol = eval_traj(control, key)

    reward = reward_fn(sol.ys)
    return reward


@eqx.filter_jit
def update_step(
    control: eqx.Module, opt_state: optax.OptState, key: jax.random.KeyArray
):
    reward, grads = eval_reward(control, key)
    grads = jax.tree_map(lambda x: -x, grads)

    control_params, control_static = eqx.partition(control, eqx.is_array)
    updates, opt_state = optimizer.update(grads, opt_state, params=control_params)
    control_params = optax.apply_updates(control_params, updates)

    control = eqx.combine(control_params, control_static)

    return control, opt_state, reward

In [None]:
pbar = trange(1024 * 16)
for i in pbar:
    key, subkey = jax.random.split(key)
    control, opt_state, reward = update_step(control, opt_state, subkey)

    if i % 128 == 0:
        pbar.set_postfix({"reward": reward.item()})

In [None]:
key, subkey = jax.random.split(key)
sol = eval_traj(control, subkey)

plt.figure()
plt.plot(sol.ts, sol.ys[:, 1])
plt.show()

# Animation Test

In [None]:
fig, ax = plt.subplots()

ax.axhline()

cart_width = 0.5
cart_height = 0.25
cart = ax.add_patch(
    matplotlib.patches.Rectangle(
        [sol.ys[0, 0] - cart_width / 2, 0], cart_width, cart_height
    )
)

pole_width = 2.5
pole_length = 0.5


def get_pole_data(cart_x, pole_angle):
    pole_base_x = cart_x
    pole_base_y = cart_height
    pole_end_x = pole_base_x + math.cos(pole_angle)
    pole_end_y = pole_base_y + math.sin(pole_angle)

    return [pole_base_x, pole_end_x], [pole_base_y, pole_end_y]


pole = ax.add_line(
    matplotlib.lines.Line2D(
        *get_pole_data(sol.ys[0, 0], sol.ys[0, 1]), linewidth=pole_width
    )
)


def init():
    ax.set_xlim([-10.0, 10.0])
    ax.set_ylim([-0.5, 2.0])

    return cart, pole


def update(frame):
    y = sol.ys[frame]

    cart.set(x=y[0])
    pole.set_data(*get_pole_data(y[0], y[1]))

    return cart, pole

animation = matplotlib.animation.FuncAnimation(fig=fig, func=update, frames=range(len(sol.ts)), init_func=init)
plt.show()