In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import osqp
from scipy import sparse
from mysqp import NLPBuilder

jax.config.update("jax_enable_x64", False)

In [2]:
nlp = NLPBuilder(dim=4)
nlp.set_f(
    lambda x:x[0] * x[3] * jnp.sum(x[:3]) + x[2]
)
nlp.add_eq_const(lambda x:jnp.sum(x**2), 40)
nlp.add_ineq_const_lb(lambda x:jnp.prod(x), 25)
nlp.set_state_bound(jnp.full(4, 1), jnp.full(4, 5))


In [2]:
nlp = NLPBuilder(dim=2)
nlp.set_f(
    lambda x: (1-x[0])**2 + 100 *(x[1]-x[0]**2)**2   #x[0] * x[3] * jnp.sum(x[:3]) + x[2]
)
#nlp.add_eq_const(lambda x:jnp.sum(x**2), 40)
#nlp.add_ineq_const_lb(lambda x:jnp.prod(x), 25)
nlp.set_state_bound(jnp.full(2, -2), jnp.full(2, 2))

In [3]:
class SQP:
    def __init__(
            self,
            dim,
            const_dim,
            convexify_fn,
            merit_fn,
            armijo_merit_fn,
            max_viol_fn,
            verbose = True
        ):
        self.qpsolver = osqp.OSQP()
        self.dim = dim
        self.const_dim = const_dim
        self.convexify_fn = convexify_fn
        self.merit_fn = merit_fn
        self.max_viol_fn = max_viol_fn
        self.armijo_merit_fn = armijo_merit_fn
        self.is_qp_init = False
        self.verbose = verbose
        self.prebuild()

    def prebuild(self):
        x = jnp.zeros(self.dim)
        m = jnp.zeros(self.const_dim)
        self.convexify_fn = jax.jit(self.convexify_fn).lower(x,m).compile()
        self.merit_fn = jax.jit(self.merit_fn).lower(x, 1.).compile()
        self.armijo_merit_fn = jax.jit(self.armijo_merit_fn).lower(x, x, 1.).compile()
        self.max_viol_fn = jax.jit(self.max_viol_fn).lower(x,m).compile()
        
    def solve(self, x0, max_iter=100, const_viol_tol=0.001):
        sigma = 0.
        x = x0
        m = jnp.zeros(self.const_dim) # lambda

        print("sqp start")
        for i in range(max_iter):
            max_viol = self.max_viol_fn(x, m)
            if max_viol < const_viol_tol:
                return x
            
            P, q, A, l, u = self.convexify_fn(x, m)
            P = sparse.csc_matrix(P)
            A = sparse.csc_matrix(A)
            q = np.asarray(q)
            l = np.asarray(l)
            u = np.asarray(u)

            if self.is_qp_init == False:
                opts = {"verbose":False}
                self.qpsolver.setup(P, q, A, l, u, **opts)
                self.is_qp_init = True
            else:
                self.qpsolver.update(
                    Px=P.data, q=q, Ax=A.data, l=l, u=u
                )
            res = self.qpsolver.solve()
            if res.info.status != "solved":
                raise NotImplementedError("QP infeasible!")
            
            direction = jnp.asarray(res.x)
            alpha = self.backtrack(x, direction, sigma)
            x += alpha * direction
            m = (1-alpha)*m + alpha * res.y
            sigma = jnp.linalg.norm(jnp.hstack([1.01*res.y, sigma]))
            if self.verbose:
                print(f"{i}: x:{x}, dir:{direction}, alpha:{alpha}, max_viol:{max_viol}")
    
    def backtrack(
        self, x, direction, 
        sigma=0., alpha=1., beta=0.5, gamma=0.1, max_iter=30
    ):
        for i in range(max_iter):
            curr_merit = self.merit_fn(x, sigma)
            next_merit = self.merit_fn(x + alpha * direction, sigma)
            armijo = gamma * alpha * self.armijo_merit_fn(x, direction, sigma)
            if next_merit < curr_merit + armijo:
                break
            alpha *= beta
        return alpha
    
    @classmethod
    def from_nlp_builder(cls, prob:NLPBuilder):
        return cls(prob.dim,
                    prob.const_dim,
                    prob.get_convexify_fn(),
                    prob.get_merit_fn(),
                    prob.get_armijo_merit(),
                    prob.get_const_viol_fn())

In [4]:
sqp = SQP.from_nlp_builder(nlp)
sqp.verbose = True

In [9]:
x0 = jnp.zeros(2)
xsol = sqp.solve(x0)

In [10]:
sqp.max_viol_fn(x0, 0.)

Array(0., dtype=float32)

In [49]:
x0 = jnp.array([1.0, 5.0, 5.0, 1.0]) # TODO: make initiaal guess feasible
xsol = sqp.solve(x0)

0: x:[0.99985886 4.99719    3.7528415  1.2499943 ], dir:[-1.4112849e-04 -2.8102221e-03 -1.2471585e+00  2.4999431e-01], alpha:1.0, max_viol:12.0
1: x:[1.0000635 4.567744  4.082198  1.3306595], dir:[ 2.0472695e-04 -4.2944622e-01  3.2935661e-01  8.0665261e-02], alpha:1.0, max_viol:1.6179313659667969
2: x:[1.0000005 4.6973457 3.8890197 1.3659519], dir:[-6.3077670e-05  1.2960210e-01 -1.9317833e-01  3.5292350e-02], alpha:1.0, max_viol:0.299407958984375
3: x:[1.0000219 4.728714  3.8403277 1.376433 ], dir:[ 2.1410697e-05  3.1368159e-02 -4.8692092e-02  1.0481072e-02], alpha:1.0, max_viol:0.05535888671875
4: x:[0.99999595 4.7435794  3.8204956  1.3794568 ], dir:[-2.6011288e-05  1.4865272e-02 -1.9832024e-02  3.0237255e-03], alpha:1.0, max_viol:0.003688812255859375


In [46]:
%timeit sqp.solve(x0)

11 ms ± 318 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [22]:
sqp.solve(xsol)

Array([1.        , 4.73975686, 3.82544392, 1.37878578], dtype=float64)

In [3]:
# init
x0 = jnp.array([1.0, 5.0, 5.0, 1.0]) # TODO: make initial guess feasible
lmbda0 = jnp.zeros(nlp.g_dim+nlp.h_dim)

PqGhAb = nlp.get_PqGhAb_fn()
merit_fn = nlp.get_merit_fn()
backtrack = nlp.get_backtrack_fn()
fdiff_fn = lambda x1, x2: jnp.abs(nlp.f(x2) - nlp.f(x1))
gh_fn = nlp.get_gh()

PqGhAb = jax.jit(PqGhAb).lower(x0, lmbda0).compile()
merit_fn = jax.jit(merit_fn).lower(x0, 1.).compile()
fdiff_fn = jax.jit(fdiff_fn).lower(x0, x0).compile()


grad_and_const_viol = nlp.get_grad_and_const_viol()
is_qp_setup = False

In [4]:
x = x0.copy()
lmbda = lmbda0.copy()
sigma = 0.
it = 0

In [5]:
import time
tic = time.time()
for i in range(100):
    P, q, G, h, A, b = PqGhAb(x, lmbda)
    
    P_ = sparse.csc_matrix(P)
    q_ = np.asarray(q)
    A_ = sparse.csc_matrix(np.vstack([A, G]))
    u_ = np.hstack([-b, -h])
    l_ = np.hstack([-b, np.full(len(h), -np.inf)])
    
    if is_qp_setup == False:
        opts = {"verbose":False, "polish":True}
        qpsolver.setup(P_, q_, A_, l_, u_, **opts)
        is_qp_setup = True
    else:
        qpsolver.update(Px=P_.data, q=q_, Ax=A_.data, l=l_, u=u_)
    res = qpsolver.solve()
    if res.info.status != "solved":
        print("QP infeasible!")
        raise ValueError()
    else:
        print("QP solved")
    direction = jnp.asarray(res.x)

    alpha = backtrack(x, direction, merit_fn, sigma=sigma)
    xdiff = alpha * direction
    xdiff_max = jnp.linalg.norm(xdiff, jnp.inf)
    fdiff = fdiff_fn(x, x + xdiff)
    max_viol, max_grad = grad_and_const_viol(x, lmbda)

    #update
    x = x + xdiff
    lmbda = (1-alpha)*lmbda + alpha * res.y
    sigma = jnp.max(jnp.abs(jnp.hstack([1.01*res.y, sigma])))
    
    print(i)
    print(f"xdiff:{xdiff_max}, fdiff:{fdiff}, max_viol:{max_viol}")
    print(f"x:{x}, alpha:{alpha}")
    if max_viol < 0.01:
        break
print(f"elapsed:{time.time() - tic}")

QP solved
0
xdiff:1.1641532182693481e-09, fdiff:3.147988536511548e-10, max_viol:12.0
x:[1. 5. 5. 1.], alpha:9.313225746154785e-10
QP solved
1
xdiff:1.249999999166853, fdiff:0.062499999478845325, max_viol:11.999999988824129
x:[1.   5.   3.75 1.25], alpha:1.0
QP solved
2
xdiff:0.4343480808709526, fdiff:0.9846040134799985, max_viol:1.6249999979227638
x:[1.         4.56565192 4.08580199 1.32998636], alpha:1.0
QP solved
3
xdiff:0.19569909302324912, fdiff:0.06072529788424319, max_viol:0.3078190485582297
x:[1.         4.69666698 3.8901029  1.3657078 ], alpha:1.0
QP solved
4
xdiff:0.04940514744966446, fdiff:0.029136839675199866, max_viol:0.056739104003284524
x:[1.         4.72842845 3.84069775 1.37643377], alpha:1.0
QP solved
5
xdiff:0.015253824613792275, fdiff:0.0018406184953647653, max_viol:0.003564705535183066
x:[1.         4.73975686 3.82544392 1.37878578], alpha:1.0
elapsed:2.1857571601867676


In [84]:
sigma

Array(1.0267321, dtype=float32)

In [83]:
direction

Array([ 0.08796829, -0.00585553, -1.2441235 ,  0.16194648], dtype=float32)

In [179]:
ctol = 0.01
cl = jnp.hstack([jnp.zeros(nlp.g_dim), jnp.full(nlp.h_dim, -jnp.inf)]) - ctol
cu = jnp.zeros(nlp.g_dim+nlp.h_dim) + ctol

((cl < gh_fn(x)) & (gh_fn(x) < cu)).all()

Array(True, dtype=bool)

In [153]:
jnp.max(jnp.hstack([cl - gh_fn(x), gh_fn(x) - cu]))

Array(0.04976868, dtype=float32)

In [134]:
gh_fn(x) - cu

Array([ 0.05768418,  0.04433655, -0.01000727, -3.7475054 , -2.8496766 ,
       -0.38134754, -4.009993  , -0.27249455, -1.1703234 , -3.6386526 ],      dtype=float32)

In [49]:
gh_fn(x)

Array([ 3.0654144e-01,  3.0376816e-01,  4.4703484e-05, -3.7556133e+00,
       -2.8572013e+00, -3.4639120e-01, -4.0000448e+00, -2.4438667e-01,
       -1.1427987e+00, -3.6536088e+00], dtype=float32)

In [48]:
gh_fn(x) < cu

Array([False, False,  True,  True,  True,  True,  True,  True,  True,
        True], dtype=bool)

In [9]:
%timeit PqAul(x0, lmbda0)

158 µs ± 1.27 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [21]:
class NLPProb:
    def __init__(self, dim):
        self.dim = dim
        self.f = None
        self._g = [] # eq = 0.
        self._h = [] # ineq <= 0.
    
    @staticmethod
    def get_eq_const(fn, val):
        result = lambda x: fn(x) - val
        return result
    @staticmethod
    def get_ineq_const_lb(fn, lb):
        result = lambda x: -fn(x) + lb
        return result
    @staticmethod
    def get_ineq_const_ub(fn, ub):
        result = lambda x: fn(x) - ub
        return result
    
    def get_ineq_const_b(self, fn, lb, ub):
        lb_fn = self.get_ineq_const_lb(fn, lb)
        ub_fn = self.get_ineq_const_ub(fn, ub)
        def result(x):
            return jnp.hstack([lb_fn(x), ub_fn(x)])
        return result
    
    @staticmethod
    def get_state_bound_fn(lb, ub):
        def state_bound_fn(x):
            return jnp.hstack([x - ub, -x + lb])
        return state_bound_fn
    
    def set_f(self, fn):
        self.f = fn
    
    def add_eq_const(self, fn, val):
        self._g.append(self.get_eq_const(fn, val))

    def add_ineq_const_lb(self, fn, lb):
        self._h.append(self.get_ineq_const_lb(fn, lb))
    
    def add_ineq_const_ub(self, fn, ub):
        self._h.append(self.get_ineq_const_ub(fn, ub))

    def add_ineq_const_b(self, fn, lb, ub):
        self._h.append(self.get_ineq_const_lb(fn, lb))
        self._h.append(self.get_ineq_const_ub(fn, ub))
    
    def set_state_bound(self, xlb, xub):
        self._h.append(self.get_state_bound_fn(xlb, xub))
    
    def get_g(self):
        return lambda x: jnp.hstack([fn(x) for fn in self._g])
    
    def get_h(self):
        return lambda x: jnp.hstack([fn(x) for fn in self._h])

    def get_gh(self):
        return lambda x: jnp.hstack([fn(x) for fn in self._g+self._h])
    
    def get_lagrangian(self):
        gh = self.get_gh()
        return lambda x, lmbda: self.f(x) + gh(x) @ lmbda
    
    def get_g_dim(self):
        g = self.get_g()
        return len(g(jnp.zeros(self.dim)))
    
    def get_h_dim(self):
        h = self.get_h()
        return len(h(jnp.zeros(self.dim)))
    
    def get_lag_mult_dim(self):
        gh = self.get_gh()
        return len(gh(jnp.zeros(self.dim)))
    
    def get_merit_fn(self):
        sigma = 1.0
        f = self.f
        g = self.get_g()
        h = self.get_h()
        def merit_fn(x):
            eq_norm = jnp.linalg.norm(g(x), 1)
            ineq_norm = jnp.linalg.norm(jnp.clip(h(x), a_min=0), 1)
            return f(x) + sigma * (eq_norm + ineq_norm)
        return merit_fn

In [22]:
def backtrack(
    x_k,
    direction,
    merit_fn,
    alpha=1.,
    beta=0.5,
    max_iter=30,
):
    curr_merit = merit_fn(x_k)
    next_merit = merit_fn(x_k + alpha * direction)

    n_iter = 0
    while (next_merit >= curr_merit) and (n_iter < max_iter):
        alpha *= beta
        next_merit = merit_fn(x_k + alpha * direction)
        n_iter += 1
    if n_iter == max_iter:
        print(f'Backtracking failed to find alpha after {max_iter} iterations!')

    return alpha

In [23]:
prob = NLPProb(4)
prob.set_f(
    lambda x:x[0] * x[3] * jnp.sum(x[:3]) + x[2]
)
prob.add_eq_const(lambda x:jnp.sum(x**2), 40)
prob.add_ineq_const_lb(lambda x:jnp.prod(x), 25)
prob.set_state_bound(jnp.full(4, 1), jnp.full(4, 5))

In [24]:
lag_hess = jax.hessian(prob.get_lagrangian())
f_grad = jax.grad(prob.f)
gh = prob.get_gh()
gh_grad = jax.jacrev(prob.get_gh())
merit_fn = prob.get_merit_fn()

In [25]:
# init
x0 = jnp.array([1.0, 5.0, 5.0, 1.0])
lmbda0 = jnp.zeros(prob.get_lag_mult_dim())
x = x0.copy()
lmbda = lmbda0.copy()

### iter

In [356]:
eq_dim = prob.get_g_dim()
ineq_dim = prob.get_h_dim()
P = lag_hess(x, jnp.zeros(10))
q = f_grad(x)
A = gh_grad(x)
u = -gh(x)
l = jnp.hstack([u[:eq_dim], jnp.full(ineq_dim, -jnp.inf)])

In [357]:
#regularization
if jnp.isnan(jnp.linalg.cholesky(P)).any():
    eigs, vecs = jnp.linalg.eigh(P)
    delta = 1e-6
    eigs_modified = jnp.where(eigs < delta, delta, eigs)
    P = vecs @ jnp.diag(eigs_modified) @ vecs.T

P = sparse.csc_matrix(P)
q = np.asarray(q)
A = sparse.csc_matrix(A)
l = np.asarray(l)
u = np.asarray(u)

In [366]:
solver = osqp.OSQP()
solver.setup(P, q, A, l, u, verbose=False)
res = solver.solve()
if res.info.status != "solved":
    print("QP infeasible!")
else:
    print("QP solved")

QP solved


In [359]:
direction = jnp.asarray(res.x)
alpha = backtrack(x, direction, merit_fn)
print(f"curr_merit:{merit_fn(x)}, next_merit:{merit_fn(x+direction*alpha)}")
print(f"x:{x}, alpha:{alpha}")

curr_merit:17.231197357177734, next_merit:17.224365234375
x:[1.0000093 4.8189845 3.7490077 1.3751227], alpha:0.0625


In [360]:
x += alpha * direction
lmbda = (1-alpha)*lmbda + alpha * res.y

In [115]:
eq_consts = [eq_const]  # = 0.
ineq_consts = [ineq_const, state_bound] # <= 0.
consts_fns = [] + eq_consts + ineq_consts
_g = lambda x: jnp.hstack([c_fn(x) for c_fn in consts_fns])
n_eq = len(eq_consts)
n_ineq = len(ineq_consts)
lbg, ubg = jnp.zeros(n_eq), jnp.zeros(n_eq)
lbh, ubh = jnp.zeros(n_ineq), jnp.full(n_ineq, jnp.inf)

In [116]:
#lagrangian
lag = lambda x, lmbda: f(x) + _g(x) @ lmbda

In [117]:
lag_hess = jax.hessian(lag)
f_grad = jax.grad(f)
g_grad = jax.jacrev(_g)

In [71]:
B = lag_hess(x0, lmbda0)
grad_f = np.asarray(f_grad(x0))
grad_g = g_grad(x0)
val_g = _g(x0)
lbg = np.array([0., 25])
ubg = np.array([40., jnp.inf])

In [93]:

prob = osqp.OSQP()

In [94]:
prob.setup(sparse.csc_matrix(B), grad_f, sparse.csc_matrix(grad_g), lbg, ubg)

-----------------------------------------------------------------
           OSQP v0.6.2  -  Operator Splitting QP Solver
              (c) Bartolomeo Stellato,  Goran Banjac
        University of Oxford  -  Stanford University 2021
-----------------------------------------------------------------
problem:  variables n = 4, constraints m = 2
          nnz(P) + nnz(A) = 18
settings: linear system solver = qdldl,
          eps_abs = 1.0e-03, eps_rel = 1.0e-03,
          eps_prim_inf = 1.0e-04, eps_dual_inf = 1.0e-04,
          rho = 1.00e-01 (adaptive),
          sigma = 1.00e-06, alpha = 1.60, max_iter = 4000
          check_termination: on (interval 25),
          scaling: on, scaled_termination: off
          warm start: on, polish: off, time_limit: off



In [95]:
res = prob.solve()

iter   objective    pri res    dua res    rho        time
   1  -3.0971e+04   4.00e+01   9.60e+03   1.00e-01   1.22e-04s
  25  -1.0000e+30   1.23e-02   1.01e+00   1.00e-01   5.41e-04s

status:               dual infeasible
number of iterations: 25
run time:             7.19e-04s
optimal rho estimate: 3.77e-03



In [96]:
n_dim = 4
n_eq, n_ineq = 1, 1
n_s = 2* n_eq + n_g
new_B = jnp.block([
    B, jnp.zeros((n_dim, ))
])

array([None, None, None, None], dtype=object)