# Double pendulum

https://en.wikipedia.org/wiki/Double_pendulum

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from sympy import symbols, Function
import sympy
import polars as pl
import gif
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
import math
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController
import jax.numpy as jnp

## Maths

### Lagrangian

$$ L = T - V $$

$$ T =  \frac{1}{2} \sum_i m_i \left( \dot{x}_i^2 + \dot{y}_i^2 \right)  $$

the uniform  graviational potential is

$$ V = \sum_i m_i g y_i $$

pendulum #1:

$$ x_1 = l_1 \sin q_1 $$

$$ y_1 = -l_1\cos q_1 $$

pendulum #2:

$$ x_2 = l_1 \sin q_1 + l_2 \sin q_2 $$

$$ y_2 = -l_1 \cos q_1 - l_2 \cos q_2 $$

the derivation becomes lengthy ... see https://en.wikipedia.org/wiki/Double_pendulum#Lagrangian for the Lagrangian

time to let sympy shine

In [None]:
g, m1, m2, l1, l2, t = symbols("g m_1 m_2 l_1 l_2 t")
_q1, _q2, _p1, _p2 = symbols("qt_1 qt_2 pt_1 pt_2", cls=Function)
q1, q2, p1, p2 = symbols("q_1 q_2 p_1 p_2")

q1t = _q1(t)
q2t = _q2(t)
p1t = _p1(t)
p2t = _p2(t)

In [None]:
x1 = l1 * sympy.sin(q1t)
x1_dot = x1.diff(t)
x1_dot

In [None]:
y1 = -l1 * sympy.cos(q1t)
y1_dot = y1.diff(t)
y1_dot

In [None]:
x2 = l1 * sympy.sin(q1t) + l2 * sympy.sin(q2t)
x2_dot = x2.diff(t)
x2_dot

In [None]:
y2 = -l1 * sympy.cos(q1t) - l2 * sympy.cos(q2t)
y2_dot = y2.diff(t)
y2_dot

Kinetic energy

In [None]:
T = 1 / 2.0 * (m1 * (x1_dot**2 + y1_dot**2) + m2 * (x2_dot**2 + y2_dot**2))
T = T.simplify()
T

Potential energy

In [None]:
V = m1 * g * y1 + m2 * g * y2
# V = 0
V

Lagrangian

In [None]:
L = T - V
L = L.simplify()
L

using 

$$ p = \partial_{\dot{q}} L $$

In [None]:
q1t_dot = q1t.diff(t)
p1_eq = L.diff(q1t_dot)
p1_eq

In [None]:
q2t_dot = q2t.diff(t)
p2_eq = L.diff(q2t_dot)
p2_eq

express $\dot{q}$ in terms of $p$

In [None]:
eq1 = p1t - p1_eq
eq2 = p2t - p2_eq

q_dot_qp = sympy.solve((eq1, eq2), (q1t_dot, q2t_dot))

In [None]:
q_dot_qp

$$ H = \sum_i p_i \dot{q}_i - L

In [None]:
H = p1t * q1t_dot + p2t * q2t_dot - L
H

replacing $\dot{q}$ with $p$, $q$

In [None]:
H_pq = H.subs(q1t_dot, q_dot_qp[q1t_dot]).subs(q2t_dot, q_dot_qp[q2t_dot]).simplify()
H_pq

$$ \dot{q}_i = \partial_{p_i} H $$

$$ \dot{p}_i = -\partial_{q_i} H $$

In [None]:
dq1_eq = H_pq.diff(p1t)
dq1_eq

In [None]:
dq2_eq = H_pq.diff(p2t)
dq2_eq

In [None]:
dp1_eq = -H_pq.diff(q1t).simplify()
dp1_eq

In [None]:
dp2_eq = -H_pq.diff(q2t).simplify()
dp2_eq

replacing $q(t)$ with $q$, similar for derivatives and $p$, so we can lambdify

In [None]:
q1_dot, q2_dot = symbols("dq_1 dq_2")

In [None]:
dq1_eq

In [None]:
dq1_eq_subs = dq1_eq.subs(p1t, p1).subs(p2t, p2).subs(q1t, q1).subs(q2t, q2)
dq1_eq_subs

In [None]:
dq2_eq

In [None]:
dq2_eq_subs = dq2_eq.subs(p1t, p1).subs(p2t, p2).subs(q1t, q1).subs(q2t, q2)
dq2_eq_subs

In [None]:
dp1_eq

In [None]:
dp1_eq_subs = (
    dp1_eq.subs(q1t, q1).subs(p1t, p1).subs(p2t, p2).subs(q1t, q1).subs(q2t, q2)
)
dp1_eq_subs

In [None]:
dp2_eq

In [None]:
dp2_eq_subs = (
    dp2_eq.subs(q2t, q2).subs(p1t, p1).subs(p2t, p2).subs(q1t, q1).subs(q2t, q2)
)
dp2_eq_subs

creating jax functions from the $p$ and $q$ equations

In [None]:
dp1_fun = sympy.lambdify((g, l1, l2, m1, m2, q1, q2, p1, p2), dp1_eq_subs, "jax")
dp2_fun = sympy.lambdify((g, l1, l2, m1, m2, q1, q2, p1, p2), dp2_eq_subs, "jax")
dq1_fun = sympy.lambdify((g, l1, l2, m1, m2, q1, q2, p1, p2), dq1_eq_subs, "jax")
dq2_fun = sympy.lambdify((g, l1, l2, m1, m2, q1, q2, p1, p2), dq2_eq_subs, "jax")

In [None]:
p1_eq_subs = (
    p1_eq.subs(q1t_dot, q1_dot).subs(q2t_dot, q2_dot).subs(q1t, q1).subs(q2t, q2)
)
p1_eq_subs

In [None]:
p2_eq_subs = (
    p2_eq.subs(q1t_dot, q1_dot).subs(q2t_dot, q2_dot).subs(q1t, q1).subs(q2t, q2)
)
p2_eq_subs

In [None]:
p1_fun = sympy.lambdify((g, l1, l2, m1, m2, q1, q2, q1_dot, q2_dot), p1_eq_subs, "math")
p2_fun = sympy.lambdify((g, l1, l2, m1, m2, q1, q2, q1_dot, q2_dot), p2_eq_subs, "math")

In [None]:
def ode_to_sole(t, z, args):
    g, l1, l2, m1, m2 = args
    q_1, q_2, p_1, p_2 = z

    q_1_dot = dq1_fun(g, l1, l2, m1, m2, q_1, q_2, p_1, p_2)
    q_2_dot = dq2_fun(g, l1, l2, m1, m2, q_1, q_2, p_1, p_2)
    p_1_dot = dp1_fun(g, l1, l2, m1, m2, q_1, q_2, p_1, p_2)
    p_2_dot = dp2_fun(g, l1, l2, m1, m2, q_1, q_2, p_1, p_2)

    return q_1_dot, q_2_dot, p_1_dot, p_2_dot


m1_val = 1.0
m2_val = 1.0
l1_val = 1.0
l2_val = 1.0
g_val = 9.81

q_1_0_val = 1.0
q_2_0_val = 0.0

q_1_dot_0_val = 0.0
q_2_dot_0_val = 0.0

p_1_0_val = p1_fun(
    g,
    l1_val,
    l2_val,
    m1_val,
    m2_val,
    q_1_0_val,
    q_2_0_val,
    q_1_dot_0_val,
    q_2_dot_0_val,
)
p_2_0_val = p2_fun(
    g,
    l1_val,
    l2_val,
    m1_val,
    m2_val,
    q_1_0_val,
    q_2_0_val,
    q_1_dot_0_val,
    q_2_dot_0_val,
)

params = (g_val, l1_val, l2_val, m1_val, m2_val)
initial_values = (q_1_0_val, q_2_0_val, p_1_0_val, p_2_0_val)
params, initial_values

In [None]:
ti, tf = 0, 10
t_span = [ti, tf]
t_eval = jnp.linspace(ti, tf, 101)
dt0 = t_eval[1].item() - t_eval[0].item()

term = ODETerm(ode_to_sole)
solver = Dopri5()
saveat = SaveAt(ts=t_eval)
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)


sol = diffeqsolve(
    term,
    solver,
    t0=ti,
    t1=tf,
    dt0=dt0,
    y0=initial_values,
    saveat=saveat,
    stepsize_controller=stepsize_controller,
    args=params,
)

In [None]:
assert sol.ys is not None
assert sol.ts is not None
t_sol = sol.ts
q_1_sol = sol.ys[0] % (2.0 * math.pi)
q_2_sol = sol.ys[1] % (2.0 * math.pi)
p_1_sol = sol.ys[2]
p_2_sol = sol.ys[3]

fig, axs = plt.subplots(nrows=2, sharex=True)
ax = axs[0]
ax.plot(t_sol, q_1_sol, label="$q_1$")
ax.plot(t_sol, p_1_sol, label="$p_1$")
ax.legend()

ax = axs[1]
ax.plot(t_sol, q_2_sol, label=r"$q_2$")
ax.plot(t_sol, p_2_sol, label=r"$p_2$")
ax.set(xlabel="t")
ax.legend()
plt.tight_layout()

pendulum #1:

$$ x_1 = l_1 \sin q_1 $$

$$ y_1 = -l_1\cos q_1 $$

pendulum #2:

$$ x_2 = l_1 \sin q_1 + l_2 \sin q_2 $$

$$ y_2 = -l_1 \cos q_1 - l_2 \cos q_2 $$

In [None]:
trajectory = pl.DataFrame(
    {
        "t": t_sol.tolist(),
        "x1": (l1_val * jnp.sin(q_1_sol)).tolist(),
        "y1": (-l1_val * jnp.cos(q_1_sol)).tolist(),
        "x2": (l1_val * jnp.sin(q_1_sol) + l2_val * jnp.sin(q_2_sol)).tolist(),
        "y2": (-l1_val * jnp.cos(q_1_sol) - l2_val * jnp.cos(q_2_sol)).tolist(),
    }
)

fig, ax = plt.subplots()
ax.scatter(data=trajectory, x="x1", y="y1", label="1")
ax.scatter(data=trajectory, x="x2", y="y2", label="2")
ax.set(xlabel="x", ylabel="y")
ax.legend(title="component")
plt.tight_layout()

In [None]:
def animate_trajectory(trajectory: pl.DataFrame, title: str, filename: str):
    x1_vals = trajectory["x1"].to_list()
    y1_vals = trajectory["y1"].to_list()
    x2_vals = trajectory["x2"].to_list()
    y2_vals = trajectory["y2"].to_list()

    x_min = min(x1_vals + x2_vals) - 1
    x_max = max(x1_vals + x2_vals) + 1

    y_min = min(y1_vals + y2_vals) - 1
    y_max = max(y1_vals + y2_vals) + 1

    frames = []

    @gif.frame
    def plot_frame(i: int) -> Figure:
        fig, ax = plt.subplots()

        ax.plot(x1_vals[: i + 1], y1_vals[: i + 1], color="lightblue", linewidth=2)
        ax.plot(x1_vals[i], y1_vals[i], "bo", markersize=8, label="1")

        ax.plot(x2_vals[: i + 1], y2_vals[: i + 1], color="lightgreen", linewidth=2)
        ax.plot(x2_vals[i], y2_vals[i], "go", markersize=8, label="2")

        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)

        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.legend(title="component")

        ax.set_title(title)

        return fig

    for i in range(len(x1_vals)):
        frames.append(plot_frame(i))

    gif.save(frames, filename, duration=300)


animate_trajectory(
    trajectory,
    "Double pendulum with gravity",
    "trajectory-double-pendulum-with-gravity.gif",
)