In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

import time
from math import ceil

from functools import partial

from jax import config

config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp

import collimator

from collimator import logging

logging.set_log_level(logging.ERROR)

from collimator.logging import logger

from collimator.framework import LeafSystem
from collimator.library import (
    FeedthroughBlock,
    Clock,
    Constant,
    Multiplexer,
    LTISystem,
    LookupTable1d,
    HermiteSimpsonNMPC,
)
from collimator.library import linearize, LookupTable1d
from collimator.simulation import SimulatorOptions

from dynamics import make_quadcopter, create_animation, plot_wc_sol

%load_ext autoreload
%autoreload 2

In [None]:
config = {
    "Ixx": 1.0,
    "Iyy": 1.0,
    "Izz": 2.0,
    "k": 1.0,
    "b": 0.5,
    "l": 1.0 / 3,
    "m": 2.0,
    "g": 9.81,
}

In [None]:
duration = 1.2  # 1.2
duration_flip = 0.9  # 0.9


def get_linear_and_rot(t):
    t = jnp.clip(t, 0, duration_flip)
    phi = 2 * jnp.pi * t / duration_flip
    # Constant values for x, y, psi (yaw), and theta (pitch)
    x, y, z, psi, theta = [0.0] * 5

    # Return the state as a JAX array
    return jnp.array([x, y, z, phi, theta, psi])


def get_state_and_control(t, config):
    u0 = config["m"] * config["g"] / (4.0 * config["k"])
    u_ref = jnp.array([u0] * 4)
    x_ref = get_linear_and_rot(t)
    dot_x_ref = jax.jacobian(get_linear_and_rot)(t)
    return jnp.hstack([x_ref, dot_x_ref, u_ref])


get_state_and_control_vec = jax.vmap(get_state_and_control, (0, None))

In [None]:
tvec = jnp.linspace(0.0, duration, 100)
traj = get_state_and_control_vec(tvec, config)[:, :6]
# create_animation(traj, traj)

In [None]:
nx = 12
nu = 4

x_and_u_0 = get_state_and_control(0.0, config)
t0 = 0.0
x0 = x_and_u_0[:nx]
u0 = x_and_u_0[nx:]

N = 20
Tf = duration
dt = Tf / N
print(f"{dt=}")

weights = [
    2.0,
    2.0,
    0.5,  # x, y, z
    0.5,
    2.0,
    2.0,  # phi, theta, psi
    1.0,
    1.0,
    0.5,  # dot_x, dot_y, dot_z
    1.0,
    0.1,
    0.1,
]  # dot_phi, dot_theta, dot_psi

# weights = [1.0]*12
Q = jnp.diag(jnp.array(weights))
QN = 2 * jnp.eye(nx)
R = 0.001 * jnp.eye(nu)

lb_x = None
ub_x = None

lb_u = 0.0 * jnp.ones(nu)
ub_u = None

x_optvars_0 = jnp.tile(x0, (N + 1, 1))
u_optvars_0 = jnp.tile(u0, (N + 1, 1))


def get_reference_trajectory(t0, N, dt, config):
    tvec = t0 + dt * jnp.arange(N + 1)
    x_and_u = get_state_and_control_vec(tvec, config)
    return x_and_u


x_and_u_ref = get_reference_trajectory(t0, N, dt, config)
x_ref = x_and_u_ref[:, :nx]
u_ref = x_and_u_ref[:, -nu:]

In [None]:
hs = HermiteSimpsonNMPC(
    make_quadcopter(config=config, initial_state=x0, name="quadcopter_mpc"),
    Q,
    QN,
    R,
    N,
    Tf / N,
    lb_x=lb_x,
    ub_x=ub_x,
    lb_u=lb_u,
    ub_u=ub_u,
)


tic = time.perf_counter()
x_and_u_optvars = hs.solve_trajectory_optimzation(
    t0, x0, x_ref, u_ref, x_optvars_0, u_optvars_0
)
toc = time.perf_counter()
print(f"Trajectory optimization problem took: {toc-tic}s")

u_opt = x_and_u_optvars[: (N + 1) * nu].reshape(N + 1, nu)
x_opt = x_and_u_optvars[(N + 1) * nu :].reshape(N + 1, nx)
t_vec = t0 + dt * jnp.arange(N + 1)

In [None]:
fig_control, axs_control = plt.subplots(2, 2, figsize=(11, 3))
for row_index, ax in enumerate(axs_control.flatten()):
    ax.plot(
        t_vec,
        u_opt[:, row_index],
        "-r",
        label=r"$u_" + str(row_index) + r"$",
        alpha=0.5,
    )
    ax.legend(loc="best")
fig_control.tight_layout()

In [None]:
class InterpArray(FeedthroughBlock):
    def __init__(self, t_vec, x_arr, *args, **kwargs):
        self.t_vec = t_vec
        self.x_arr = x_arr
        self.interp_fun = jax.vmap(jnp.interp, (None, None, 1))
        super().__init__(
            lambda t: self.interp_fun(t, self.t_vec, self.x_arr), *args, **kwargs
        )


builder = collimator.DiagramBuilder()

quadcopter = builder.add(
    make_quadcopter(config=config, initial_state=x0, name="quadcopter")
)
control = builder.add(InterpArray(t_vec, u_opt, name="control"))
clock = builder.add(Clock(name="clock"))

builder.connect(clock.output_ports[0], control.input_ports[0])
builder.connect(control.output_ports[0], quadcopter.input_ports[0])

diagram = builder.build()
diagram.pprint()

In [None]:
context = diagram.create_context()

recorded_signals = {
    "state": quadcopter.output_ports[0],
    "control": control.output_ports[0],
}

dt = dt
Tsolve = duration

nseg = ceil(Tsolve / dt)
options = SimulatorOptions(
    max_major_steps=10 * nseg,
    max_major_step_length=dt,
)

sol = collimator.simulate(
    diagram,
    context,
    (0.0, Tsolve),
    options=options,
    recorded_signals=recorded_signals,
)

In [None]:
try:
    create_animation(sol.outputs["state"][:, :6], x_ref[:, :6])
except RuntimeError as e:
    print(e)