In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

import functools
import math
from functools import partial
from typing import Callable, List, Optional, Tuple, Type

import diffrax
import equinox as eqx
import matplotlib.animation
import matplotlib.lines
import matplotlib.patches
import matplotlib.pyplot as plt
import optax
from jaxtyping import Array, ArrayLike, PRNGKeyArray, PyTree, Scalar
from tqdm.auto import trange

In [None]:
def cartpole_ode(t: Scalar, y: Array, u: Array, args: PyTree) -> PyTree:
    f = 10 * u[0]

    g = 9.81
    mass_cart = 1.0
    mass_pole = 0.1
    length_pole = 0.5

    x = y[0]
    theta = y[1]
    dot_x = y[2]
    dot_theta = y[3]

    ddot_theta = (
        g * jnp.sin(theta)
        + jnp.cos(theta)
        * (
            (-f - mass_pole * length_pole * jnp.square(dot_theta) * jnp.sin(theta))
            / (mass_cart + mass_pole)
        )
    ) / (
        length_pole
        * (4 / 3 - (mass_pole * jnp.square(jnp.cos(theta)) / (mass_cart + mass_pole)))
    )

    ddot_x = (
        f
        + mass_pole
        * length_pole
        * (jnp.square(dot_theta) * jnp.sin(theta) - ddot_theta * jnp.cos(theta))
    ) / (mass_cart + mass_pole)

    dy = jnp.stack((dot_x, dot_theta, ddot_x, ddot_theta), axis=-1)
    return dy

# Control function decorators

In [None]:
def with_control(
    f: Callable[[Scalar, PyTree, PyTree, PyTree], PyTree]
) -> Callable[[Scalar, PyTree, PyTree], PyTree]:
    @functools.wraps(f)
    def wrapper(t: Scalar, y: PyTree, args: PyTree) -> PyTree:
        control, f_args = args
        u = control(y)

        dy = f(t, y, u, f_args)
        return dy

    def modify_initial_state(control: PyTree, t0: Scalar, y0: Array) -> Array:
        return y0

    wrapper_fn = wrapper
    wrapper_fn._modify_initial_state = modify_initial_state

    return wrapper_fn


def with_derivative_control(
    f: Callable[[Scalar, PyTree, PyTree, PyTree], PyTree], num_controls: int
) -> Callable[[Scalar, PyTree, PyTree], PyTree]:
    @functools.wraps(f)
    def wrapper(t: Scalar, y: PyTree, args: PyTree) -> PyTree:
        control, f_args = args
        c_u, f_y = y[..., :num_controls], y[..., num_controls:]

        f_dy = f(t, f_y, c_u, f_args)
        c_du = control(y)

        dy = jnp.concatenate((c_du, f_dy), axis=-1)
        return dy

    def modify_initial_state(control: PyTree, t0: Scalar, y0: Array) -> Array:
        X0 = jnp.concatenate((t0, y0), axis=-1)
        c0 = control.encode_controls(X0)

        # c0 = jnp.zeros(num_controls)

        y0 = jnp.concatenate((c0, y0), axis=-1)
        return y0

    wrapper_fn = wrapper
    wrapper_fn._modify_initial_state = modify_initial_state

    return wrapper_fn


# This needs stepping support, since ODE RNNs are discontinuous over state changes
# def with_ode_rnn_control(f: Callable[[Scalar, PyTree, PyTree, PyTree], PyTree], num_controls: int, num_memory: int):


def with_cde_rnn_control(
    f: Callable[[Scalar, PyTree, PyTree, PyTree], PyTree], num_latents: int
):
    @functools.wraps(f)
    def wrapper(t: Scalar, y: PyTree, args: PyTree) -> PyTree:
        control, f_args = args
        f_y, c_z = y[..., :-num_latents], y[..., -num_latents:]

        c_u = control.decode_latents(c_z)
        f_dy = f(t, f_y, c_u, f_args)

        dt = jnp.ones(1)
        dX = jnp.concatenate((dt, f_dy), axis=-1)

        c_dzdX = control(c_z)
        c_dz = c_dzdX @ dX

        dy = jnp.concatenate((f_dy, c_dz), axis=-1)
        return dy

    def modify_initial_state(control: PyTree, t0: Scalar, y0: Array) -> Array:
        X0 = jnp.concatenate((t0, y0), axis=-1)
        z0 = control.encode_latents(X0)

        y0 = jnp.concatenate((z0, y0), axis=-1)
        return y0

    wrapper_fn = wrapper
    wrapper_fn._modify_initial_state = modify_initial_state

    return wrapper_fn

# Active Control Test (with generic interface)

In [None]:
class RNN(eqx.Module):
    in_proj: eqx.nn.Linear
    out_proj: eqx.nn.Linear
    cells: eqx.Module

    def __init__(self, in_width: int, out_width: int, rnn_width: int, rnn_layers: int, cell_cls: Type[eqx.Module], key: PRNGKeyArray):
        key, key1, key2 = jax.random.split(key, num=3)
        self.in_proj = eqx.nn.Linear(in_width, rnn_width, key=key1)
        self.out_proj = eqx.nn.Linear(rnn_width, out_width, use_bias=False, key=key2)

        keys = jax.random.split(key, num=rnn_layers)
        make_cells = lambda k: cell_cls(input_size=rnn_width, hidden_size=rnn_width, key=k)
        self.cells = jax.vmap(make_cells)(keys)

    def __call__(self, inputs: Array, states: PyTree) -> Tuple[Array, PyTree]:
        x = self.in_proj(inputs)

        # Scan over stack of RNN cells
        def f(carry: Array, x: Tuple[eqx.Module, PyTree]) -> Tuple[Array, PyTree]:
            input = carry
            cell, state = x

            next_state = cell(input, state)

            if isinstance(next_state, Tuple[Array, Array]):
                output, _ = next_state
            elif isinstance(next_state, Array):
                output = next_state

            return output, next_state

        x, next_states = jax.lax.scan(f, init=x, xs=(self.cells, states))

        x = self.out_proj(x)
        return x, next_states

class ModularControl(eqx.Module):
    main: eqx.Module
    encoder: Optional[eqx.Module]
    decoder: Optional[eqx.Module]
    num_controls: int
    num_latents: Optional[int]
    num_states: Optional[int]
    mode: str

    def __init__(
        self,
        hidden_width: int,
        hidden_layers: int,
        num_controls: int,
        num_latents: Optional[int],
        num_states: Optional[int],
        rnn_cell_cls: Optional[Type[eqx.Module]],
        mode: str = "cde-rnn",
        *,
        key: PRNGKeyArray
    ):
        self.num_controls = num_controls
        self.num_latents = num_latents
        self.num_states = num_states
        self.mode = mode

        if mode == "cde-rnn":
            keys = jax.random.split(key, num=3)
            self.main = eqx.nn.MLP(
                in_size=num_latents,
                out_size=(num_latents * (1 + num_states)),
                width_size=hidden_width,
                depth=hidden_layers,
                activation=jax.nn.silu,
                final_activation=jax.nn.tanh,
                use_final_bias=False,
                key=keys[0],
            )
            self.encoder = eqx.nn.MLP(
                in_size=(1 + num_states),
                out_size=num_latents,
                width_size=hidden_width,
                depth=hidden_layers,
                activation=jax.nn.silu,
                use_final_bias=False,
                key=keys[1],
            )
            self.decoder = eqx.nn.MLP(
                in_size=num_latents,
                out_size=num_controls,
                width_size=hidden_width,
                depth=hidden_layers,
                activation=jax.nn.silu,
                final_activation=jax.nn.tanh,
                use_final_bias=False,
                key=keys[2],
            )
        if mode == "step-rnn":
            
        elif mode == "derivative":
            keys = jax.random.split(key, num=2)
            self.main = eqx.nn.MLP(
                in_size=num_states,
                out_size=num_controls,
                width_size=hidden_width,
                depth=hidden_layers,
                activation=jax.nn.silu,
                use_final_bias=False,
                key=keys[0],
            )
            self.encoder = eqx.nn.MLP(
                in_size=(1 + num_states),
                out_size=num_controls,
                width_size=hidden_width,
                depth=hidden_layers,
                activation=jax.nn.silu,
                use_final_bias=False,
                key=keys[1],
            )

    def __call__(self, inputs: Array) -> Array:
        if self.mode == "cde-rnn":
            dzdX: Array = self.main(inputs)
            dzdX = dzdX.reshape(self.num_latents, 1 + self.num_states)

            return dzdX
        else:
            return self.main(inputs)

    def encode_controls(self, X0: Array) -> Array:
        return self.encoder(X0)

    def encode_latents(self, z0: Array) -> Array:
        return self.encoder(z0)

    def decode_latents(self, z: Array) -> Array:
        return self.decoder(z)

In [None]:
key = jax.random.PRNGKey(1234)

key, subkey = jax.random.split(key)
control = ModularControl(
    hidden_width=64,
    hidden_layers=2,
    num_controls=1,
    num_latents=16,
    num_states=4,
    mode="cde-rnn",
    key=subkey,
)

optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params=eqx.filter(control, eqx.is_array))
env_ode = with_cde_rnn_control(cartpole_ode, num_latents=16)


def reward_fn(ys: Array) -> float:
    x_thresh = 2.0
    theta_thresh = 0.2

    x = ys[..., 0]
    theta = ys[..., 1]

    # Mark invalid states
    # invalid_state = (jnp.abs(x) > x_thresh) | (jnp.abs(theta) > theta_thresh)

    # Propagate invalid states to the right
    # _, invalid_state = jax.lax.scan(
    #    lambda carry, scan: (carry | scan, carry | scan), False, invalid_state
    # )

    # Aggregate reward over valid states
    reward = jnp.square(x) + jnp.square(theta)
    # reward = jnp.where(invalid_state, 0.0, reward)
    reward = jnp.sum(reward)

    return -reward

    # return -jnp.mean(jnp.square(ys[..., 1]))


@eqx.filter_jit
def eval_traj(control: eqx.Module, key: jax.random.KeyArray, t1: float):
    y0 = jax.random.uniform(
        key,
        shape=(4,),
        minval=-0.05,
        maxval=0.05,
    )

    sol = diffrax.diffeqsolve(
        terms=diffrax.ODETerm(env_ode),
        solver=diffrax.Kvaerno5(),
        t0=0.0,
        t1=t1,
        dt0=0.01,
        y0=env_ode._modify_initial_state(control, jnp.asarray([0.0]), y0),
        args=(control, None),
        saveat=diffrax.SaveAt(ts=jnp.linspace(0.0, t1, 1024)),
        stepsize_controller=diffrax.PIDController(
            rtol=1e-5, atol=1e-5, pcoeff=0.3, icoeff=0.3
        ),
    )

    return sol


@eqx.filter_value_and_grad
def eval_reward(
    control: eqx.Module, key: jax.random.KeyArray, batch_size: int, t1: float
) -> float:
    # keys = jax.random.split(key, batch_size)
    # sol = jax.vmap(eval_traj, in_axes=(None, 0), out_axes=0)(control, keys)
    sol = eval_traj(control, key, t1)

    reward = reward_fn(sol.ys)
    return reward


# @eqx.filter_jit
def update_step(
    control: eqx.Module, opt_state: optax.OptState, key: jax.random.KeyArray, t1: float
):
    reward, grads = eval_reward(control, key, 16, t1)
    grads = jax.tree_map(lambda x: -x, grads)

    control_params, control_static = eqx.partition(control, eqx.is_array)
    updates, opt_state = optimizer.update(grads, opt_state, params=control_params)
    control_params = optax.apply_updates(control_params, updates)

    control = eqx.combine(control_params, control_static)

    return control, opt_state, reward

In [None]:
num_steps = 1024
max_t1 = 100.0

pbar = trange(num_steps)
for i in pbar:
    t1 = jnp.float64(max_t1 * (i + 1) / num_steps)

    key, subkey = jax.random.split(key)
    control, opt_state, reward = update_step(control, opt_state, subkey, t1)

    if i % 16 == 0:
        pbar.set_postfix({"reward": reward.item(), "t1": t1})

In [None]:
key, subkey = jax.random.split(key)
sol = eval_traj(control, subkey, jnp.float64(max_t1))

plt.figure()
plt.plot(sol.ts, sol.ys[:, :2])
plt.show()

#plt.figure()
#plt.plot(sol.ts, jax.vmap(control)(sol.ys))
#plt.show()

# Active Control Test
Train a control with the current system state as input

In [None]:
key = jax.random.PRNGKey(1234)

key, subkey = jax.random.split(key)
control = eqx.nn.MLP(
    in_size=4,
    out_size=1,
    width_size=64,
    depth=2,
    use_final_bias=False,
    # activation=jax.nn.tanh,
    final_activation=jax.nn.tanh,
    key=subkey,
)

optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params=eqx.filter(control, eqx.is_array))


def reward_fn(ys: Array) -> float:
    x_thresh = 2.0
    theta_thresh = 0.2

    x = ys[..., 0]
    theta = ys[..., 1]

    # Mark invalid states
    # invalid_state = (jnp.abs(x) > x_thresh) | (jnp.abs(theta) > theta_thresh)

    # Propagate invalid states to the right
    # _, invalid_state = jax.lax.scan(
    #    lambda carry, scan: (carry | scan, carry | scan), False, invalid_state
    # )

    # Aggregate reward over valid states
    reward = jnp.square(x) + jnp.square(theta)
    # reward = jnp.where(invalid_state, 0.0, reward)
    reward = jnp.sum(reward)

    return -reward

    # return -jnp.mean(jnp.square(ys[..., 1]))


@eqx.filter_jit
def eval_traj(control: eqx.Module, key: jax.random.KeyArray, t1: float):
    y0 = jax.random.uniform(
        key,
        shape=(4,),
        minval=-0.05,
        maxval=0.05,
        # minval=jnp.asarray([-0.5, -0.1, -5.0, -1.0]),
        # maxval=jnp.asarray([0.5, 0.1, 5.0, 1.0]),
    )

    sol = diffrax.diffeqsolve(
        terms=diffrax.ODETerm(with_control(cartpole_ode)),
        solver=diffrax.Dopri5(),
        t0=0.0,
        t1=t1,
        dt0=0.01,
        y0=y0,
        args=(control, None),
        saveat=diffrax.SaveAt(ts=jnp.linspace(0.0, t1, 1024)),
        stepsize_controller=diffrax.PIDController(
            rtol=1e-5, atol=1e-5, pcoeff=0.3, icoeff=0.3
        ),
    )

    return sol


@eqx.filter_value_and_grad
def eval_reward(
    control: eqx.Module, key: jax.random.KeyArray, batch_size: int, t1: float
) -> float:
    # keys = jax.random.split(key, batch_size)
    # sol = jax.vmap(eval_traj, in_axes=(None, 0), out_axes=0)(control, keys)
    sol = eval_traj(control, key, t1)

    reward = reward_fn(sol.ys)
    return reward


@eqx.filter_jit
def update_step(
    control: eqx.Module, opt_state: optax.OptState, key: jax.random.KeyArray, t1: float
):
    reward, grads = eval_reward(control, key, 16, t1)
    grads = jax.tree_map(lambda x: -x, grads)

    control_params, control_static = eqx.partition(control, eqx.is_array)
    updates, opt_state = optimizer.update(grads, opt_state, params=control_params)
    control_params = optax.apply_updates(control_params, updates)

    control = eqx.combine(control_params, control_static)

    return control, opt_state, reward

In [None]:
pbar = trange(1024 * 16)
for i in pbar:
    t1 = jnp.float64(10 * (i + 1) / (1024 * 16))

    key, subkey = jax.random.split(key)
    control, opt_state, reward = update_step(control, opt_state, subkey, t1)

    if i % 16 == 0:
        pbar.set_postfix({"reward": reward.item(), "t1": t1})

In [None]:
key, subkey = jax.random.split(key)
sol = eval_traj(control, subkey, jnp.float64(10.0))

plt.figure()
plt.plot(sol.ts, sol.ys[:, :2])
plt.show()

plt.figure()
plt.plot(sol.ts, jax.vmap(control)(sol.ys))
plt.show()

# Active control parameterizing the derivative of the control signal

In [None]:
key = jax.random.PRNGKey(1234)

key, subkey = jax.random.split(key)
control = eqx.nn.MLP(
    in_size=5,
    out_size=1,
    width_size=64,
    depth=2,
    use_final_bias=False,
    # activation=jax.nn.tanh,
    # final_activation=jax.nn.tanh,
    key=subkey,
)

optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params=eqx.filter(control, eqx.is_array))


def reward_fn(ys: Array) -> float:
    x_thresh = 2.0
    theta_thresh = 0.2

    x = ys[..., 0]
    theta = ys[..., 1]

    # Mark invalid states
    # invalid_state = (jnp.abs(x) > x_thresh) | (jnp.abs(theta) > theta_thresh)

    # Propagate invalid states to the right
    # _, invalid_state = jax.lax.scan(
    #    lambda carry, scan: (carry | scan, carry | scan), False, invalid_state
    # )

    # Aggregate reward over valid states
    reward = jnp.square(x) + jnp.square(theta)
    # reward = jnp.where(invalid_state, 0.0, reward)
    reward = jnp.sum(reward)

    return -reward

    # return -jnp.mean(jnp.square(ys[..., 1]))


@eqx.filter_jit
def eval_traj(control: eqx.Module, key: jax.random.KeyArray, t1: float):
    y0 = jax.random.uniform(
        key,
        shape=(4,),
        minval=-0.05,
        maxval=0.05,
        # minval=jnp.asarray([-0.5, -0.1, -5.0, -1.0]),
        # maxval=jnp.asarray([0.5, 0.1, 5.0, 1.0]),
    )

    sol = diffrax.diffeqsolve(
        terms=diffrax.ODETerm(with_derivative_control(cartpole_ode, num_controls=1)),
        solver=diffrax.Dopri5(),
        t0=0.0,
        t1=t1,
        dt0=0.01,
        y0=jnp.concatenate((jnp.zeros(1), y0)),
        args=(control, None),
        saveat=diffrax.SaveAt(ts=jnp.linspace(0.0, t1, 1024)),
        stepsize_controller=diffrax.PIDController(
            rtol=1e-5, atol=1e-5, pcoeff=0.3, icoeff=0.3
        ),
    )

    return sol


@eqx.filter_value_and_grad
def eval_reward(
    control: eqx.Module, key: jax.random.KeyArray, batch_size: int, t1: float
) -> float:
    # keys = jax.random.split(key, batch_size)
    # sol = jax.vmap(eval_traj, in_axes=(None, 0), out_axes=0)(control, keys)
    sol = eval_traj(control, key, t1)

    reward = reward_fn(sol.ys)
    return reward


#@eqx.filter_jit
def update_step(
    control: eqx.Module, opt_state: optax.OptState, key: jax.random.KeyArray, t1: float
):
    reward, grads = eval_reward(control, key, 16, t1)
    grads = jax.tree_map(lambda x: -x, grads)

    control_params, control_static = eqx.partition(control, eqx.is_array)
    updates, opt_state = optimizer.update(grads, opt_state, params=control_params)
    control_params = optax.apply_updates(control_params, updates)

    control = eqx.combine(control_params, control_static)

    return control, opt_state, reward

In [None]:
pbar = trange(1024 * 16)
for i in pbar:
    t1 = jnp.float64(10 * (i + 1) / (1024 * 16))

    key, subkey = jax.random.split(key)
    control, opt_state, reward = update_step(control, opt_state, subkey, t1)

    if i % 16 == 0:
        pbar.set_postfix({"reward": reward.item(), "t1": t1})

In [None]:
key, subkey = jax.random.split(key)
sol = eval_traj(control, subkey, jnp.float64(100.0))

plt.figure()
plt.plot(sol.ts, sol.ys[:, :2])
plt.show()

plt.figure()
plt.plot(sol.ts, jax.vmap(control)(sol.ys))
plt.show()

# Animation Test

In [None]:
fig, ax = plt.subplots()

ax.axhline()

cart_width = 0.5
cart_height = 0.25
cart = ax.add_patch(
    matplotlib.patches.Rectangle(
        [sol.ys[0, 0] - cart_width / 2, 0], cart_width, cart_height
    )
)

pole_width = 2.5
pole_length = 0.5


def get_pole_data(cart_x, pole_angle):
    pole_base_x = cart_x
    pole_base_y = cart_height
    pole_end_x = pole_base_x + math.cos(pole_angle)
    pole_end_y = pole_base_y + math.sin(pole_angle)

    return [pole_base_x, pole_end_x], [pole_base_y, pole_end_y]


pole = ax.add_line(
    matplotlib.lines.Line2D(
        *get_pole_data(sol.ys[0, 0], sol.ys[0, 1]), linewidth=pole_width
    )
)


def init():
    ax.set_xlim([-10.0, 10.0])
    ax.set_ylim([-0.5, 2.0])

    return cart, pole


def update(frame):
    y = sol.ys[frame]

    cart.set(x=y[0])
    pole.set_data(*get_pole_data(y[0], y[1]))

    return cart, pole

animation = matplotlib.animation.FuncAnimation(fig=fig, func=update, frames=range(len(sol.ts)), init_func=init)
plt.show()