In [None]:
import os
import pickle
import time

import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer
from matplotlib.colors import Normalize
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm.notebook import tqdm

from dimp.utils import init_matplotlib


init_matplotlib()

%matplotlib widget

In [None]:
rerun = False

out_dir = "data/pann_clqr_dt"

os.makedirs(out_dir, exist_ok=True)

## Example Pannocchia 🌽

LTI system with dynamics
$$
\dot{s} = A s + B u
$$

3 states and 1 input.

### Define the problem matrices, the initial state, the time horizon, and the LQR matrices.

In [None]:
A = np.array([
    [-0.1, 0, 0],
    [0, -2, -6.25],
    [0, 4, 0]
])

B = np.array([[0.25], [2.0], [0.0]])

s0 = np.array([1.344, -4.585, 5.674])   # initial state

T = 10.0        # time window

N = 1000        # max number of timesteps of the OCP

dt = T / N

Q = 1.0 * np.eye(3)
R = 0.1 * np.eye(1)

u_max = 1.0     # max control input

n_s = 3         # number of states
n_u = 1         # number of inputs

### Create the Cxvpy Variables and the Optimization Problem

$$
\begin{align*}
& \min_{\substack{s_{k+1}, u_k \\ k=0, \dots, N}} \quad & \sum_{k=0}^{N} \left( s_k^T Q s_k + u_k^T R u_k \right), \\
& \text{s.t.} \quad & s_{k+1} = s_k + \Delta t (A s_k + B u_k), \\
& & s_0 = s_{\text{init}}, \\
& & u_k \in U.
\end{align*}
$$

In [None]:
s = [s0] + [cp.Variable(3, name=f"s_{i}") for i in range(N)]
u = [cp.Variable(1, name=f"u_{i}") for i in range(N)]

def create_pann_clqr(n: int = N, dt: float = dt):
    objective = cp.Minimize(
        cp.sum([cp.quad_form(s[i+1], Q) for i in range(n)]) * dt +
        cp.sum([cp.quad_form(u[i], R) for i in range(n)]) * dt
    )
    
    dynamics_constraints = [
        s[i+1] == s[i] + (A @ s[i] + B @ u[i]) * dt for i in range(n)
    ]
    
    input_limits = [
        cp.abs(u[i]) <= u_max for i in range(n)
    ]

    constraints = dynamics_constraints + input_limits

    problem = cp.Problem(objective, constraints)
    
    return problem

pann_clqr = create_pann_clqr()
assert pann_clqr.is_dpp()

### Study how the solution changes with different number of samples.

With $n=80$ samples the system is unstable.
From $n=160$ timesteps the OCP stabilizes the system.
However, increasing the number of timesteps further improves the solution (see the optimal cost value).

In [None]:
# Number of timesteps of the trajectory optimization problem.
interval = 80

fig, axs = plt.subplots(12, 2, figsize=(6.4, 12.8), constrained_layout=True)

i = 0
for n in range(interval, N + 1, interval):
    dt = T / n
    pann_clqr = create_pann_clqr(n=n, dt=dt)

    start_time = time.time()
    pann_clqr.solve()
    solve_time = time.time() - start_time
    
    print(
        f"n = {n}\n"
        f"Optimal cost: {pann_clqr.objective.value:.4f}, "
        f"solve time meas: {solve_time:.4f} s, "
        f"solve time: {pann_clqr.solver_stats.solve_time:.4f} s\n"
    )
    
    # if n > 6*interval:
    #     continue

    times = np.arange(n) * dt
    s_vec = np.array([s.value for s in s[1:n+1]])
    u_vec = np.array([u.value for u in u[:n]])

    axs[2*(i//2), i%2].plot(times, s_vec, label=['x', 'y', 'z'])
    axs[2*(i//2), i%2].set(
        xlabel='Time',
        ylabel='State',
        title=fr"$n={n}$",
    )

    axs[2*(i//2)+1, i%2].plot(times, u_vec)
    axs[2*(i//2)+1, i%2].set(
        xlabel='Time',
        ylabel='Input',
    )

    # axs[i//2, (i+1)%2].legend()

    i = i + 1

## DQP with Auxiliary Variables

### Create the Parametrized CLQR Problem

Discretize then linearize the (nonlinear) dynamics.
$$
s_{k+1} = \bar{s}_k + \bar{\delta}_k (A \bar{s}_k + B u_k) + (I + \bar{\delta}_k A) \tilde{s}_k + \bar{\delta}_k B u_k + (A \bar{s}_k + B u_k) \tilde{\delta}_k,
$$
where $\tilde{\square} = \square - \bar{\square}$ and $\delta_k$ is an auxiliary variable that represents the time step lengths.

The QP is parametrized by the parameter vector $\bar{\delta} = \begin{bmatrix} \bar{\delta}_1 & \dots & \bar{\delta}_N \end{bmatrix}^T$.
$$
\begin{align*}
& \min_{\substack{s_{k+1}, u_k, \delta_k \\ k=0, \dots, N}} \quad & \sum_{k=0}^{N} \left(\bar{\delta}_k \left( s_k^T Q s_k + u_k^T R u_k \right) + w_\delta \tilde{\delta}_k^2 \right), \\
& \text{s.t.} \quad & s_{k+1} = \bar{s}_k + \bar{\delta}_k (A \bar{s}_k + B u_k) + (I + \bar{\delta}_k A) \tilde{s}_k + \bar{\delta}_k B u_k + (A \bar{s}_k + B u_k) \tilde{\delta}_k, \\
& & s_0 = s_{\text{init}}, \\
& & u_k \in U.
\end{align*}
$$

In [None]:
n = 200

deltas = [cp.Variable(1, name=f'deltas_{i}') for i in range(n)]
dts = [cp.Parameter(1, nonneg=True, name=f'dts_{i}') for i in range(n)]

def create_pann_param_clqr(s_bar, u_bar, n: int = N):
    objective = cp.Minimize(
        cp.sum([cp.quad_form(s[i+1], Q) * dts[i] for i in range(n)]) +
        cp.sum([cp.quad_form(u[i], R) * dts[i] for i in range(n)]) +
        10**3 * cp.sum([cp.square(deltas[i] - dts[i]) for i in range(n)])
    )
    
    dynamics_constraints = [
        s[i+1] == s_bar[i] \
            + dts[i] * (A @ s_bar[i] + B @ u_bar[i]) \
            + (np.eye(n_s) + dts[i] * A) @ (s[i] - s_bar[i]) \
            + dts[i] * B @ (u[i] - u_bar[i]) \
            + (deltas[i] - dts[i]) * (A @ s_bar[i] + B @ u_bar[i]) \
            for i in range(n)
    ]
    
    input_limits = [
        cp.abs(u[i]) <= u_max for i in range(n)
    ]
    
    timestep_constraints = [
        deltas[i] >= 0 for i in range(n)
    ] + [
        cp.sum(deltas[0:n]) == T
    ]

    constraints = dynamics_constraints + input_limits + timestep_constraints

    problem = cp.Problem(objective, constraints)
    
    return problem

s_bar = [np.zeros(n_s) for _ in range(n)]
u_bar = [np.zeros(n_u) for _ in range(n)]

pann_param_clqr = create_pann_param_clqr(s_bar, u_bar, n)
assert pann_param_clqr.is_dpp()

### Task Loss

Unscaled
$$
\mathcal{L}_1 = \sum_{k=0}^{N} \left(\lVert s_k \rVert^2_Q  + \lVert u_k \rVert^2_R \right)
$$

Time scaled
$$
\mathcal{L}_2 = \sum_{k=0}^{N} \delta_k \left(\lVert s_k \rVert^2_Q  + \lVert u_k \rVert^2_R \right)
$$

Time bar scaled
$$
\mathcal{L}_3 = \sum_{k=0}^{N} \bar{\delta}_k \left(\lVert s_k \rVert^2_Q  + \lVert u_k \rVert^2_R \right)
$$

In [None]:
loss_methods = ["unscaled", "time scaled", "time bar scaled"]

def task_loss(sol, param, method="unscaled"):
    states = [sol[i] for i in range(n)]
    inputs = [sol[n+i] for i in range(n)]
    deltas = [sol[2*n+i] for i in range(n)]
    deltas_bar = np.concatenate([d_bar.detach().numpy() for d_bar in param])
    
    Q_th = torch.tensor(Q, dtype=torch.float32, device=states[0].device)
    R_th = torch.tensor(R, dtype=torch.float32, device=states[0].device)
    
    if method == 'unscaled':
        return sum([
            si.t() @ Q_th @ si for si in states
        ]) + sum([
            ui.t() @ R_th @ ui for ui in inputs
        ])
    if method == 'time scaled':
        return sum([
            deltas[i] * states[i].t() @ Q_th @ states[i] for i in range(n)
        ]) + sum([
            deltas[i] * inputs[i].t() @ R_th @ inputs[i] for i in range(n)
        ])
    if method == 'time bar scaled':
        return sum([
            deltas_bar[i] * states[i].t() @ Q_th @ states[i] for i in range(n)
        ]) + sum([
            deltas_bar[i] * inputs[i].t() @ R_th @ inputs[i] for i in range(n)
        ])
    
    raise ValueError(f"Unknown method {method}")

In [None]:
def plot_timegrid(deltas, x=None, ax=None, ylabel=None, title=None):
    times = np.cumsum(deltas.tolist())

    if ax is None:
        fig, ax = plt.subplots()
    for t in times:
        ax.axvline(t, color='gray', linestyle='--', alpha=0.5)
    
    if x is not None:
        ax.plot(times, x)
        
    ax.set_xlabel("Time")
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    if title is not None:
        ax.set_title(title)

In [None]:
def plot_colored(deltas, x, ax=None):
    times = np.cumsum(deltas)
    
    cmap = plt.get_cmap("viridis")
    norm = Normalize(vmin=np.min(deltas), vmax=np.max(deltas))

    if ax is None:
        fig = plt.figure()
        ax = plt.gca()
    
    for i in range(len(x)-1):
        # horizontal hold
        ax.hlines(x[i], times[i], times[i+1], 
                  colors=cmap(norm(deltas[i])), linewidth=2)
        # vertical jump
        ax.vlines(times[i+1], x[i], x[i+1], 
                  colors=cmap(norm(deltas[i])), linewidth=1)
        
    ax.set_xlabel("Time")

    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax)
    cbar.set_label("deltas")
    
    return ax

### Training Loop

In [None]:
def plot_training_res(sol, history):
    s_arr = np.array([s.detach().numpy().tolist() for s in sol[0:n]])
    u_arr = np.array([u.detach().numpy().tolist() for u in sol[n:2*n]])
    d_arr = np.concatenate(np.array([d.detach().numpy().tolist() for d in sol[2*n:3*n]]))

    fig, ax = plt.subplots(2, 3, figsize=(9.6, 9.6))
    ax[0, 0].plot([h['loss'] for h in history])

    ax[0, 0].set_xlabel("Epoch")
    ax[0, 0].set_ylabel("Loss")
    ax[0, 0].set_title("Loss Evolution")
    
    plot_colored(d_arr, u_arr, ax[1, 0])
    
    plot_timegrid(
        d_arr, s_arr, ax[0, 1],
        ylabel="State",
        title="State Evolution"
    )

    plot_timegrid(
        d_arr, u_arr, ax[1, 1],
        ylabel="Input",
        title="Input Evolution"
    )
    
    ax[0, 2].plot(np.cumsum(d_arr), d_arr)
    ax[0, 2].set_xlabel("Time")
    ax[0, 2].set_ylabel("Timestep duration")
    ax[0, 2].set_title("Timesteps Evolution")
    
    fig.delaxes(ax[1, 2])

    fig.set_constrained_layout(True)

In [None]:
if rerun:
    n_epochs = 100

    history_aux  = []
    sol_aux = {}

    for method in loss_methods:
        print(f"Method: {method}")
        
        n = 160

        dts_torch = [torch.nn.Parameter(torch.ones(1) * dt) for _ in range(n)]

        optim = torch.optim.Adam(dts_torch, lr=5e-4)

        # =========================================================================== #

        with torch.no_grad():
            dt = T / n
            for d in dts_torch:
                d.copy_(torch.ones(1) * dt)

        s_bar = [s0 + (np.zeros(n_s) - s0) * i/n for i in range(n)]
        u_bar = [np.zeros(n_u) for _ in range(n)]
        with tqdm(total=n_epochs) as pbar:
            for epoch in range(n_epochs):
                pbar.update(1)
                
                optim.zero_grad()

                pann_param_clqr = create_pann_param_clqr(s_bar, u_bar, n)
                cvxpylayer = CvxpyLayer(
                    pann_param_clqr,
                    parameters=dts[:n],
                    variables=s[1:n+1] + u[:n] + deltas[:n],
                )
                sol_aux[method] = cvxpylayer(*dts_torch)
                            
                s_bar = [sol_aux[method][i].detach().numpy() for i in range(n)]
                u_bar = [sol_aux[method][n+i].detach().numpy() for i in range(n)]

                loss = task_loss(sol_aux[method], dts_torch, method=method)
                loss.backward()

                optim.step()
                with torch.no_grad():
                    for d in dts_torch:
                        d.clamp_(min=1e-6, max=0.07)
                        d *= T / sum(dts_torch)

                history_aux.append({
                    'method': method,
                    'loss': loss.item(),
                    'dts': [d.detach().numpy() for d in dts_torch]
                })

    with open(f"{out_dir}/sol_aux.pkl", "wb") as f:
        pickle.dump(sol_aux, f)
    with open(f"{out_dir}/history_aux.pkl", "wb") as f:
        pickle.dump(history_aux, f)
else:
    with open(f"{out_dir}/sol_aux.pkl", "rb") as f:
        sol_aux = pickle.load(f)
    with open(f"{out_dir}/history_aux.pkl", "rb") as f:
        history_aux = pickle.load(f)

In [None]:
n = 160
for method in loss_methods:
    plot_training_res(sol_aux[method], [h for h in history_aux if h['method'] == method])

## DQP with Shaped Parameters

In [None]:
def theta_2_dt(theta):
    eps = 1e-3
    
    w = torch.softmax(theta.flatten(), dim=0)
    return eps + (T - n*eps) * w

## DQP With Reparametrized Timesteps

Reparametrize time steps on the simplex:
$$
\Delta t_k = \epsilon + (T - n \epsilon) \frac{e^{\theta_k}}{\sum_j e^{\theta_j}}
$$
This enforces both positivity and the total time constraint $\sum_k \Delta t_k = T$ without discontinuous updates.

The OCP uses $\Delta t_k$ and the optimizer uses $\theta_k$ as optimization parameters.
Gradients do flow through the softmax function.

In [None]:
n = 160

def create_pann_param_clqr_2(n: int = N):
    dts = cp.Parameter(n, nonneg=True, name='dts')
    
    objective = cp.Minimize(
        cp.sum([cp.quad_form(s[i+1], Q) * dts[i] for i in range(n)]) +
        cp.sum([cp.quad_form(u[i], R) * dts[i] for i in range(n)])
    )
    
    dynamics_constraints = [
        s[i+1] == s[i] \
            + dts[i] * (A @ s[i] + B @ u[i]) \
            for i in range(n)
    ]
    
    input_limits = [
        cp.abs(u[i]) <= u_max for i in range(n)
    ]

    constraints = dynamics_constraints + input_limits

    problem = cp.Problem(objective, constraints)
    
    return problem, dts

pann_param_clqr_2, _ = create_pann_param_clqr_2(n)
assert pann_param_clqr_2.is_dpp()

In [None]:
loss_methods_2 = ["unscaled", "time scaled"]

def task_loss_2(sol, param, method="unscaled"):
    states = [sol[i] for i in range(n)]
    inputs = [sol[n+i] for i in range(n)]
    deltas = param
    
    Q_th = torch.tensor(Q, dtype=torch.float32, device=states[0].device)
    R_th = torch.tensor(R, dtype=torch.float32, device=states[0].device)
    
    if method == 'unscaled':
        return sum([
            si.t() @ Q_th @ si for si in states
        ]) + sum([
            ui.t() @ R_th @ ui for ui in inputs
        ])
    if method == 'time scaled':
        return sum([
            deltas[i] * states[i].t() @ Q_th @ states[i] for i in range(n)
        ]) + sum([
            deltas[i] * inputs[i].t() @ R_th @ inputs[i] for i in range(n)
        ])
    
    raise ValueError(f"Unknown method {method}")

### Task Loss

Unscaled
$$
\mathcal{L}_1 = \sum_{k=0}^{N} \left(\lVert s_k \rVert^2_Q  + \lVert u_k \rVert^2_R \right)
$$

Time scaled
$$
\mathcal{L}_2 = \sum_{k=0}^{N} \delta_k \left(\lVert s_k \rVert^2_Q  + \lVert u_k \rVert^2_R \right)
$$

In [None]:
if rerun:
    n_epochs = 200

    history_rep  = []
    sol_rep = {}

    for method in loss_methods_2:
        print(f"Method: {method}")
        
        n = 160

        theta = torch.nn.Parameter(torch.ones(n, 1))

        optim = torch.optim.Adam([theta], lr=1e-2)

        # =========================================================================== #

        s_bar = [s0 + (np.zeros(n_s) - s0) * i/n for i in range(n)]
        u_bar = [np.zeros(n_u) for _ in range(n)]
        with tqdm(total=n_epochs) as pbar:
            for epoch in range(n_epochs):
                pbar.update(1)
                
                optim.zero_grad(set_to_none=True)
                dts2_torch = theta_2_dt(theta)

                pann_param_clqr, dts2 = create_pann_param_clqr_2(n)
                cvxpylayer = CvxpyLayer(
                    pann_param_clqr,
                    parameters=[dts2],
                    variables=s[1:n+1] + u[:n],
                )
                sol_rep[method] = cvxpylayer(dts2_torch)
                
                s_bar = [sol_rep[method][i].detach().numpy() for i in range(n)]
                u_bar = [sol_rep[method][n+i].detach().numpy() for i in range(n)]

                loss = task_loss_2(sol_rep[method], dts2_torch, method=method)
                loss.backward()

                optim.step()

                history_rep.append({
                    'method': method,
                    'loss': loss.item(),
                    'dts': dts2_torch.detach().numpy(),
                })
                
    with open(f"{out_dir}/history_rep.pkl", "wb") as f:
        pickle.dump(history_rep, f)
    with open(f"{out_dir}/sol_rep.pkl", "wb") as f:
        pickle.dump(sol_rep, f)
else:
    with open(f"{out_dir}/history_rep.pkl", "rb") as f:
        history_rep = pickle.load(f)
    with open(f"{out_dir}/sol_rep.pkl", "rb") as f:
        sol_rep = pickle.load(f)

In [None]:
def plot_training_res_2(sol, history):
    s_arr = np.array([s.detach().numpy().tolist() for s in sol[0:n]])
    u_arr = np.array([u.detach().numpy().tolist() for u in sol[n:2*n]])
    d_arr = history[-1]['dts']

    fig, ax = plt.subplots(2, 3, figsize=(9.6, 9.6))
    ax[0, 0].plot([h['loss'] for h in history])
    ax[0, 0].set_xlabel("Epoch")
    ax[0, 0].set_ylabel("Loss")
    ax[0, 0].set_title("Loss Evolution")
    
    plot_colored(d_arr, u_arr, ax[1, 0])
    
    plot_timegrid(
        d_arr, s_arr, ax[0, 1],
        ylabel="State",
        title="State Evolution"
    )

    plot_timegrid(
        d_arr, u_arr, ax[1, 1],
        ylabel="Input",
        title="Input Evolution"
    )
    
    ax[0, 2].plot(np.cumsum(d_arr), d_arr)
    ax[0, 2].set_xlabel("Time")
    ax[0, 2].set_ylabel("Timestep duration")
    ax[0, 2].set_title("Timesteps Evolution")
    
    fig.delaxes(ax[1, 2])

    fig.set_constrained_layout(True)

In [None]:
n = 160
for method in loss_methods_2:
    plot_training_res_2(sol_rep[method], [h for h in history_rep if h['method'] == method])

## DQP With Exact Discretization

Use the exact discretization of the LTI system with zero-order hold (ZOH).

In [None]:
n = 160

Aps = [cp.Parameter((n_s, n_s), name=f"Ad_{k}") for k in range(n)]
Bps = [cp.Parameter((n_s, n_u), name=f"Bd_{k}") for k in range(n)]

def create_exact_param_pann_clqr(n: int = N):
    objective = cp.sum([cp.quad_form(s[k+1], Q) + cp.quad_form(u[k], R) for k in range(n)])
    
    constraints = [ s[k+1] == Aps[k] @ s[k] + Bps[k] @ u[k] for k in range(n) ] \
                + [ cp.abs(u[k]) <= u_max for k in range(n) ]

    prob = cp.Problem(cp.Minimize(objective), constraints)
    
    return prob

prob = create_exact_param_pann_clqr(n)
assert prob.is_dpp()

layer = CvxpyLayer(prob, parameters=Aps + Bps, variables=s[1:n+1] + u[:n])

def Ad_Bd_from_dt(dt):
    # Block matrix exponential per Van Loan
    M = torch.zeros(n_s + n_u, n_s + n_u, dtype=torch.float32, device=dt.device)
    M[:n_s, :n_s] = torch.tensor(A, dtype=torch.float32, device=dt.device) * dt
    M[:n_s, n_s:] = torch.tensor(B, dtype=torch.float32, device=dt.device) * dt
    E = torch.matrix_exp(M)
    return E[:n_s, :n_s], E[:n_s, n_s:]

In [None]:
loss_methods_zoh = ["unscaled", "time scaled"]

if rerun:
    n_epochs_zoh = 300
    history_zoh = []
    sol_zoh = {}

    for method in loss_methods_zoh:
        print(f"ZOH Method: {method}")

        # Choose horizon for training (you can change to 200, etc.)
        n = 160

        # Softmax-reparam timesteps: theta -> dt on simplex (uses your theta_2_dt)
        theta = torch.nn.Parameter(torch.ones(n, 1))
        optim = torch.optim.Adam([theta], lr=1e-2)

        # Build the ZOH layer for this n
        zoh_layer = layer

        with tqdm(total=n_epochs_zoh) as pbar:
            for epoch in range(n_epochs_zoh):
                pbar.update(1)
                optim.zero_grad(set_to_none=True)

                # Current timesteps from parameters (positive, sum to T)
                dts_torch = theta_2_dt(theta)  # shape: (n,)

                # Build (Ad,Bd) lists for the layer call
                Ad_list, Bd_list = zip(*[Ad_Bd_from_dt(dt_k) for dt_k in dts_torch])

                # Solve QP via CVXPYLayer forward pass
                sol_zoh[method] = zoh_layer(*Ad_list, *Bd_list)

                # Compute task loss (unscaled or time-scaled)
                loss = task_loss_2(sol_zoh[method], dts_torch, method=method)
                loss.backward()
                optim.step()

                # Log history
                history_zoh.append({
                    "method": method,
                    "epoch": epoch,
                    "loss": float(loss.item()),
                    "dts": dts_torch.detach().cpu().numpy(),  # keep full grid
                })
                
    with open(f"{out_dir}/sol_zoh.pkl", "wb") as f:
        pickle.dump(sol_zoh, f)
    with open(f"{out_dir}/history_zoh.pkl", "wb") as f:
        pickle.dump(history_zoh, f)
else:
    with open(f"{out_dir}/sol_zoh.pkl", "rb") as f:
        sol_zoh = pickle.load(f)
    with open(f"{out_dir}/history_zoh.pkl", "rb") as f:
        history_zoh = pickle.load(f)

In [None]:
n = 160
for method in loss_methods_zoh:
    plot_training_res_2(sol_zoh[method], [h for h in history_zoh if h['method'] == method])

### WTF

Check that the ZOH discretization is correct.

In [None]:
states = [sol_zoh[method][i] for i in range(n)]
inputs = [sol_zoh[method][n+i] for i in range(n)]
deltas = dts_torch.detach().cpu().numpy()

def shoot_dyn(s0, inputs, deltas):
    dt = 0.00001
    
    time = 0.0
    
    times = np.cumsum(deltas)
    
    s = [s0]
    
    while time <= T:
        i = np.searchsorted(times, time, side="right") - 1

        s.append(s[-1] + (A @ s[-1] + B @ inputs[i].detach().cpu().numpy()) * dt)

        time = time + dt
        
    return s
    
s_hist = shoot_dyn(s0, inputs, deltas)

plt.figure()
plt.plot(np.array(s_hist))
plt.show()