# Imports

In [2]:
import jax

from jax.config import config
config.update('jax_enable_x64', True)

import jax_dataclasses as jdc
from functools import partial
from typing import Any, Callable, Sequence

import jax.numpy as jnp
import jax.scipy as jsp
import matplotlib.pylab as plt
import numpy as np
import optax
import scipy as sp
import tqdm
import time
import chex

Array = jax.Array
PRNGKey = chex.PRNGKey
PyTree = Any
Scalar = chex.Scalar

In [3]:
# f(t, x, u)
Dynamics = Callable[[Scalar, PyTree, Array], PyTree]

# c(t, x, u)
StageCost = Callable[[Scalar, PyTree, Array], float]
# c(t, x)
TerminalCost = Callable[[Scalar, PyTree], float]


@jdc.pytree_dataclass
class Cost:
    stage_cost: StageCost
    terminal_cost: TerminalCost


def rollout(dynamics: Dynamics, U: Array, x0: PyTree) -> PyTree:
    """Unrolls `X[t+1] = dynamics(t, X[t], U[t])`, where `X[0] = x0`."""

    def step(x, args):
        t, u = args
        x = dynamics(t, x, u)
        return x, x  # Return (carry, emitted state)

    _, X = jax.lax.scan(step, x0, (jnp.arange(len(U)), U))
    return X


def trajectory_cost(cost: Cost, U: Array, x0: PyTree, X: PyTree) -> float:
    T = len(U)
    time_steps = jnp.arange(T)
    X = jax.tree_util.tree_map(
        lambda a, b: jnp.concatenate((a[None, :], b), axis=0), x0, X)
    stage_cost = jnp.sum(jax.vmap(cost.stage_cost)(
        time_steps, jax.tree_util.tree_map(lambda leaf: leaf[:-1], X), U))
    terminal_cost = cost.terminal_cost(
        T, jax.tree_util.tree_map(lambda leaf: leaf[-1], X))
    return stage_cost + terminal_cost


def objective(
    dynamics: Dynamics, cost: Cost, U: Array, x0: PyTree
) -> float:
    X = rollout(dynamics, U, x0)
    return trajectory_cost(cost, U, x0, X)


def pytree_block_until_ready(tree: PyTree) -> PyTree:
    return jax.tree_util.tree_map(lambda leaf: leaf.block_until_ready(), tree)


def print_jit_and_eval_times(f, *, args=(), kwargs={}, name='', num_steps=5):
    start_time = time.time()
    pytree_block_until_ready(f(*args, **kwargs))
    jit_plus_eval_time = time.time() - start_time

    start_time = time.time()
    for _ in range(num_steps):
        pytree_block_until_ready(f(*args, **kwargs))
    eval_time = (time.time() - start_time) / num_steps
    print(
        f"{name}, jit_time={jit_plus_eval_time - eval_time:.3f} (s), eval_time={eval_time:.3f} (s)")


In [None]:
# Integrators
def euler(dynamics:Dynamics, dt:float)->Dynamics:
    return lambda t, x, u: x + dt*dynamics(t*dt, x, u)

def rk4(dynamics:Dynamics, dt:float)->Dynamics:
    def integrator(t, x, u):
        dt2 = dt / 2.0
        k1 = dynamics(t*dt, x, u)
        k2 = dynamics(t*dt, x + dt2 * k1, u)
        k3 = dynamics(t*dt, x + dt2 * k2, u)
        k4 = dynamics(t*dt, x + dt *k3, u)
        return

# Define dynamics

In [None]:
g=9.81
m=1.0
l=1.0
max_torque=2.0


# Generate data

# Set up cost, constraints and dynamics

# Define iLQR solver

# Solve