In [None]:
from cycler import cycler
import cvxpy as cp
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go

from dimp.robots import (
    OmniState, OmniInput, OmniRobot, RobotMPCData
)

### Create The Data for the MPC

In [None]:
ns = 2      # Number of states (x, y)
ni = 2      # Number of inputs (vx, vy)

nc = 100    # Number of control intervals

s0 = cp.Parameter(ns, name="s0")

mpc_data = RobotMPCData(
    nc=nc,
    states_list=[OmniState(s0)] + [OmniState(cp.Variable(ns, name=f"s{k+1}")) for k in range(nc)],
    statesbar_list=[OmniState(s0)] + [OmniState(cp.Parameter(ns)) for _ in range(nc)],
    inputs_list=[OmniInput(cp.Variable(ni, name=f"i{k}")) for k in range(nc)],
    inputsbar_list=[OmniInput(cp.Parameter(ni)) for _ in range(nc)],
)

dt = 0.1
robot = OmniRobot(dt=dt, mpc_data=mpc_data)

# Parameters
p_goal = np.array([10.0, 5.0])
v_max = 1.0

### Create The MPC Problem

In [None]:
weights = cp.Parameter(2, name="weights", nonneg=True)

def create_qcqp():
    objective = cp.Minimize(
          0.5 * weights[0] * cp.sum([cp.pnorm(mpc_data.statei[i+1] - p_goal, p=2) for i in range(nc)])
        + 0.5 * weights[1] * cp.sum([cp.pnorm(mpc_data.inputi[i], p=2) for i in range(nc)])
    )

    dynamics_constraints = robot.dt_dynamics_constraint()

    input_constraints = [
        cp.norm(mpc_data.inputi[i], p=2) - v_max <= 0 for i in range(nc)
    ]

    constraints = dynamics_constraints + input_constraints

    problem = cp.Problem(objective, constraints)

    return problem

qcqp_problem = create_qcqp()
assert qcqp_problem.is_dpp()

### Plot the Trajectory

In [None]:
def plot_trajectory(states):
    steps = states.shape[0]

    xm, xM = states[:, 0].min() - 1, states[:, 0].max() + 1
    ym, yM = states[:, 1].min() - 1, states[:, 1].max() + 1

    fig = go.Figure(
        data=[
            go.Scatter(x=states[:, 0], y=states[:, 1],
                        mode="lines", name="Trajectory",
                        line=dict(width=2, color="rgba(0, 0, 255, 0.5)", dash='dot')),
            go.Scatter(x=[states[0, 0]], y=[states[0, 1]],
                        mode="markers", name="Robot",
                        marker=dict(color="blue", size=10)),
        ])

    fig.update_layout(width=600, height=450,
        xaxis=dict(range=[xm, xM], autorange=False, zeroline=False, scaleanchor="y"),
        yaxis=dict(range=[ym, yM], autorange=False, zeroline=False),
        title_text="Trajectory", title_x=0.5,
        updatemenus = [dict(type = "buttons",
            buttons = [
                dict(
                    args = [None, {"frame": {"duration": 10, "redraw": False},
                                    "fromcurrent": True, "transition": {"duration": 10}, "mode": "immediate"}],
                    label = "Play",
                    method = "animate",
                )])],
    )

    fig.update(frames=[
        go.Frame(
            data=[go.Scatter(x=[states[k, 0]], y=[states[k, 1]])],
            traces=[1]
        ) for k in range(steps)])

    fig.show()

    fig.write_html("omni_robot_mpc.html", include_plotlyjs="cdn", full_html=False)

In [None]:
s0.value = np.array([0, 0])
weights.value = np.array([1.0, 0.1])

qcqp_problem.solve()

In [None]:
default_cycler = (
    cycler(color=['#E41A1C', '#377EB8', '#4DAF4A', '#984EA3', '#FF7F00', '#FFFF33', '#A65628', '#F781BF', '#999999']) +
    # cycler(color=['#0072BD', '#D95319', '#EDB120', '#7E2F8E', '#77AC30', '#4DBEEE', '#A2142F']) +
    cycler('linestyle', ['-', '--', '-', '--', '-', '--', '-', '--', '-'])
)

textsize = 12
labelsize = 12

plt.rc('font', family='serif', serif='Times')
plt.rc('text', usetex=True)
plt.rc('xtick', labelsize=textsize)
plt.rc('ytick', labelsize=textsize)
plt.rc('axes', labelsize=labelsize, prop_cycle=default_cycler)
plt.rc('legend', fontsize=textsize)

plt.rc("axes", grid=True, xmargin=0)
plt.rc("grid", linestyle='dotted', linewidth=0.25)

plt.rcParams['figure.constrained_layout.use'] = True

### Visualize the State and Input Evolution

In [None]:
fig, ax = plt.subplots(1, 2)

ax[0].plot(
    dt * np.arange(nc + 1),
    [mpc_data.statei[i].value for i in range(nc + 1)],
    label=[r"$s_0$", r"$s_1$"]
)
ax[0].set_title("States")
ax[0].set_xlabel("Time [s]")
ax[0].set_ylabel("Position [m]")
ax[0].legend()

ax[1].plot(
    dt * np.arange(nc),
    [mpc_data.inputi[i].value for i in range(nc)],
    label=[r"$i_0$", r"$i_1$"]
)
ax[1].set_title("Inputs")
ax[1].set_xlabel("Time [s]")
ax[1].set_ylabel("Velocity [N]")
ax[1].legend()