In [None]:
import os
import pickle
import time

import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer
from matplotlib.colors import Normalize
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm.notebook import tqdm

from dimp.utils import init_matplotlib, get_colors


init_matplotlib()

colors = get_colors()

%matplotlib widget

In [None]:
#! Change rerun to True if you want to re-run the experiments and save the results.

rerun = False

out_dir = "data/pann_clqr_dt"

os.makedirs(out_dir, exist_ok=True)

## Example Pannocchia ðŸŒ½

Contnuous LTI system with dynamics
$$
\dot{s} = A s + B u
$$

3 states and 1 input.

### Define the problem matrices, the initial state, the time horizon, and the LQR matrices.

In [None]:
A = np.array([
    [-0.1, 0, 0],
    [0, -2, -6.25],
    [0, 4, 0]
])

B = np.array([[0.25], [2.0], [0.0]])

s0 = np.array([1.344, -4.585, 5.674])   # initial state

T = 10.0        # time window

N = 1000        # max number of timesteps of the OCP

dt = T / N

Q = 1.0 * np.eye(3)
R = 0.1 * np.eye(1)

u_max = 1.0     # max control input

n_s = 3         # number of states
n_u = 1         # number of inputs

### Create the Cxvpy Variables and the Optimization Problem

Classic constrained LQR problem.
Discretized with first-order Euler.
$$
\begin{align*}
& \min_{\substack{s_{k+1}, u_k \\ k=0, \dots, N}} \quad & \sum_{k=0}^{N} \left( s_k^T Q s_k + u_k^T R u_k \right), \\
& \text{s.t.} \quad & s_{k+1} = s_k + \Delta t (A s_k + B u_k), \\
& & s_0 = s_{\text{init}}, \\
& & u_k \in U.
\end{align*}
$$

In [None]:
s = [s0] + [cp.Variable(3, name=f"s_{i}") for i in range(N)]
u = [cp.Variable(1, name=f"u_{i}") for i in range(N)]

def create_pann_clqr(n: int = N, dt: float = dt):
    objective = cp.Minimize(
        cp.sum([cp.quad_form(s[i+1], Q) for i in range(n)]) * dt +
        cp.sum([cp.quad_form(u[i], R) for i in range(n)]) * dt
    )
    
    dynamics_constraints = [
        s[i+1] == s[i] + (A @ s[i] + B @ u[i]) * dt for i in range(n)
    ]
    
    input_limits = [
        cp.abs(u[i]) <= u_max for i in range(n)
    ]

    constraints = dynamics_constraints + input_limits

    problem = cp.Problem(objective, constraints)
    
    return problem

pann_clqr = create_pann_clqr()
assert pann_clqr.is_dpp()

### Study how the solution changes with different number of samples.

With $n=80$ samples the system is unstable.
From $n=160$ timesteps the OCP stabilizes the system.
However, increasing the number of timesteps further improves the solution (see the optimal cost value).

In [None]:
# Number of timesteps of the trajectory optimization problem.
interval = 80

fig, axs = plt.subplots(12, 2, figsize=(6.4, 12.8), constrained_layout=True)

i = 0
for n in range(interval, N + 1, interval):
    dt = T / n
    pann_clqr = create_pann_clqr(n=n, dt=dt)

    start_time = time.time()
    pann_clqr.solve()
    solve_time = time.time() - start_time
    
    print(
        f"n = {n}\n"
        f"Optimal cost: {pann_clqr.objective.value:.4f}, "
        f"solve time meas: {solve_time:.4f} s, "
        f"solve time: {pann_clqr.solver_stats.solve_time:.4f} s\n"
    )
    
    # if n > 6*interval:
    #     continue

    times = np.arange(n) * dt
    s_vec = np.array([s.value for s in s[1:n+1]])
    u_vec = np.array([u.value for u in u[:n]])

    axs[2*(i//2), i%2].plot(times, s_vec, label=['x', 'y', 'z'])
    axs[2*(i//2), i%2].set(
        xlabel='Time',
        ylabel='State',
        title=fr"$n={n}$",
    )

    axs[2*(i//2)+1, i%2].plot(times, u_vec)
    axs[2*(i//2)+1, i%2].set(
        xlabel='Time',
        ylabel='Input',
    )

    # axs[i//2, (i+1)%2].legend()

    i = i + 1

## DQP with Auxiliary Variables

### Create the Parametrized CLQR Problem

Discretize then linearize the (nonlinear) dynamics.
$$
s_{k+1} = \bar{s}_k + \bar{\delta}_k (A \bar{s}_k + B u_k) + (I + \bar{\delta}_k A) \tilde{s}_k + \bar{\delta}_k B u_k + (A \bar{s}_k + B u_k) \tilde{\delta}_k,
$$
where $\tilde{\square} = \square - \bar{\square}$ and $\delta_k$ is an auxiliary variable that represents the time step lengths.

The optimization vector of the QP is $\begin{bmatrix} \tilde{s}_{k=1,\dots,N} & \tilde{u}_{k=0,\dots,N-1} & \tilde{\delta}_{k=0,\dots,N} -1\end{bmatrix}$.

The QP is parametrized by the parameter vector $\bar{\delta} = \begin{bmatrix} \bar{\delta}_1 & \dots & \bar{\delta}_N \end{bmatrix}^T$.
$$
\begin{align*}
& \min_{\substack{s_{k+1}, u_k, \delta_k \\ k=0, \dots, N}} \quad & \sum_{k=0}^{N} \left(\bar{\delta}_k \left( s_k^T Q s_k + u_k^T R u_k \right) + w_\delta \tilde{\delta}_k^2 \right), \\
& \text{s.t.} \quad & s_{k+1} = \bar{s}_k + \bar{\delta}_k (A \bar{s}_k + B \bar{u}_k) + (I + \bar{\delta}_k A) \tilde{s}_k + \bar{\delta}_k B \tilde{u}_k + (A \bar{s}_k + B \bar{u}_k) \tilde{\delta}_k, \\
& & s_0 = s_{\text{init}}, \\
& & u_k \in U.
\end{align*}
$$

In [None]:
n = 200

deltas = [cp.Variable(1, name=f'deltas_{i}') for i in range(n)]
dts = [cp.Parameter(1, nonneg=True, name=f'dts_{i}') for i in range(n)]

def create_pann_param_clqr(s_bar, u_bar, n: int = N):
    objective = cp.Minimize(
        cp.sum([cp.quad_form(s[i+1], Q) * dts[i] for i in range(n)]) +
        cp.sum([cp.quad_form(u[i], R) * dts[i] for i in range(n)]) +
        10**3 * cp.sum([cp.square(deltas[i] - dts[i]) for i in range(n)])
    )
    
    dynamics_constraints = [
        s[i+1] == s_bar[i] \
            + dts[i] * (A @ s_bar[i] + B @ u_bar[i]) \
            + (np.eye(n_s) + dts[i] * A) @ (s[i] - s_bar[i]) \
            + dts[i] * B @ (u[i] - u_bar[i]) \
            + (deltas[i] - dts[i]) * (A @ s_bar[i] + B @ u_bar[i]) \
            for i in range(n)
    ]
    
    input_limits = [
        cp.abs(u[i]) <= u_max for i in range(n)
    ]
    
    timestep_constraints = [
        deltas[i] >= 0 for i in range(n)
    ] + [
        cp.sum(deltas[0:n]) == T
    ]

    constraints = dynamics_constraints + input_limits + timestep_constraints

    problem = cp.Problem(objective, constraints)
    
    return problem

s_bar = [np.zeros(n_s) for _ in range(n)]
u_bar = [np.zeros(n_u) for _ in range(n)]

pann_param_clqr = create_pann_param_clqr(s_bar, u_bar, n)
assert pann_param_clqr.is_dpp()

### Task Loss

Unscaled
$$
\mathcal{L}_1 = \sum_{k=0}^{N} \left(\lVert s_k \rVert^2_Q  + \lVert u_k \rVert^2_R \right)
$$

Time scaled
$$
\mathcal{L}_2 = \sum_{k=0}^{N} \delta_k \left(\lVert s_k \rVert^2_Q  + \lVert u_k \rVert^2_R \right)
$$

Time bar scaled
$$
\mathcal{L}_3 = \sum_{k=0}^{N} \bar{\delta}_k \left(\lVert s_k \rVert^2_Q  + \lVert u_k \rVert^2_R \right)
$$

In [None]:
# loss_methods = ["unscaled", "time scaled", "time bar scaled"]
loss_methods = ["time scaled", "time bar scaled"]

def task_loss(sol, param, method="unscaled"):
    states = [sol[i] for i in range(n)]
    inputs = [sol[n+i] for i in range(n)]
    deltas = [sol[2*n+i] for i in range(n)]
    deltas_bar = np.concatenate([d_bar.detach().numpy() for d_bar in param])
    
    Q_th = torch.tensor(Q, dtype=torch.float32, device=states[0].device)
    R_th = torch.tensor(R, dtype=torch.float32, device=states[0].device)
    
    if method == 'unscaled':
        return sum([
            si.t() @ Q_th @ si for si in states
        ]) + sum([
            ui.t() @ R_th @ ui for ui in inputs
        ])
    if method == 'time scaled':
        return sum([
            deltas[i] * states[i].t() @ Q_th @ states[i] for i in range(n)
        ]) + sum([
            deltas[i] * inputs[i].t() @ R_th @ inputs[i] for i in range(n)
        ])
    if method == 'time bar scaled':
        return sum([
            deltas_bar[i] * states[i].t() @ Q_th @ states[i] for i in range(n)
        ]) + sum([
            deltas_bar[i] * inputs[i].t() @ R_th @ inputs[i] for i in range(n)
        ])
    
    raise ValueError(f"Unknown method {method}")

In [None]:
def plot_timegrid(deltas, x=None, ax=None, ylabel=None, title=None):
    times = np.cumsum(deltas.tolist())

    if ax is None:
        fig, ax = plt.subplots()
    for t in times:
        ax.axvline(t, color='gray', linestyle='--', alpha=0.25)
    
    if x is not None:
        ax.plot(times, x)
        
    ax.set_xlabel("Time")
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    if title is not None:
        ax.set_title(title)

In [None]:
def plot_colored(deltas, x, ax=None):
    times = np.cumsum(deltas)
    
    cmap = plt.get_cmap("viridis")
    norm = Normalize(vmin=np.min(deltas), vmax=np.max(deltas))

    if ax is None:
        fig = plt.figure()
        ax = plt.gca()
    
    for i in range(len(x)-1):
        # horizontal hold
        ax.hlines(x[i], times[i], times[i+1], 
                  colors=cmap(norm(deltas[i+1])), linewidth=2)
        # vertical jump
        ax.vlines(times[i+1], x[i], x[i+1], 
                  colors=cmap(norm(deltas[i+1])), linewidth=1)
        
    ax.set_xlabel("Time")
    ax.set_ylabel("Input")

    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax)
    cbar.set_label("deltas")
    
    return ax

### Training Loop

In [None]:
def plot_training_res(sol, history, method: int):
    s_arr = np.array([s.detach().numpy().tolist() for s in sol[0:n]])
    u_arr = np.array([u.detach().numpy().tolist() for u in sol[n:2*n]])
    if method == 1:
        d_arr = np.concatenate(np.array([d.detach().numpy().tolist() for d in sol[2*n:3*n]]))
    elif method == 2:
        d_arr = history[-1]['dts']
    else:
        raise ValueError(f"Unknown method {method}")

    fig, ax = plt.subplots(2, 3, figsize=(9.6, 9.6))
    ax[0, 0].plot([h['loss'] for h in history])

    ax[0, 0].set_xlabel("Epoch")
    ax[0, 0].set_ylabel("Loss")
    ax[0, 0].set_title("Loss Evolution")
    
    plot_colored(d_arr, u_arr, ax[1, 0])
    
    plot_timegrid(
        d_arr, s_arr, ax[0, 1],
        ylabel="State",
        title="State Evolution"
    )

    plot_timegrid(
        d_arr, u_arr, ax[1, 1],
        ylabel="Input",
        title="Input Evolution"
    )
    
    ax[0, 2].plot(np.cumsum(d_arr), d_arr)
    ax[0, 2].set_xlabel("Time")
    ax[0, 2].set_ylabel("Timestep duration")
    ax[0, 2].set_title("Timesteps Evolution")
    
    fig.delaxes(ax[1, 2])

    fig.set_constrained_layout(True)

In [None]:
def save_training_res(exp_name: str, sol, history, method: int):
    out_dir = f"out/{exp_name}"
    os.makedirs(out_dir, exist_ok=True)
    
    s_arr = np.array([s.detach().numpy().tolist() for s in sol[0:n]])
    u_arr = np.array([u.detach().numpy().tolist() for u in sol[n:2*n]])
    if method == 1:
        d_arr = np.concatenate(np.array([d.detach().numpy().tolist() for d in sol[2*n:3*n]]))
    elif method == 2:
        d_arr = history[-1]['dts']
    else:
        raise ValueError(f"Unknown method {method}")
    
    fig, ax = plt.subplots(1, 1, figsize=(3.2, 3.2))
    plot_colored(d_arr, u_arr, ax)
    fig.savefig(
        f"{out_dir}/input.pdf", bbox_inches='tight',
    )
    plt.close(fig)
    
    fig, ax = plt.subplots(1, 1, figsize=(3.2, 3.2))
    ax.plot(np.cumsum(d_arr), d_arr)
    ax.set_xlabel("Time")
    ax.set_ylabel("Timestep duration")
    fig.savefig(
        f"{out_dir}/timesteps.pdf", bbox_inches='tight',
    )
    plt.close(fig)
    
    fig, ax = plt.subplots(1, 1, figsize=(3.2, 3.2))
    ax.plot([h['loss'] for h in history])
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    fig.savefig(
        f"{out_dir}/loss.pdf", bbox_inches='tight',
    )
    plt.close(fig)

In [None]:
if rerun:
    n_epochs = 100

    history_aux  = []
    sol_aux = {}

    for method in loss_methods:
        print(f"Method: {method}")
        
        n = 160

        dts_torch = [torch.nn.Parameter(torch.ones(1) * dt) for _ in range(n)]

        optim = torch.optim.Adam(dts_torch, lr=5e-4)

        # =========================================================================== #

        with torch.no_grad():
            dt = T / n
            for d in dts_torch:
                d.copy_(torch.ones(1) * dt)

        s_bar = [s0 + (np.zeros(n_s) - s0) * i/n for i in range(n)]
        u_bar = [np.zeros(n_u) for _ in range(n)]
        with tqdm(total=n_epochs) as pbar:
            for epoch in range(n_epochs):
                pbar.update(1)
                
                optim.zero_grad()

                pann_param_clqr = create_pann_param_clqr(s_bar, u_bar, n)
                cvxpylayer = CvxpyLayer(
                    pann_param_clqr,
                    parameters=dts[:n],
                    variables=s[1:n+1] + u[:n] + deltas[:n],
                )
                sol_aux[method] = cvxpylayer(*dts_torch)
                            
                s_bar = [sol_aux[method][i].detach().numpy() for i in range(n)]
                u_bar = [sol_aux[method][n+i].detach().numpy() for i in range(n)]

                loss = task_loss(sol_aux[method], dts_torch, method=method)
                loss.backward()

                optim.step()
                with torch.no_grad():
                    for d in dts_torch:
                        d.clamp_(min=1e-6, max=0.07)
                        d *= T / sum(dts_torch)

                history_aux.append({
                    'method': method,
                    'loss': loss.item(),
                    'dts': [d.detach().numpy() for d in dts_torch]
                })

    with open(f"{out_dir}/sol_aux.pkl", "wb") as f:
        pickle.dump(sol_aux, f)
    with open(f"{out_dir}/history_aux.pkl", "wb") as f:
        pickle.dump(history_aux, f)
else:
    with open(f"{out_dir}/sol_aux.pkl", "rb") as f:
        sol_aux = pickle.load(f)
    with open(f"{out_dir}/history_aux.pkl", "rb") as f:
        history_aux = pickle.load(f)

In [None]:
n = 160
for method in loss_methods:
    plot_training_res(sol_aux[method], [h for h in history_aux if h['method'] == method], method=1)
    save_training_res(
        method.replace(" ", "_"),
        sol_aux[method],
        [h for h in history_aux if h['method'] == method],
        method=1,
    )

## DQP with Reparametrized Parameters

In [None]:
def theta_2_dt(theta):
    eps = 1e-3
    
    w = torch.softmax(theta.flatten(), dim=0)
    return eps + (T - n*eps) * w

### DQP With Reparametrized Timesteps

Reparametrize time steps on the simplex:
$$
\Delta t_k = \epsilon + (T - n \epsilon) \frac{e^{\theta_k}}{\sum_j e^{\theta_j}}
$$
This enforces both positivity and the total time constraint $\sum_k \Delta t_k = T$ without discontinuous updates.

The optimization vector of the QP is $\begin{bmatrix} \tilde{s}_{k=1,\dots,N} & \tilde{u}_{k=0,\dots,N-1}\end{bmatrix}$.

$$
\begin{align*}
& \min_{\substack{s_{k+1}, u_k \\ k=0, \dots, N}} \quad & \sum_{k=0}^{N} \left(\Delta t_k \left( s_k^T Q s_k + u_k^T R u_k \right)\right), \\
& \text{s.t.} \quad & s_{k+1} = \bar{s}_k + \Delta t_k (A \bar{s}_k + B \bar{u}_k) + (I + \Delta t_k A) \tilde{s}_k + \Delta t_k B \tilde{u}_k, \\
& & s_0 = s_{\text{init}}, \\
& & u_k \in U.
\end{align*}
$$

The OCP uses $\Delta t_k$ as a parameter.
(The optimizer uses $\theta_k$ as optimization parameters.)
Gradients do flow through the softmax function.

In [None]:
n = 160

def create_pann_param_clqr_2(n: int = N):
    dts = cp.Parameter(n, nonneg=True, name='dts')
    
    objective = cp.Minimize(
        cp.sum([cp.quad_form(s[i+1], Q) * dts[i] for i in range(n)]) +
        cp.sum([cp.quad_form(u[i], R) * dts[i] for i in range(n)])
    )
    
    dynamics_constraints = [
        s[i+1] == s[i] \
            + dts[i] * (A @ s[i] + B @ u[i]) \
            for i in range(n)
    ]
    
    input_limits = [
        cp.abs(u[i]) <= u_max for i in range(n)
    ]

    constraints = dynamics_constraints + input_limits

    problem = cp.Problem(objective, constraints)
    
    return problem, dts

pann_param_clqr_2, _ = create_pann_param_clqr_2(n)
assert pann_param_clqr_2.is_dpp()

In [None]:
loss_methods_2 = ["unscaled", "time scaled"]
loss_methods_2 = ["time scaled"]

def task_loss_2(sol, param, method="unscaled"):
    states = [sol[i] for i in range(n)]
    inputs = [sol[n+i] for i in range(n)]
    deltas = param
    
    Q_th = torch.tensor(Q, dtype=torch.float32, device=states[0].device)
    R_th = torch.tensor(R, dtype=torch.float32, device=states[0].device)
    
    if method == 'unscaled':
        return sum([
            si.t() @ Q_th @ si for si in states
        ]) + sum([
            ui.t() @ R_th @ ui for ui in inputs
        ])
    if method == 'time scaled':
        return sum([
            deltas[i] * states[i].t() @ Q_th @ states[i] for i in range(n)
        ]) + sum([
            deltas[i] * inputs[i].t() @ R_th @ inputs[i] for i in range(n)
        ])
    
    raise ValueError(f"Unknown method {method}")

### Task Loss

Unscaled
$$
\mathcal{L}_1 = \sum_{k=0}^{N} \left(\lVert s_k \rVert^2_Q  + \lVert u_k \rVert^2_R \right)
$$

Time scaled
$$
\mathcal{L}_2 = \sum_{k=0}^{N} \Delta t_k \left(\lVert s_k \rVert^2_Q  + \lVert u_k \rVert^2_R \right)
$$

In [None]:
if rerun:
    n_epochs = 200

    history_rep  = []
    sol_rep = {}

    for method in loss_methods_2:
        print(f"Method: {method}")
        
        n = 160

        theta = torch.nn.Parameter(torch.ones(n, 1))

        optim = torch.optim.Adam([theta], lr=1e-2)

        # =========================================================================== #

        s_bar = [s0 + (np.zeros(n_s) - s0) * i/n for i in range(n)]
        u_bar = [np.zeros(n_u) for _ in range(n)]
        with tqdm(total=n_epochs) as pbar:
            for epoch in range(n_epochs):
                pbar.update(1)
                
                optim.zero_grad(set_to_none=True)
                dts2_torch = theta_2_dt(theta)

                pann_param_clqr, dts2 = create_pann_param_clqr_2(n)
                cvxpylayer = CvxpyLayer(
                    pann_param_clqr,
                    parameters=[dts2],
                    variables=s[1:n+1] + u[:n],
                )
                sol_rep[method] = cvxpylayer(dts2_torch)
                
                s_bar = [sol_rep[method][i].detach().numpy() for i in range(n)]
                u_bar = [sol_rep[method][n+i].detach().numpy() for i in range(n)]

                loss = task_loss_2(sol_rep[method], dts2_torch, method=method)
                loss.backward()

                optim.step()

                history_rep.append({
                    'method': method,
                    'loss': loss.item(),
                    'dts': dts2_torch.detach().numpy(),
                })
                
    with open(f"{out_dir}/history_rep.pkl", "wb") as f:
        pickle.dump(history_rep, f)
    with open(f"{out_dir}/sol_rep.pkl", "wb") as f:
        pickle.dump(sol_rep, f)
else:
    with open(f"{out_dir}/history_rep.pkl", "rb") as f:
        history_rep = pickle.load(f)
    with open(f"{out_dir}/sol_rep.pkl", "rb") as f:
        sol_rep = pickle.load(f)

In [None]:
n = 160
for method in loss_methods_2:
    plot_training_res(sol_rep[method], [h for h in history_rep if h['method'] == method], method=2)
    save_training_res(
        method.replace(" ", "_") + "_rep",
        sol_rep[method],
        [h for h in history_rep if h['method'] == method],
        method=2,
    )

## Loss Hyper Sampling

### LHS - Resampling

In [None]:
def zoh_discretize(dt, A, B):
    """
    Compute exact ZOH discretization (Ad, Bd) via matrix exponential.
    Fully differentiable through torch.matrix_exp.
    
    Args:
        dt: Scalar timestep duration (torch tensor)
        A: Continuous-time state matrix (n_s, n_s)
        B: Continuous-time input matrix (n_s, n_u)
    
    Returns:
        Ad: Discrete-time state matrix (n_s, n_s)
        Bd: Discrete-time input matrix (n_s, n_u)
    """
    n_s, n_u = A.shape[0], B.shape[1]
    M = torch.zeros(n_s + n_u, n_s + n_u, device=dt.device, dtype=A.dtype)
    M[:n_s, :n_s] = A * dt
    M[:n_s, n_s:] = B * dt
    E = torch.matrix_exp(M)
    return E[:n_s, :n_s], E[:n_s, n_s:]


def uniform_resampling_loss(
    inputs_qp,      # list of n tensors from QP solution, each shape (n_u,)
    dts_torch,      # shape (n,), timesteps summing to T
    s0,             # initial state, shape (n_s,)
    A, B,           # system matrices (torch tensors)
    Q, R,           # cost matrices (torch tensors)
    T,              # total horizon
    n_res=1000,     # number of uniform grid points
    use_exact=False # if True, use matrix_exp; else Euler
):
    """
    Evaluate LQR cost on a dense uniform time grid.
    
    This loss function:
    1. Creates a uniform grid of n_res points over [0, T]
    2. Interpolates inputs using Zero-Order Hold from QP solution
    3. Simulates state forward using Euler (or exact) integration
    4. Computes Riemann sum approximation to continuous cost integral
    
    Gradients flow through inputs_qp -> QP solution -> dts_torch -> theta.
    The interval indices are computed with detached cumsum (discrete, no grad needed).
    """
    device = dts_torch.device
    dtype = dts_torch.dtype
    n = len(dts_torch)
    
    # Ensure tensors are on correct device
    A = torch.as_tensor(A, dtype=dtype, device=device)
    B = torch.as_tensor(B, dtype=dtype, device=device)
    Q = torch.as_tensor(Q, dtype=dtype, device=device)
    R = torch.as_tensor(R, dtype=dtype, device=device)
    s0 = torch.as_tensor(s0, dtype=dtype, device=device)
    
    # Dense uniform time grid
    dt_uniform = T / n_res
    t_uniform = torch.linspace(0, T - dt_uniform, n_res, device=device, dtype=dtype)
    
    # Cumulative times (end of each interval)
    t_cumsum = torch.cumsum(dts_torch, dim=0)
    
    # ZOH interpolation: find interval index for each uniform time point
    # Detach for index computation (indices are discrete, no gradient needed)
    indices = torch.searchsorted(t_cumsum.detach(), t_uniform, right=False)
    indices = torch.clamp(indices, 0, n - 1)
    
    # Stack inputs and index into them (gradients flow through u_stack)
    u_stack = torch.stack(inputs_qp, dim=0)  # (n, n_u)
    u_interp = u_stack[indices]  # (n_res, n_u)
    
    # Forward simulation on uniform grid
    s_list = []
    s_current = s0.clone()
    
    if use_exact:
        # Pre-compute discrete matrices for uniform timestep
        dt_uniform_t = torch.tensor(dt_uniform, device=device, dtype=dtype)
        Ad_uniform, Bd_uniform = zoh_discretize(dt_uniform_t, A, B)
        for j in range(n_res):
            s_list.append(s_current)
            s_current = Ad_uniform @ s_current + Bd_uniform @ u_interp[j]
    else:
        # Euler integration
        for j in range(n_res):
            s_list.append(s_current)
            s_current = s_current + dt_uniform * (A @ s_current + B @ u_interp[j])
    
    s_stack = torch.stack(s_list, dim=0)  # (n_res, n_s)
    
    # Compute loss: L = dt_uniform * sum_j (s_j^T Q s_j + u_j^T R u_j)
    state_cost = torch.sum((s_stack @ Q) * s_stack)
    input_cost = torch.sum((u_interp @ R) * u_interp)
    
    loss = dt_uniform * (state_cost + input_cost)
    
    return loss

### LHS - Substeps

In [None]:
def substep_loss(
    inputs_qp,      # list of n tensors from QP solution, each shape (n_u,)
    dts_torch,      # shape (n,), timesteps
    s0,             # initial state, shape (n_s,)
    A, B,           # system matrices (torch tensors)
    Q, R,           # cost matrices (torch tensors)
    n_sub=10,       # number of substeps per interval
    use_exact=False # if True, use matrix_exp per interval
):
    """
    Compute LQR cost with substeps within each non-uniform interval.
    
    This loss function:
    1. For each interval k of duration dt_k, creates n_sub substeps
    2. Applies constant input u_k throughout the interval
    3. Integrates cost contribution from each substep
    
    Gradients flow through:
    - inputs_qp -> QP solution -> dts_torch -> theta
    - dt_subs = dts_torch / n_sub preserves gradients
    """
    device = dts_torch.device
    dtype = dts_torch.dtype
    n = len(dts_torch)
    
    # Ensure tensors are on correct device
    A = torch.as_tensor(A, dtype=dtype, device=device)
    B = torch.as_tensor(B, dtype=dtype, device=device)
    Q = torch.as_tensor(Q, dtype=dtype, device=device)
    R = torch.as_tensor(R, dtype=dtype, device=device)
    s0 = torch.as_tensor(s0, dtype=dtype, device=device)
    
    # Substep durations (gradients preserved through division)
    dt_subs = dts_torch / n_sub  # (n,)
    
    # Expand inputs: each u_k repeated n_sub times
    u_stack = torch.stack(inputs_qp, dim=0)  # (n, n_u)
    u_expanded = u_stack.repeat_interleave(n_sub, dim=0)  # (n*n_sub, n_u)
    
    # Expand timesteps: each dt_sub_k repeated n_sub times
    dt_expanded = dt_subs.repeat_interleave(n_sub)  # (n*n_sub,)
    
    total_substeps = n * n_sub
    
    # Forward simulation
    s_list = []
    s_current = s0.clone()
    
    if use_exact:
        # Per-interval exact discretization
        for k in range(n):
            dt_sub_k = dt_subs[k]
            Ad_k, Bd_k = zoh_discretize(dt_sub_k, A, B)
            u_k = inputs_qp[k]
            for _ in range(n_sub):
                s_list.append(s_current)
                s_current = Ad_k @ s_current + Bd_k @ u_k
    else:
        # Euler integration (vectorized loop)
        for j in range(total_substeps):
            s_list.append(s_current)
            s_current = s_current + dt_expanded[j] * (A @ s_current + B @ u_expanded[j])
    
    s_stack = torch.stack(s_list, dim=0)  # (n*n_sub, n_s)
    
    # Compute loss: L = sum_j dt_sub_j * (s_j^T Q s_j + u_j^T R u_j)
    state_cost = torch.sum((s_stack @ Q) * s_stack, dim=1)  # (n*n_sub,)
    input_cost = torch.sum((u_expanded @ R) * u_expanded, dim=1)  # (n*n_sub,)
    
    loss = torch.sum(dt_expanded * (state_cost + input_cost))
    
    return loss

In [None]:
# Verify gradient flow for hyper-sampling losses

def test_gradient_flow():
    """Verify gradients flow correctly through both loss functions."""
    
    n_test = 20
    theta_test = torch.nn.Parameter(torch.ones(n_test, 1))
    
    A_torch = torch.tensor(A, dtype=torch.float32)
    B_torch = torch.tensor(B, dtype=torch.float32)
    Q_torch = torch.tensor(Q, dtype=torch.float32)
    R_torch = torch.tensor(R, dtype=torch.float32)
    s0_torch = torch.tensor(s0, dtype=torch.float32)
    
    # Get timesteps via simplex transformation
    dts_test = theta_2_dt(theta_test)
    
    # Create dummy inputs (simulate QP output with gradients)
    inputs_test = [torch.randn(n_u, requires_grad=True) for _ in range(n_test)]
    
    # === Test Approach 1: Uniform Resampling ===
    loss1 = uniform_resampling_loss(
        inputs_test, dts_test, s0_torch,
        A_torch, B_torch, Q_torch, R_torch,
        T=T, n_res=100, use_exact=False
    )
    loss1.backward()
    
    print("Approach 1 (Uniform Resampling - Euler):")
    print(f"  Loss value: {loss1.item():.6f}")
    print(f"  theta.grad exists: {theta_test.grad is not None}")
    if theta_test.grad is not None:
        print(f"  theta.grad norm: {theta_test.grad.norm().item():.6f}")
    
    # Reset gradients
    theta_test.grad = None
    for u in inputs_test:
        if u.grad is not None:
            u.grad = None
    
    # Recompute dts
    dts_test = theta_2_dt(theta_test)
    
    # === Test Approach 2: Substeps ===
    loss2 = substep_loss(
        inputs_test, dts_test, s0_torch,
        A_torch, B_torch, Q_torch, R_torch,
        n_sub=10, use_exact=False
    )
    loss2.backward()
    
    print("\nApproach 2 (Substeps - Euler):")
    print(f"  Loss value: {loss2.item():.6f}")
    print(f"  theta.grad exists: {theta_test.grad is not None}")
    if theta_test.grad is not None:
        print(f"  theta.grad norm: {theta_test.grad.norm().item():.6f}")

test_gradient_flow()

### Training with Hyper-Sampling Losses

In [None]:
loss_methods_hs = ["uniform_resample", "substeps"]

if rerun:
    n_epochs_hs = 400
    n_res_hs = 1000  # for uniform resampling
    n_sub_hs = 10    # for substeps
    
    history_hs = []
    sol_hs = {}
    
    # Convert system matrices to torch tensors
    A_torch = torch.tensor(A, dtype=torch.float32)
    B_torch = torch.tensor(B, dtype=torch.float32)
    Q_torch = torch.tensor(Q, dtype=torch.float32)
    R_torch = torch.tensor(R, dtype=torch.float32)
    s0_torch = torch.tensor(s0, dtype=torch.float32)
    
    for method in loss_methods_hs:
        print(f"Training with method: {method}")
        
        n = 160
        
        # Softmax-reparametrized timesteps
        theta = torch.nn.Parameter(torch.ones(n, 1))
        optim = torch.optim.Adam([theta], lr=1e-2)
        
        with tqdm(total=n_epochs_hs) as pbar:
            for epoch in range(n_epochs_hs):
                pbar.update(1)
                optim.zero_grad(set_to_none=True)
                
                # Get timesteps from simplex transformation
                dts_torch = theta_2_dt(theta)
                
                # Solve QP via CVXPyLayer
                pann_param_clqr, dts2 = create_pann_param_clqr_2(n)
                cvxpylayer = CvxpyLayer(
                    pann_param_clqr,
                    parameters=[dts2],
                    variables=s[1:n+1] + u[:n],
                )
                sol_hs[method] = cvxpylayer(dts_torch)
                
                # Extract inputs from QP solution
                inputs_qp = [sol_hs[method][n+i] for i in range(n)]
                
                # Compute hyper-sampling loss
                if method == "uniform_resample":
                    loss = uniform_resampling_loss(
                        inputs_qp, dts_torch, s0_torch,
                        A_torch, B_torch, Q_torch, R_torch,
                        T=T, n_res=n_res_hs, use_exact=False
                    )
                elif method == "substeps":
                    loss = substep_loss(
                        inputs_qp, dts_torch, s0_torch,
                        A_torch, B_torch, Q_torch, R_torch,
                        n_sub=n_sub_hs, use_exact=False
                    )
                else:
                    raise ValueError(f"Unknown method: {method}")
                
                loss.backward()
                optim.step()
                
                history_hs.append({
                    'method': method,
                    'epoch': epoch,
                    'loss': loss.item(),
                    'dts': dts_torch.detach().cpu().numpy(),
                })
        
        print(f"  Final loss: {history_hs[-1]['loss']:.6f}\n")
    
    with open(f"{out_dir}/sol_hs.pkl", "wb") as f:
        pickle.dump(sol_hs, f)
    with open(f"{out_dir}/history_hs.pkl", "wb") as f:
        pickle.dump(history_hs, f)
else:
    with open(f"{out_dir}/sol_hs.pkl", "rb") as f:
        sol_hs = pickle.load(f)
    with open(f"{out_dir}/history_hs.pkl", "rb") as f:
        history_hs = pickle.load(f)

In [None]:
# Plot results for hyper-sampling methods
n = 160

for method in loss_methods_hs:
    plot_training_res(
        sol_hs[method], 
        [h for h in history_hs if h['method'] == method], 
        method=2
    )
    save_training_res(
        method,
        sol_hs[method],
        [h for h in history_hs if h['method'] == method],
        method=2,
    )

In [None]:
# Comparison plot: uniform resampling vs substeps
fig, axs = plt.subplots(2, 2, figsize=(9.6, 6.4), constrained_layout=True)

# Loss curves
for method in loss_methods_hs:
    history_method = [h for h in history_hs if h['method'] == method]
    axs[0, 0].plot([h['loss'] for h in history_method], label=method)
axs[0, 0].set_xlabel("Epoch")
axs[0, 0].set_ylabel("Loss")
axs[0, 0].set_title("Loss Convergence")
axs[0, 0].legend()

# Final timestep distributions
for i, method in enumerate(loss_methods_hs):
    history_method = [h for h in history_hs if h['method'] == method]
    d_arr = history_method[-1]['dts']
    times = np.cumsum(d_arr)
    axs[0, 1].plot(times, d_arr, label=method)
axs[0, 1].set_xlabel("Time")
axs[0, 1].set_ylabel("Timestep duration")
axs[0, 1].set_title("Final Timestep Distributions")
axs[0, 1].legend()

# Timestep histograms
for i, method in enumerate(loss_methods_hs):
    history_method = [h for h in history_hs if h['method'] == method]
    d_arr = history_method[-1]['dts']
    axs[1, i].hist(d_arr.flatten(), bins=30, alpha=0.7, edgecolor='black')
    axs[1, i].set_xlabel("Timestep duration")
    axs[1, i].set_ylabel("Count")
    axs[1, i].set_title(f"Histogram: {method}")
    axs[1, i].axvline(T/n, color='r', linestyle='--', label=f"uniform={T/n:.4f}")
    axs[1, i].legend()

fig.suptitle("Comparison: Uniform Resampling vs Substeps", fontsize=14)
plt.show()

In [None]:
### Evaluate "True" Continuous-Time Cost

def evaluate_continuous_cost(inputs_qp, dts, s0, A, B, Q, R, T, n_eval=10000):
    """
    Evaluate the trajectory on a very dense grid to approximate true continuous cost.
    This is the "ground truth" for comparison.
    """
    A_t = torch.as_tensor(A, dtype=torch.float32)
    B_t = torch.as_tensor(B, dtype=torch.float32)
    Q_t = torch.as_tensor(Q, dtype=torch.float32)
    R_t = torch.as_tensor(R, dtype=torch.float32)
    s0_t = torch.as_tensor(s0, dtype=torch.float32)

    if isinstance(dts, torch.Tensor):
        dts_np = dts.detach().cpu().numpy()
    else:
        dts_np = np.array(dts)

    dt_eval = T / n_eval
    t_cumsum = np.cumsum(dts_np)

    # Simulate on very dense grid
    s_current = s0_t.clone()
    total_cost = 0.0

    for j in range(n_eval):
        t_j = j * dt_eval
        # Find which interval we're in
        k = np.searchsorted(t_cumsum, t_j, side='right')
        k = min(k, len(inputs_qp) - 1)

        u_j = inputs_qp[k]
        if isinstance(u_j, torch.Tensor):
            u_j = u_j.detach()
        else:
            u_j = torch.tensor(u_j, dtype=torch.float32)

        # Accumulate cost
        state_cost = float(s_current @ Q_t @ s_current)
        input_cost = float(u_j @ R_t @ u_j)
        total_cost += dt_eval * (state_cost + input_cost)

        # Euler step
        s_current = s_current + dt_eval * (A_t @ s_current + B_t @ u_j)

    return total_cost


# Compare the two hyper-sampling methods on "true" continuous cost
print("=== True Continuous-Time Cost Comparison ===\n")

for method in loss_methods_hs:
    sol = sol_hs[method]
    history_method = [h for h in history_hs if h['method'] == method]
    dts_final = history_method[-1]['dts']

    inputs_qp = [sol[n+i] for i in range(n)]

    true_cost = evaluate_continuous_cost(inputs_qp, dts_final, s0, A, B, Q, R, T)

    print(f"{method}:")
    print(f"  Training loss (final): {history_method[-1]['loss']:.4f}")
    print(f"  True continuous cost:  {true_cost:.4f}")
    print(f"  dt range: [{np.min(dts_final):.5f}, {np.max(dts_final):.5f}]")
    print(f"  dt std:   {np.std(dts_final):.5f}")
    print()

# Also compare with existing methods if available
print("=== Comparison with Other Methods ===\n")

try:
    for method in loss_methods_2:
        sol = sol_rep[method]
        history_method = [h for h in history_rep if h['method'] == method]
        dts_final = history_method[-1]['dts']
        inputs_qp = [sol[n+i] for i in range(n)]
        true_cost = evaluate_continuous_cost(inputs_qp, dts_final, s0, A, B, Q, R, T)
        print(f"Reparametrized ({method}): true_cost = {true_cost:.4f}")
except:
    pass

try:
    for method in loss_methods_zoh:
        sol = sol_zoh[method]
        history_method = [h for h in history_zoh if h['method'] == method]
        dts_final = history_method[-1]['dts']
        inputs_qp = [sol[n+i] for i in range(n)]
        true_cost = evaluate_continuous_cost(inputs_qp, dts_final, s0, A, B, Q, R, T)
        print(f"ZOH ({method}): true_cost = {true_cost:.4f}")
except:
    pass


## DQP With ZOH Exact Discretization

Use the exact discretization of the LTI system with zero-order hold (ZOH).

OCP parameters:
$$
A_{d, k}, B_{d, k}
= \operatorname{ZOH}(A, B, \Delta t_k)
= \exp \left( \begin{bmatrix} A & B \\ 0 & 0 \end{bmatrix} \Delta t_k \right)
= \begin{bmatrix} A_{d, k} & B_{d, k} \\ 0 & I \end{bmatrix}
$$

### No Cost Scaling With Interval Duration

In [None]:
n = 160

Aps = [cp.Parameter((n_s, n_s), name=f"Ad_{k}") for k in range(n)]
Bps = [cp.Parameter((n_s, n_u), name=f"Bd_{k}") for k in range(n)]

def create_exact_param_pann_clqr(n: int = N):
    objective = cp.sum([cp.quad_form(s[k+1], Q) + cp.quad_form(u[k], R) for k in range(n)])
    
    constraints = [ s[k+1] == Aps[k] @ s[k] + Bps[k] @ u[k] for k in range(n) ] \
                + [ cp.abs(u[k]) <= u_max for k in range(n) ]

    prob = cp.Problem(cp.Minimize(objective), constraints)
    
    return prob

prob = create_exact_param_pann_clqr(n)
assert prob.is_dpp()

layer = CvxpyLayer(prob, parameters=Aps + Bps, variables=s[1:n+1] + u[:n])

def Ad_Bd_from_dt(dt):
    # Block matrix exponential per Van Loan
    M = torch.zeros(n_s + n_u, n_s + n_u, dtype=torch.float32, device=dt.device)
    M[:n_s, :n_s] = torch.tensor(A, dtype=torch.float32, device=dt.device) * dt
    M[:n_s, n_s:] = torch.tensor(B, dtype=torch.float32, device=dt.device) * dt
    E = torch.matrix_exp(M)
    return E[:n_s, :n_s], E[:n_s, n_s:]

In [None]:
loss_methods_zoh = ["unscaled", "time scaled"]
loss_methods_zoh = ["time scaled"]

n_epochs_zoh = 1
history_zoh = []
sol_zoh = {}

for method in loss_methods_zoh:
    print(f"ZOH Method: {method}")

    # Choose horizon for training (you can change to 200, etc.)
    n = 160

    # Softmax-reparam timesteps: theta -> dt on simplex (uses your theta_2_dt)
    theta = torch.nn.Parameter(torch.ones(n, 1))
    optim = torch.optim.Adam([theta], lr=1e-2)

    # Build the ZOH layer for this n
    zoh_layer = layer

    with tqdm(total=n_epochs_zoh) as pbar:
        for epoch in range(n_epochs_zoh):
            pbar.update(1)
            optim.zero_grad(set_to_none=True)

            # Current timesteps from parameters (positive, sum to T)
            dts_torch = theta_2_dt(theta)  # shape: (n,)

            # Build (Ad,Bd) lists for the layer call
            Ad_list, Bd_list = zip(*[Ad_Bd_from_dt(dt_k) for dt_k in dts_torch])

            # Solve QP via CVXPYLayer forward pass
            sol_zoh[method] = zoh_layer(*Ad_list, *Bd_list)

            # Compute task loss (unscaled or time-scaled)
            loss = task_loss_2(sol_zoh[method], dts_torch, method=method)
            loss.backward()
            optim.step()

            # Log history
            history_zoh.append({
                "method": method,
                "epoch": epoch,
                "loss": float(loss.item()),
                "dts": dts_torch.detach().cpu().numpy(),  # keep full grid
            })

n = 160
for method in loss_methods_zoh:
    plot_training_res(sol_zoh[method], [h for h in history_zoh if h['method'] == method], method=2)

In [None]:
loss_methods_zoh = ["unscaled", "time scaled"]
loss_methods_zoh = ["time scaled"]

if rerun:
    n_epochs_zoh = 300
    history_zoh = []
    sol_zoh = {}

    for method in loss_methods_zoh:
        print(f"ZOH Method: {method}")

        # Choose horizon for training (you can change to 200, etc.)
        n = 160

        # Softmax-reparam timesteps: theta -> dt on simplex (uses your theta_2_dt)
        theta = torch.nn.Parameter(torch.ones(n, 1))
        optim = torch.optim.Adam([theta], lr=1e-2)

        # Build the ZOH layer for this n
        zoh_layer = layer

        with tqdm(total=n_epochs_zoh) as pbar:
            for epoch in range(n_epochs_zoh):
                pbar.update(1)
                optim.zero_grad(set_to_none=True)

                # Current timesteps from parameters (positive, sum to T)
                dts_torch = theta_2_dt(theta)  # shape: (n,)

                # Build (Ad,Bd) lists for the layer call
                Ad_list, Bd_list = zip(*[Ad_Bd_from_dt(dt_k) for dt_k in dts_torch])

                # Solve QP via CVXPYLayer forward pass
                sol_zoh[method] = zoh_layer(*Ad_list, *Bd_list)

                # Compute task loss (unscaled or time-scaled)
                loss = task_loss_2(sol_zoh[method], dts_torch, method=method)
                loss.backward()
                optim.step()

                # Log history
                history_zoh.append({
                    "method": method,
                    "epoch": epoch,
                    "loss": float(loss.item()),
                    "dts": dts_torch.detach().cpu().numpy(),  # keep full grid
                })
                
    with open(f"{out_dir}/sol_zoh.pkl", "wb") as f:
        pickle.dump(sol_zoh, f)
    with open(f"{out_dir}/history_zoh.pkl", "wb") as f:
        pickle.dump(history_zoh, f)
else:
    with open(f"{out_dir}/sol_zoh.pkl", "rb") as f:
        sol_zoh = pickle.load(f)
    with open(f"{out_dir}/history_zoh.pkl", "rb") as f:
        history_zoh = pickle.load(f)

In [None]:
n = 160
for method in loss_methods_zoh:
    plot_training_res(sol_zoh[method], [h for h in history_zoh if h['method'] == method], method=2)
    save_training_res(
        method.replace(" ", "_") + "_zoh",
        sol_zoh[method],
        [h for h in history_zoh if h['method'] == method],
        method=2,
    )

#### WTF

Check that the ZOH discretization is correct.

In [None]:
# states = [sol_zoh[method][i] for i in range(n)]
# inputs = [sol_zoh[method][n+i] for i in range(n)]
# deltas = dts_torch.detach().cpu().numpy()

# def shoot_dyn(s0, inputs, deltas):
#     dt = 0.00001
    
#     time = 0.0
    
#     times = np.cumsum(deltas)
    
#     s = [s0]
    
#     while time <= T:
#         i = np.searchsorted(times, time, side="right") - 1

#         s.append(s[-1] + (A @ s[-1] + B @ inputs[i].detach().cpu().numpy()) * dt)

#         time = time + dt
        
#     return s
    
# s_hist = shoot_dyn(s0, inputs, deltas)

# plt.figure()
# plt.plot(np.array(s_hist))
# plt.show()

### With Cost Scaling With Interval Duration

In [None]:
n = 160

Aps = [cp.Parameter((n_s, n_s), name=f"Ad_{k}") for k in range(n)]
Bps = [cp.Parameter((n_s, n_u), name=f"Bd_{k}") for k in range(n)]

LQs = [cp.Parameter((n_s, n_s), PSD=True, name=f"Q_{k}") for k in range(n)]
LRs = [cp.Parameter((n_u, n_u), PSD=True, name=f"R_{k}") for k in range(n)]

def create_exact_param_pann_clqr_2(n: int = N):
    objective = cp.sum([
        cp.sum_squares(LQs[k] @ s[k+1]) + cp.sum_squares(LRs[k] @ u[k])
        for k in range(n)
    ])

    constraints = [ s[k+1] == Aps[k] @ s[k] + Bps[k] @ u[k] for k in range(n) ] \
                + [ cp.abs(u[k]) <= u_max for k in range(n) ]

    prob = cp.Problem(cp.Minimize(objective), constraints)
    
    return prob

prob_2 = create_exact_param_pann_clqr_2(n)
assert prob_2.is_dpp()

layer = CvxpyLayer(prob_2, parameters=Aps + Bps + LQs + LRs, variables=s[1:n+1] + u[:n])

def LQs_LRs_from_dt(dts):
    Q_np = torch.as_tensor(Q, dtype=torch.float32)
    R_np = torch.as_tensor(R, dtype=torch.float32)
    LQ0  = torch.linalg.cholesky(Q_np)  # constant
    LR0  = torch.linalg.cholesky(R_np)  # constant
    LQs  = [torch.sqrt(dt) * LQ0 for dt in dts]
    LRs  = [torch.sqrt(dt) * LR0 for dt in dts]
    return LQs, LRs

In [None]:
loss_methods_zoh_2 = ["time scaled"]

if rerun:
    n_epochs_zoh = 500
    history_zoh_2 = []
    sol_zoh_2 = {}

    for method in loss_methods_zoh_2:
        print(f"ZOH Method: {method}")

        # Choose horizon for training (you can change to 200, etc.)
        n = 160

        # Softmax-reparam timesteps: theta -> dt on simplex (uses your theta_2_dt)
        theta = torch.nn.Parameter(torch.ones(n, 1))
        optim = torch.optim.Adam([theta], lr=1e-2)

        # Build the ZOH layer for this n
        zoh_layer = layer

        with tqdm(total=n_epochs_zoh) as pbar:
            for epoch in range(n_epochs_zoh):
                pbar.update(1)
                optim.zero_grad(set_to_none=True)

                # Current timesteps from parameters (positive, sum to T)
                dts_torch = theta_2_dt(theta)  # shape: (n,)

                # Build (Ad,Bd) lists for the layer call
                Ad_list, Bd_list = zip(*[Ad_Bd_from_dt(dt_k) for dt_k in dts_torch])
                LQs_list, LRs_list = LQs_LRs_from_dt(dts_torch)

                # Solve QP via CVXPYLayer forward pass
                sol_zoh_2[method] = zoh_layer(*Ad_list, *Bd_list, *LQs_list, *LRs_list)

                # Compute task loss (unscaled or time-scaled)
                loss = task_loss_2(sol_zoh_2[method], dts_torch, method=method)
                loss.backward()
                optim.step()

                # Log history
                history_zoh_2.append({
                    "method": method,
                    "epoch": epoch,
                    "loss": float(loss.item()),
                    "dts": dts_torch.detach().cpu().numpy(),  # keep full grid
                })
                
    with open(f"{out_dir}/sol_zoh_2.pkl", "wb") as f:
        pickle.dump(sol_zoh_2, f)
    with open(f"{out_dir}/history_zoh_2.pkl", "wb") as f:
        pickle.dump(history_zoh_2, f)
else:
    with open(f"{out_dir}/sol_zoh_2.pkl", "rb") as f:
        sol_zoh_2 = pickle.load(f)
    with open(f"{out_dir}/history_zoh_2.pkl", "rb") as f:
        history_zoh_2 = pickle.load(f)
n = 160
for method in loss_methods_zoh_2:
    plot_training_res(sol_zoh_2[method], [h for h in history_zoh_2 if h['method'] == method], method=2)
    save_training_res(
        method.replace(" ", "_") + "_zoh_2",
        sol_zoh_2[method],
        [h for h in history_zoh_2 if h['method'] == method],
        method=2,
    )

## Sampling Density and Trajectory Change Analysis

In [None]:
def extract_trajectory_data(ms, n):
    """Extract states, inputs, and timesteps from a method solution."""
    sol, history, sol_method = ms['sol'], ms['history'], ms.get('sol_method', 2)
    
    s_arr = np.array([sol[i].detach().numpy() for i in range(n)])
    u_arr = np.array([sol[n+i].detach().numpy() for i in range(n)]).flatten()
    
    if sol_method == 1:
        dts = np.concatenate([sol[2*n+i].detach().numpy() for i in range(n)]).flatten()
    else:
        dts = np.array(history[-1]['dts']).flatten()
    
    times = np.cumsum(dts)
    return {'s': s_arr, 'u': u_arr, 'dts': dts, 'times': times}


def compute_trajectory_metrics(data, n):
    """Compute sampling density and trajectory change metrics."""
    s, u, dts = data['s'], data['u'], data['dts']
    dt_uniform = T / n
    
    return {
        'sampling_density': (1.0 / dts) * dt_uniform,
        'abs_u': np.abs(u),
        'delta_u': np.abs(np.diff(u)),
        'norm_s': np.linalg.norm(s, axis=1),
        'delta_s': np.linalg.norm(np.diff(s, axis=0), axis=1),
    }

In [None]:
def plot_density_and_changes(data, metrics, method_name, colors, axes=None):
    """Plot sampling density, |Delta u|, and ||Delta s|| on the same axes."""
    times = data['times']
    ax = axes if axes is not None else plt.subplots(figsize=(3.2, 2.4))[1]
    
    ax.plot(times, metrics['sampling_density'], label=r'Sampling density', color=colors[0])
    ax.plot(times[:-1], metrics['delta_u'], label=r'$|\Delta u|$', color=colors[1])
    ax.plot(times[:-1], metrics['delta_s'], label=r'$\|\Delta s\|_2$', color=colors[2])
    ax.axhline(1.0, color='gray', linestyle=':', alpha=0.5)
    ax.set(ylabel='Value', title=method_name)
    ax.legend(loc='upper right', fontsize=7)

In [None]:
def compute_cross_correlation(x, y, max_lag=20):
    """Normalized cross-correlation. Positive lag: x[k] vs y[k+lag]."""
    x = (x - np.mean(x)) / (np.std(x) + 1e-10)
    y = (y - np.mean(y)) / (np.std(y) + 1e-10)
    m = min(len(x), len(y))
    lags = np.arange(-max_lag, max_lag + 1)
    ccf = np.array([
        np.mean(x[:m-lag] * y[lag:m]) if lag >= 0 else np.mean(x[-lag:m] * y[:m+lag])
        for lag in lags if (m - abs(lag)) > 0
    ])
    return lags, ccf


def plot_cross_correlations(data, metrics, method_name, colors, max_lag=30):
    """Plot cross-correlation between sampling density and trajectory metrics."""
    sd = metrics['sampling_density']
    times = data['times']
    n = len(sd)
    sig_bound = 1.96 / np.sqrt(n)
    
    quantities = [
        ('delta_u', r'$|\Delta u|$', sd[:-1]),
        ('delta_s', r'$\|\Delta s\|_2$', sd[:-1]),
        ('abs_u', r'$|u|$', sd),
        ('norm_s', r'$\|s\|_2$', sd),
        ('time', r'$t$', sd),
    ]
    
    # Add time as a metric for correlation
    metrics_with_time = {**metrics, 'time': times}
    
    fig, axs = plt.subplots(1, len(quantities), figsize=(2.8 * len(quantities), 2.8))
    
    for ax, (key, label, sd_aligned) in zip(axs, quantities):
        metric_vals = metrics_with_time[key]
        if len(metric_vals) > len(sd_aligned):
            metric_vals = metric_vals[:len(sd_aligned)]
        lags, ccf = compute_cross_correlation(sd_aligned, metric_vals, max_lag=max_lag)
        ax.plot(lags, ccf, color=colors[0], marker='o', markersize=2)
        ax.axhline(0, color='gray', linestyle='-', alpha=0.3)
        ax.axvline(0, color='gray', linestyle='--', alpha=0.3)
        ax.axhline(sig_bound, color=colors[5], linestyle=':', alpha=0.5)
        ax.axhline(-sig_bound, color=colors[5], linestyle=':', alpha=0.5)
        ax.set(xlabel='Lag', ylabel='CCF', title=f'SD vs {label}')
        
        peak_idx = np.argmax(np.abs(ccf))
        peak_lag, peak_val = lags[peak_idx], ccf[peak_idx]
        
        y_offset = -25 if peak_val > 0 else 15
        ax.annotate(f'{peak_val:.2f} @ {peak_lag}',
                    xy=(peak_lag, peak_val), fontsize=8,
                    xytext=(0, y_offset), textcoords='offset points',
                    ha='center', bbox=dict(boxstyle='round,pad=0.2', fc='white', alpha=0.7))
    
    fig.suptitle(f'{method_name}', fontsize=11)
    return fig

### Collect All Methods

In [None]:
n = 160
method_solutions = {}

method_configs = [
    ("Aux", loss_methods, sol_aux, history_aux, 1),
    ("Rep", loss_methods_2, sol_rep, history_rep, 2),
    ("HS", loss_methods_hs, sol_hs, history_hs, 2),
    ("ZOH", loss_methods_zoh, sol_zoh, history_zoh, 2),
    ("ZOH2", loss_methods_zoh_2, sol_zoh_2, history_zoh_2, 2),
]

for prefix, methods, sol_dict, history_list, sol_method in method_configs:
    try:
        for method in methods:
            key = f"{prefix}: {method}"
            method_solutions[key] = {
                'sol': sol_dict[method],
                'history': [h for h in history_list if h['method'] == method],
                'sol_method': sol_method,
            }
    except Exception as e:
        print(f"Could not load {prefix} methods: {e}")

print(f"Loaded {len(method_solutions)} methods: {list(method_solutions.keys())}")

### Sampling Density vs Trajectory Changes

In [None]:
n_methods = len(method_solutions)
n_rows = int(np.ceil(n_methods / 2))
fig, axs = plt.subplots(n_rows, 2, figsize=(10, 2.5 * n_rows), squeeze=False)

for i, (key, ms) in enumerate(method_solutions.items()):
    data = extract_trajectory_data(ms, n)
    metrics = compute_trajectory_metrics(data, n)
    plot_density_and_changes(data, metrics, key, colors, axes=axs[i // 2, i % 2])

for j in range(i + 1, n_rows * 2):
    fig.delaxes(axs[j // 2, j % 2])

### Cross-Correlation (Time-Lagged)

Cross correlation of the Sampling Density (SD) with:
- $\| \Delta u \|$
- $\| \Delta s \|_2$
- $\| u \|$
- $\| s \|_2$

On the y-axis, the cross-correlation factor.

The box indicates the maximum CC and the time lag at which it happens.

The dotted lines at $\pm 1.96 / \sqrt{n}$ â‰ˆ Â±0.155 (for n=160) are the 95% CI. Outside bounds -> statistically significant correlation.

- **lag = 0**: Instantaneous correlation
- **lag > 0**: Does high sampling density *precede* large changes?
- **lag < 0**: Does high sampling density *follow* large changes?

In [None]:
for key, ms in method_solutions.items():
    data = extract_trajectory_data(ms, n)
    metrics = compute_trajectory_metrics(data, n)
    plot_cross_correlations(data, metrics, key, colors, max_lag=30)