# Optimal control of pendulum

## Imports

In [1]:
from typing import Any, Union, NamedTuple, Callable
import jax
from jax import Array
import jax.random as jr
import jax.numpy as jnp
from jax.experimental.ode import odeint
import matplotlib.pyplot as plt

from diffilqrax.utils import keygen
from diffilqrax.ilqr import ilqr_solver
from diffilqrax.typs import (
    iLQRParams,
    System,
    ModelDims,
    PendulumParams,
    Theta
)

jax.config.update('jax_enable_x64', True)

## The Problem

We have a pendulum with masses $m_1=1$ and with lengths $L_1$ and $L_2$ which is released from an arbitrary point and has to be driven to the target position.

The cartesian coordinates of the center of mass of each point are defined as
$$
\begin{align}
x_1 =& L_1 \sin(\theta_1) \\
y_1 =& -L_1 \cos(\theta_1) \\
\end{align}
$$

In [2]:
# define coordination function
x1_fn = lambda l_1, the_1: l_1 * jnp.sin(the_1)
y1_fn = lambda l_1, the_1: - l_1 * jnp.cos(the_1)

The dynamics of the double pendulum in polar coordinates are given by the following equations:

$$
% \begin{equation}
\ddot{\theta} = \frac{mgl}{J}\sin\theta - \frac{mgl}{J} u\cos\theta
% \end{equation}
$$

Setting $\frac{mgl}{J}=1$ and initializing system with $\theta=\pi + \epsilon$ and $\dot{\theta}=0$, arrives with simplified dynamics,

$$
% \begin{equation}
\ddot{\theta} = \sin\theta - u\cos\theta
% \end{equation}
$$

Outlined in matrix form, where $\mathbf{x}=(x_1,x_2)=(\dot{\theta}, \sin\theta)$, 

$$
\frac{d}{dt}\left( \begin{matrix} \dot{\theta} \\ \sin\theta \end{matrix} \right) = 
\left( \begin{matrix} 0 & 1 \\ \cos\theta & 0 \end{matrix} \right) \left( \begin{matrix} \dot{\theta} \\ \sin\theta \end{matrix} \right) + 
\left( \begin{matrix} -\cos\theta \\ 0 \end{matrix} \right) u
$$

In [3]:
def euler(dynamics:Callable, dt:float)->Callable:
    return lambda t, x, u: x + dt*dynamics(t*dt, x, u)

def rk4(dynamics:Callable, dt:float)->Callable:
    def integrator(t, x, u):
        dt2 = dt / 2.0
        k1 = dynamics(t*dt, x, u)
        k2 = dynamics(t*dt, x + dt2 * k1, u)
        k3 = dynamics(t*dt, x + dt2 * k2, u)
        k4 = dynamics(t*dt, x + dt *k3, u)
        return

In [19]:
class PendulumParams(NamedTuple):
    """Pendulum parameters"""

    m: float
    l: float
    g: float

def pendulum_dynamics(state:Array, u:Array, theta: PendulumParams)->Array:
    """simulate the dynamics of a pendulum. x0 is sin(theta), x1 is cos(theta), x2 is theta_dot.
    u is the torque applied to the pendulum.

    Args:
        t (int): timepoint
        state (Array): state params
        u (Array): external input
        theta (Theta): parameters
    """
    
    def dyn(x:Array)->Array:
        d_theta, theta = x
        sin_theta = jnp.sin(theta)
        dS_dt = jnp.array([[0, 1], [jnp.cos(theta), 0]]) @ jnp.array([[d_theta], [sin_theta]]) + jnp.array([[-jnp.cos(theta)], [0.]]) * u
        return (dS_dt[0][0], jnp.arcsin(dS_dt[1][0]))
    
    dt=0.01
    tspec = jnp.linspace(0., 2., int(2/dt)+1)
    states = odeint(dyn, state, tspec)
    
    return states


In [20]:
odeint?

[0;31mSignature:[0m
[0modeint[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mfunc[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0my0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mt[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m*[0m[0margs[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mrtol[0m[0;34m=[0m[0;36m1.4e-08[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0matol[0m[0;34m=[0m[0;36m1.4e-08[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmxstep[0m[0;34m=[0m[0minf[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mhmax[0m[0;34m=[0m[0minf[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Adaptive stepsize (Dormand-Prince) Runge-Kutta odeint implementation.

Args:
  func: function to evaluate the time derivative of the solution `y` at time
    `t` as `func(y, t, *args)`, producing the same shape/structure as `y0`.
  y0: array or pytree of arrays representing the initial value for the state.
  t: array of float times for evaluat

In [21]:
pparams = PendulumParams(1.0, 1.0, 9.8)
n_tps = jnp.linspace(0., 2., int(2/0.01)+1).size
state_ = jnp.array([0., jnp.pi/2])

pendulum_dynamics(state=state_, u=jnp.zeros((n_tps,1)), theta=pparams)


TypeError: pendulum_dynamics.<locals>.dyn() takes 1 positional argument but 2 were given

In [None]:

def pendulum_model():
    """define pendulum model with cost, dynamics and cost function"""
    def cost(t: int, x: Array, u: Array, theta: Any):
        return jnp.sum(x[0]**2) + jnp.sum((x[1]-jnp.pi)**2) + jnp.sum(u**2)

    def costf(x: Array, theta: Any):
        return jnp.sum(x[0]**2 + (x[1]-jnp.pi)**2)

    def dynamics(t: int, x: Array, u: Array, theta: Union[Theta, PendulumParams]):
        return pendulum_dynamics(t, x, u, theta)

    return System(cost, costf, dynamics, ModelDims(horizon=100, n=3, m=1, dt=0.1))

# Define iLQR solver

In [None]:
key = jr.PRNGKey(seed=234)
key, skeys = keygen(key, 5)

ls_kwargs = {
    "beta": 0.8,
    "max_iter_linesearch": 16,
    "tol": 1e0,
    "alpha_min": 0.0001,
}

theta = PendulumParams(m=1, l=2, g=9.81)
params = iLQRParams(x0=jr.normal(next(skeys), (3,)), theta=theta)
model = pendulum_model()

Us_init = jnp.zeros((model.dims.horizon, 1))

## Solve

In [None]:
# test ilqr solver
(Xs_stars, Us_stars, Lambs_stars), converged_cost, cost_log = ilqr_solver(
    model,
    params,
    Us_init,
    max_iter=40,
    convergence_thresh=1e-8,
    alpha_init=1.0,
    verbose=True,
    use_linesearch=True,
    **ls_kwargs,
)

## Visualise