In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import time
import pathlib

In [None]:
import torch
if torch.__version__ >= '2.0.0':
    from torch import func as functorch
else:
    import functorch
from fista import QP, FISTA

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps:0")
elif torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

In [None]:
# test problem 
# min x1^2 + x2^2 + x3^2 + x1 + 2*x2 s.t. x1 = 0.1, -1 <= x <= 1, mu = 1.0
# min x1^2 + x2^2 + x3^2 + 4*x1 + 2*x2 s.t. x2 = 1, -3 <= x <= 3, mu = 1.0

num_batches = 2
num_vars = 3
num_eqc = 1
num_boxc = 2

# quadratic cost
Q = torch.zeros(num_batches, num_vars, num_vars)
Q[:, torch.arange(num_vars), torch.arange(num_vars)] = 1

q = torch.zeros(num_batches, num_vars)
q[0, 0], q[0, 1] = 1., 2.
q[1, 0], q[1, 1] = 4., 2.

# equality constraints
rho = 0.1 # set to zero if no equality constraint is needed
A = torch.zeros(num_batches, num_eqc, num_vars)
b = torch.zeros(num_batches, num_eqc)
A[0, 0, 0], b[0, 0] = 1., 0.1
A[1, 0, 1], b[1, 0] = 1., 1.

# box constraints
lb = torch.zeros(num_batches, num_vars)
ub = torch.zeros(num_batches, num_vars)
lb[0], ub[0] = -1., 1.
lb[1], ub[1] = -3., 3.

# friction cone constraints
mu = torch.ones(num_batches, 1)

In [None]:
def solve(solver, max_it):
    solver.reset()
    for i in range(max_it):
        solver.step()
    return solver.prob

In [None]:
prob = QP(num_batches, num_vars, num_eqc, friction_coeff=None, device=device)
prob.set_data(Q, q, A, b, rho, lb, ub)
solver = FISTA(prob, device=device)

In [None]:
# %timeit -n 100 solve(solver, 100)

In [None]:
prob = solve(solver, 100)
prob.xk

In [None]:
# quick test of the friction cone projection
forces = torch.ones((2, 6))
forces[1, :3] *= 2
forces[1, 3:5] *= 0
forces[0, 3:6] *= 0
forces[0, 5] = -1.
print(forces)
print(solver.proj_friction_cone(forces, mu))