In [None]:
import copy
from cycler import cycler
import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import torch

from dimp.robots import (
    OmniState, OmniInput, OmniRobot, RobotMPCData
)

### Create The Data for the MPC

In [None]:
ns = 2      # Number of states (x, y)
ni = 2      # Number of inputs (vx, vy)

nc = 2      # Number of control intervals

s0 = cp.Parameter(ns, name="s0")

mpc_data = RobotMPCData(
    nc=nc,
    states_list=[OmniState(s0)] + [OmniState(cp.Variable(ns, name=f"s{k+1}")) for k in range(nc)],
    statesbar_list=[OmniState(s0)] + [OmniState(cp.Parameter(ns)) for _ in range(nc)],
    inputs_list=[OmniInput(cp.Variable(ni, name=f"i{k}")) for k in range(nc)],
    inputsbar_list=[OmniInput(cp.Parameter(ni)) for _ in range(nc)],
)

dt = 0.1
robot = OmniRobot(dt=dt, mpc_data=mpc_data)

# Parameters
p_goal = np.array([10.0, 5.0])
v_max = 1.0

### Create The MPC Problem

In [None]:
weights = cp.Parameter(2, name="weights", nonneg=True)

def create_qcqp():
    objective = cp.Minimize(
          0.5 * weights[0] * cp.sum([cp.pnorm(mpc_data.statei[i+1] - p_goal, p=2) for i in range(nc)])
        + 0.5 * weights[1] * cp.sum([cp.pnorm(mpc_data.inputi[i], p=2) for i in range(nc)])
    )

    dynamics_constraints = robot.dt_dynamics_constraint()

    input_constraints = [
        cp.norm(mpc_data.inputi[0], p=2) - v_max <= 0,
        cp.norm(mpc_data.inputi[1], p=2) - v_max <= 0,
    ]

    constraints = dynamics_constraints + input_constraints

    problem = cp.Problem(objective, constraints)

    return problem

qcqp_problem = create_qcqp()
assert qcqp_problem.is_dpp()

### Simulate the Trajectory

In [None]:
def simulate(problem, mpc_data):
    steps = 200

    states = np.zeros((steps, ns))
    inputs = np.zeros((steps, ni))

    mpc_data.statei[0].value = np.array([0, 0])

    for i in range(steps):

        problem.solve()

        mpc_data.statei[0].value = mpc_data.statei[1].value

        mpc_data.update_bar()

        states[i, :] = mpc_data.statei[1].value
        inputs[i, :] = mpc_data.inputi[1].value

    return states, inputs

### Plot the Trajectory

In [None]:
def plot_trajectory(states):
    steps = states.shape[0]

    xm, xM = states[:, 0].min() - 1, states[:, 0].max() + 1
    ym, yM = states[:, 1].min() - 1, states[:, 1].max() + 1

    fig = go.Figure(
        data=[
            go.Scatter(x=states[:, 0], y=states[:, 1],
                        mode="lines", name="Trajectory",
                        line=dict(width=2, color="rgba(0, 0, 255, 0.5)", dash='dot')),
            go.Scatter(x=[states[0, 0]], y=[states[0, 1]],
                        mode="markers", name="Robot",
                        marker=dict(color="blue", size=10)),
        ])

    fig.update_layout(width=600, height=450,
        xaxis=dict(range=[xm, xM], autorange=False, zeroline=False, scaleanchor="y"),
        yaxis=dict(range=[ym, yM], autorange=False, zeroline=False),
        title_text="Trajectory", title_x=0.5,
        updatemenus = [dict(type = "buttons",
            buttons = [
                dict(
                    args = [None, {"frame": {"duration": 10, "redraw": False},
                                    "fromcurrent": True, "transition": {"duration": 10}, "mode": "immediate"}],
                    label = "Play",
                    method = "animate",
                )])],
    )

    fig.update(frames=[
        go.Frame(
            data=[go.Scatter(x=[states[k, 0]], y=[states[k, 1]])],
            traces=[1]
        ) for k in range(steps)])

    fig.show()

    fig.write_html("omni_robot_mpc.html", include_plotlyjs="cdn", full_html=False)

In [None]:
def simulate_and_plot_qcqp():
    weights.value = np.array([1.0, 0.1])
    
    states, inputs = simulate(qcqp_problem, mpc_data)

    plot_trajectory(states)

simulate_and_plot_qcqp()

In [None]:
cvxpylayer = CvxpyLayer(
    qcqp_problem,
    parameters=[weights, s0],
    variables=[mpc_data.statei[i] for i in range(1, nc + 1)] + [mpc_data.inputi[i] for i in range(nc)],
)

w_th = torch.tensor([1.0, 0.1], requires_grad=True)
s0_th = torch.tensor([0.0, 0.0], requires_grad=True)

optim = torch.optim.Adam([w_th], lr=1e-1)

solution = cvxpylayer(w_th, s0_th)

solution[0].sum().backward()

print(f"w.grad: {w_th.grad}")

In [None]:
def rollout(initial_state, W, steps=200):
    s_t = initial_state
    all_states, all_inputs = [], []

    for _ in range(steps):
        sol = cvxpylayer(W, s_t)
        s1, s2, u0, u1  = sol

        all_states.append(s1)
        all_inputs.append(u0)

        s_t = s1

    return all_states, all_inputs

In [None]:
def task_loss(states, inputs):
    p_goal_th = torch.tensor(p_goal, requires_grad=True)
    
    st_cost = torch.stack([torch.norm(s[0:2] - p_goal_th, p=2)  for s in states]).sum()
    in_cost = torch.stack([torch.norm(u, p=2)  for u in inputs]).sum()
    return 0.5 * 10.0 * st_cost + 0.5 * 1.0 * in_cost

In [None]:
n_epochs = 50
history  = []
for epoch in range(n_epochs):
    optim.zero_grad()

    s0_th.data = torch.zeros_like(s0_th)

    sol = cvxpylayer(w_th, s0_th)
    states = [sol[i] for i in range(nc)]
    inputs = [sol[nc + i] for i in range(nc)]

    loss = task_loss(states, inputs)
    loss.backward()

    optim.step()

    history.append({
        'loss': loss.item(),
        'w': copy.deepcopy(w_th.detach().cpu().numpy()),
        'dw': copy.deepcopy(w_th.grad.detach().cpu().numpy()),
    })
    print(f"Epoch {epoch:2d} | loss = {loss.item():.4f} | w = {w_th.detach().cpu().numpy()}")


In [None]:
default_cycler = (
    cycler(color=['#E41A1C', '#377EB8', '#4DAF4A', '#984EA3', '#FF7F00', '#FFFF33', '#A65628', '#F781BF', '#999999']) +
    # cycler(color=['#0072BD', '#D95319', '#EDB120', '#7E2F8E', '#77AC30', '#4DBEEE', '#A2142F']) +
    cycler('linestyle', ['-', '--', '-', '--', '-', '--', '-', '--', '-'])
)

colors = list(default_cycler.by_key()['color'])

textsize = 12
labelsize = 12

plt.rc('font', family='serif', serif='Times')
plt.rc('text', usetex=True)
plt.rc('xtick', labelsize=textsize)
plt.rc('ytick', labelsize=textsize)
plt.rc('axes', labelsize=labelsize, prop_cycle=default_cycler)
plt.rc('legend', fontsize=textsize)

plt.rc("axes", grid=True, xmargin=0)
plt.rc("grid", linestyle='dotted', linewidth=0.25)

plt.rcParams['figure.constrained_layout.use'] = True

In [None]:
fig, ax = plt.subplots(1, 2)
ax[0].plot([h['loss'] for h in history])
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("Loss")
ax[0].set_title("Loss Evolution")

ax11 = ax[1].twinx()
ax[1].plot([h['w'][0] for h in history],)
ax[1].set_title("Weights Evolution")
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel(r"$w_0$", color=colors[0])
ax[1].tick_params(axis='y', labelcolor=colors[0])

ax11.plot([h['w'][1] for h in history], color=colors[1], linestyle='--')
ax11.set_ylabel(r"$w_1$", color=colors[1])
ax11.tick_params(axis='y', labelcolor=colors[1])

fig.set_constrained_layout(True)
plt.show()