In [None]:
import time

import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer
import matplotlib.pyplot as plt
import numpy as np
import torch

from dimp.utils import init_matplotlib


init_matplotlib()

In [None]:
A = np.array([
    [-0.1, 0, 0],
    [0, -2, -6.25],
    [0, 4, 0]
])

B = np.array([[0.25], [2.0], [0.0]])

x0 = np.array([1.344, -4.585, 5.674])

T = 10.0

N = 1000

dt = T / N

Q = 1.0 * np.eye(3)
R = 0.1 * np.eye(1)

In [None]:
x = [x0] + [cp.Variable(3, name=f"x_{i}") for i in range(N)]
u = [cp.Variable(1, name=f"u_{i}") for i in range(N)]

def create_clqr(n: int = N, dt: float = dt):
    objective = cp.Minimize(
        cp.sum([cp.quad_form(x[i+1], Q) for i in range(n)]) * dt +
        cp.sum([cp.quad_form(u[i], R) for i in range(n)]) * dt
    )
    
    dynamics_constraints = [
        x[i+1] == x[i] + (A @ x[i] + B @ u[i]) * dt for i in range(n)
    ]
    
    input_limits = [
        cp.abs(u[i]) <= 1.0 for i in range(n)
    ]

    constraints = dynamics_constraints + input_limits

    problem = cp.Problem(objective, constraints)
    
    return problem

big_clqr = create_clqr()
assert big_clqr.is_dpp()

In [None]:
interval = 80

for n in range(interval, N + 1, interval):
    dt = T / n
    big_clqr = create_clqr(n=n, dt=dt)

    start_time = time.time()
    big_clqr.solve()
    solve_time = time.time() - start_time

    print(f"Optimal cost for n={n}: {big_clqr.objective.value:.4f}")
    print(f"Solve time meas for n={n}: {solve_time} s")
    print(f"Solve time for n={n}: {big_clqr.solver_stats.solve_time:.4f} s")
    print()
    
    if n> 5*interval:
        continue

    times = np.arange(n) * dt
    x_vec = np.array([x.value for x in x[1:n+1]])
    u_vec = np.array([u.value for u in u[:n]])

    fig, axs = plt.subplots(2, 1)

    axs[0].plot(times, x_vec, label=['x', 'y', 'z'])
    axs[0].set(
        xlabel='Time',
        ylabel='State',
    )

    axs[1].plot(times, u_vec)
    axs[1].set(
        xlabel='Time',
        ylabel='Control Input',
    )

    axs[0].legend()

In [None]:
dts = [cp.Parameter(nonneg=True, name=f'dt_{i}') for i in range(N)]

def create_clqr(n: int = N):
    objective = cp.Minimize(
        cp.sum([cp.quad_form(x[i+1], Q) * dts[i] for i in range(n)]) +
        cp.sum([cp.quad_form(u[i], R) * dts[i] for i in range(n)])
    )
    
    dynamics_constraints = [
        x[i+1] == x[i] + (A @ x[i] + B @ u[i]) * dts[i] for i in range(n)
    ]
    
    input_limits = [
        cp.abs(u[i]) <= 1.0 for i in range(n)
    ]

    constraints = dynamics_constraints + input_limits

    problem = cp.Problem(objective, constraints)
    
    return problem

big_clqr = create_clqr()
assert big_clqr.is_dpp()

In [None]:
nx = 100
nu = 20

rng = np.random.default_rng(42)

A = rng.uniform(-1, 1, (nx, nx))
B = rng.uniform(-1, 1, (nx, nu))
x0 = rng.uniform(-1, 1, nx)

T = 0.5

N = 100

dt = T / N

Q = 1.0 * np.eye(nx)
R = 0.1 * np.eye(nu)

In [None]:
x = [x0] + [cp.Variable(nx, name=f"x_{i}") for i in range(N)]
u = [cp.Variable(nu, name=f"u_{i}") for i in range(N)]

def create_big_clqr(n: int = N, dt: float = dt):
    objective = cp.Minimize(
        cp.sum([cp.quad_form(x[i+1], Q) for i in range(n)]) * dt +
        cp.sum([cp.quad_form(u[i], R) for i in range(n)]) * dt
    )
    
    dynamics_constraints = [
        x[i+1] == x[i] + (A @ x[i] + B @ u[i]) * dt for i in range(n)
    ]
    
    input_limits = [
        cp.abs(u[i]) <= 10.0 for i in range(n)
    ]

    constraints = dynamics_constraints + input_limits

    problem = cp.Problem(objective, constraints)
    
    return problem

big_clqr = create_big_clqr()
assert big_clqr.is_dpp()

In [None]:
n = N
dt = T / n
big_clqr = create_big_clqr(n=n, dt=dt)

start_time = time.time()
big_clqr.solve()
solve_time = time.time() - start_time

# print(f"Optimal cost for n={n}: {big_clqr.objective.value:.4f}")
print(f"Solve time meas for n={n}: {solve_time} s")
print(f"Solve time for n={n}: {big_clqr.solver_stats.solve_time:.4f} s")
print()

times = np.arange(n) * dt
x_vec = np.array([x.value for x in x[1:n+1]])
u_vec = np.array([u.value for u in u[:n]])

fig, axs = plt.subplots(2, 1)

axs[0].plot(times, x_vec)
axs[0].set(
    xlabel='Time',
    ylabel='State',
)

axs[1].plot(times, u_vec)
axs[1].set(
    xlabel='Time',
    ylabel='Control Input',
)

In [None]:
n = N//2
dt = T / n

deltas = cp.Variable(n, name='deltas')
dts = cp.Parameter(n, nonneg=True, name='dts')

def create_clqr(n: int = N):
    objective = cp.Minimize(
        cp.sum([cp.quad_form(x[i+1], Q) * dts[i] for i in range(n)]) +
        cp.sum([cp.quad_form(u[i], R) * dts[i] for i in range(n)])
    )
    
    dynamics_constraints = [
        x[i+1] == x[i] + (A @ x[i] + B @ u[i]) * dts[i] for i in range(n)
    ]
    
    input_limits = [
        cp.abs(u[i]) <= 10.0 for i in range(n)
    ]
    
    timestep_constraints = [
        deltas[i] == dts[i] for i in range(n)
    ] + [
        cp.sum(deltas) == T
    ]

    constraints = dynamics_constraints + input_limits + timestep_constraints

    problem = cp.Problem(objective, constraints)
    
    return problem

big_clqr = create_clqr(n=n)
assert big_clqr.is_dpp()

In [None]:
cvxpylayer = CvxpyLayer(
    big_clqr,
    parameters=[dts],
    variables=x[1:n+1] + u[:n] + [deltas],
)

dts_torch = torch.nn.Parameter(torch.ones(n) * dt)

optim = torch.optim.Adam([dts_torch], lr=1e-3)

sol = cvxpylayer(dts_torch)

In [None]:
def task_loss(sol):
    states = [sol[i] for i in range(n)]
    inputs = [sol[n+i] for i in range(n)]
    
    Q_th = torch.tensor(Q, dtype=torch.float32, device=states[0].device)
    R_th = torch.tensor(R, dtype=torch.float32, device=states[0].device)
    
    return sum([
        si.t() @ Q_th @ si for si in states
    ]) + sum([
        ui.t() @ R_th @ ui for ui in inputs
    ])


In [None]:
n_epochs = 10
history  = []

with torch.no_grad():
    dts_torch.copy_(torch.ones(n) * dt)

for epoch in range(n_epochs):
    optim.zero_grad()

    sol = cvxpylayer(dts_torch)

    loss = task_loss(sol)
    loss.backward()

    optim.step()
    
    history.append({
        'loss': loss.item(),
        'dts': dts_torch.detach().numpy()
    })