In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import osqp
from scipy import sparse
from mysqp import NLPBuilder, value_and_jacrev

jax.config.update("jax_enable_x64", True)

In [3]:
nlp = NLPBuilder(4)

In [4]:
nlp.set_f(
    lambda x:x[0] * x[3] * jnp.sum(x[:3]) + x[2]
)
nlp.add_eq_const(lambda x:jnp.sum(x**2), 40)
nlp.add_ineq_const_lb(lambda x:jnp.prod(x), 25)
nlp.set_state_bound(jnp.full(4, 1), jnp.full(4, 5))

In [5]:
from functools import partial

# init
x0 = jnp.array([1.0, 5.0, 5.0, 1.0]) # TODO: make initial guess feasible
lmbda0 = jnp.zeros(nlp.g_dim+nlp.h_dim)

lag_grad_fn = jax.grad(nlp.get_lagrangian_fn())
f_grad_fn = jax.grad(nlp.f)
const_and_jac_fn = partial(value_and_jacrev, f=nlp.get_gh())
eq_dim = nlp.g_dim
ineq_dim = nlp.h_dim

merit_fn = nlp.get_merit_fn()
backtrack = nlp.get_backtrack_fn()
fdiff_fn = lambda x1, x2: jnp.abs(nlp.f(x2) - nlp.f(x1))
gh_fn = nlp.get_gh()
#PqAlu = jax.jit(PqAlu).lower(x0, lmbda0).compile()
# merit_fn = jax.jit(merit_fn).lower(x0, 1.).compile()
# fdiff_fn = jax.jit(fdiff_fn).lower(x0, x0).compile()
qpsolver = osqp.OSQP()

is_qp_setup = False

In [6]:
def BFGS(B_prev, lag_grad_curr, lag_grad_prev, x_curr, x_prev):
    def constant_zero(B_prev, incr1, incr2):
        return B_prev
    def normal(B_prev, incr1, incr2):
        return B_prev + incr1 - incr2
    y = x_curr - x_prev
    s = lag_grad_curr - lag_grad_prev
    c1 = jnp.inner(y, s)
    c2 = s @ B_prev @ s
    incr1 = jnp.outer(y, y) / c1
    incr2 = B_prev @ jnp.outer(s, s) @ B_prev / c2
    cond = jnp.isclose(c1, 0.) | jnp.isclose(c2, 0.)
    return jax.lax.cond(cond, constant_zero, normal, B_prev, incr1, incr2)

In [7]:
x = x0.copy()
lmbda = lmbda0.copy()
sigma = 0.
it = 0
P = None

In [51]:
#PqAlu
lag_grad_curr = lag_grad_fn(x, lmbda)
if P is None:
    P = jnp.eye(nlp.dim)
else:
    P = BFGS(P_prev, lag_grad_curr, lag_grad_prev, x, x_prev)
    print("BFGS!")
q = f_grad_fn(x)
const, const_jac = const_and_jac_fn(x)
A, u = const_jac, -const
l = jnp.hstack([u[:eq_dim], jnp.full(ineq_dim, -jnp.inf)])

#QP
Ps = sparse.csc_matrix(P)
q = np.asarray(q)
As = sparse.csc_matrix(A)
l = np.asarray(l)
u = np.asarray(u)
if is_qp_setup == False:
    qpsolver.setup(Ps, q, As, l, u, verbose=False)
    is_qp_setup = True
else:
    qpsolver.update(Px=Ps.data, q=q, Ax=As.data, l=l, u=u)

res = qpsolver.solve()
if res.info.status != "solved":
    print("QP infeasible!")
else:
    print("QP solved")
direction = jnp.asarray(res.x)


alpha = backtrack(x, direction, merit_fn, beta=1/3, sigma=sigma)
xdiff = alpha * direction
fdiff = fdiff_fn(x, x + xdiff)

#update
P_prev = P
x_prev = x
lag_grad_prev = lag_grad_curr

x = x + xdiff
lmbda = (1-alpha)*lmbda + alpha * res.y
sigma = jnp.max(jnp.abs(jnp.hstack([1.01*res.y, sigma])))

print(f"xdiff:{jnp.linalg.norm(xdiff, jnp.inf)}, fdiff:{fdiff}")
print(f"x:{x}")

BFGS!
QP solved
xdiff:4.132376019904824e-12, fdiff:9.947598300641403e-14
x:[0.99992547 4.7370438  3.83289496 1.37701724]


In [35]:
x_prev, x

(Array([0.99992547, 4.7370438 , 3.83289496, 1.37701724], dtype=float64),
 Array([0.99992547, 4.7370438 , 3.83289496, 1.37701724], dtype=float64))

In [36]:
ctol = 0.01
cl = jnp.hstack([jnp.zeros(nlp.g_dim), jnp.full(nlp.h_dim, -jnp.inf)]) - ctol
cu = jnp.zeros(nlp.g_dim+nlp.h_dim) + ctol

((cl < gh_fn(x)) & (gh_fn(x) < cu)).all()

Array(False, dtype=bool)

In [37]:
jnp.max(jnp.hstack([cl - gh_fn(x), gh_fn(x) - cu]))

Array(0.01669514, dtype=float64)

In [134]:
gh_fn(x) - cu

Array([ 0.05768418,  0.04433655, -0.01000727, -3.7475054 , -2.8496766 ,
       -0.38134754, -4.009993  , -0.27249455, -1.1703234 , -3.6386526 ],      dtype=float32)

In [49]:
gh_fn(x)

Array([ 3.0654144e-01,  3.0376816e-01,  4.4703484e-05, -3.7556133e+00,
       -2.8572013e+00, -3.4639120e-01, -4.0000448e+00, -2.4438667e-01,
       -1.1427987e+00, -3.6536088e+00], dtype=float32)

In [48]:
gh_fn(x) < cu

Array([False, False,  True,  True,  True,  True,  True,  True,  True,
        True], dtype=bool)

In [9]:
%timeit PqAul(x0, lmbda0)

158 µs ± 1.27 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [21]:
class NLPProb:
    def __init__(self, dim):
        self.dim = dim
        self.f = None
        self._g = [] # eq = 0.
        self._h = [] # ineq <= 0.
    
    @staticmethod
    def get_eq_const(fn, val):
        result = lambda x: fn(x) - val
        return result
    @staticmethod
    def get_ineq_const_lb(fn, lb):
        result = lambda x: -fn(x) + lb
        return result
    @staticmethod
    def get_ineq_const_ub(fn, ub):
        result = lambda x: fn(x) - ub
        return result
    
    def get_ineq_const_b(self, fn, lb, ub):
        lb_fn = self.get_ineq_const_lb(fn, lb)
        ub_fn = self.get_ineq_const_ub(fn, ub)
        def result(x):
            return jnp.hstack([lb_fn(x), ub_fn(x)])
        return result
    
    @staticmethod
    def get_state_bound_fn(lb, ub):
        def state_bound_fn(x):
            return jnp.hstack([x - ub, -x + lb])
        return state_bound_fn
    
    def set_f(self, fn):
        self.f = fn
    
    def add_eq_const(self, fn, val):
        self._g.append(self.get_eq_const(fn, val))

    def add_ineq_const_lb(self, fn, lb):
        self._h.append(self.get_ineq_const_lb(fn, lb))
    
    def add_ineq_const_ub(self, fn, ub):
        self._h.append(self.get_ineq_const_ub(fn, ub))

    def add_ineq_const_b(self, fn, lb, ub):
        self._h.append(self.get_ineq_const_lb(fn, lb))
        self._h.append(self.get_ineq_const_ub(fn, ub))
    
    def set_state_bound(self, xlb, xub):
        self._h.append(self.get_state_bound_fn(xlb, xub))
    
    def get_g(self):
        return lambda x: jnp.hstack([fn(x) for fn in self._g])
    
    def get_h(self):
        return lambda x: jnp.hstack([fn(x) for fn in self._h])

    def get_gh(self):
        return lambda x: jnp.hstack([fn(x) for fn in self._g+self._h])
    
    def get_lagrangian(self):
        gh = self.get_gh()
        return lambda x, lmbda: self.f(x) + gh(x) @ lmbda
    
    def get_g_dim(self):
        g = self.get_g()
        return len(g(jnp.zeros(self.dim)))
    
    def get_h_dim(self):
        h = self.get_h()
        return len(h(jnp.zeros(self.dim)))
    
    def get_lag_mult_dim(self):
        gh = self.get_gh()
        return len(gh(jnp.zeros(self.dim)))
    
    def get_merit_fn(self):
        sigma = 1.0
        f = self.f
        g = self.get_g()
        h = self.get_h()
        def merit_fn(x):
            eq_norm = jnp.linalg.norm(g(x), 1)
            ineq_norm = jnp.linalg.norm(jnp.clip(h(x), a_min=0), 1)
            return f(x) + sigma * (eq_norm + ineq_norm)
        return merit_fn

In [22]:
def backtrack(
    x_k,
    direction,
    merit_fn,
    alpha=1.,
    beta=0.5,
    max_iter=30,
):
    curr_merit = merit_fn(x_k)
    next_merit = merit_fn(x_k + alpha * direction)

    n_iter = 0
    while (next_merit >= curr_merit) and (n_iter < max_iter):
        alpha *= beta
        next_merit = merit_fn(x_k + alpha * direction)
        n_iter += 1
    if n_iter == max_iter:
        print(f'Backtracking failed to find alpha after {max_iter} iterations!')

    return alpha

In [23]:
prob = NLPProb(4)
prob.set_f(
    lambda x:x[0] * x[3] * jnp.sum(x[:3]) + x[2]
)
prob.add_eq_const(lambda x:jnp.sum(x**2), 40)
prob.add_ineq_const_lb(lambda x:jnp.prod(x), 25)
prob.set_state_bound(jnp.full(4, 1), jnp.full(4, 5))

In [24]:
lag_hess = jax.hessian(prob.get_lagrangian())
f_grad = jax.grad(prob.f)
gh = prob.get_gh()
gh_grad = jax.jacrev(prob.get_gh())
merit_fn = prob.get_merit_fn()

In [25]:
# init
x0 = jnp.array([1.0, 5.0, 5.0, 1.0])
lmbda0 = jnp.zeros(prob.get_lag_mult_dim())
x = x0.copy()
lmbda = lmbda0.copy()

### iter

In [356]:
eq_dim = prob.get_g_dim()
ineq_dim = prob.get_h_dim()
P = lag_hess(x, jnp.zeros(10))
q = f_grad(x)
A = gh_grad(x)
u = -gh(x)
l = jnp.hstack([u[:eq_dim], jnp.full(ineq_dim, -jnp.inf)])

In [357]:
#regularization
if jnp.isnan(jnp.linalg.cholesky(P)).any():
    eigs, vecs = jnp.linalg.eigh(P)
    delta = 1e-6
    eigs_modified = jnp.where(eigs < delta, delta, eigs)
    P = vecs @ jnp.diag(eigs_modified) @ vecs.T

P = sparse.csc_matrix(P)
q = np.asarray(q)
A = sparse.csc_matrix(A)
l = np.asarray(l)
u = np.asarray(u)

In [366]:
solver = osqp.OSQP()
solver.setup(P, q, A, l, u, verbose=False)
res = solver.solve()
if res.info.status != "solved":
    print("QP infeasible!")
else:
    print("QP solved")

QP solved


In [359]:
direction = jnp.asarray(res.x)
alpha = backtrack(x, direction, merit_fn)
print(f"curr_merit:{merit_fn(x)}, next_merit:{merit_fn(x+direction*alpha)}")
print(f"x:{x}, alpha:{alpha}")

curr_merit:17.231197357177734, next_merit:17.224365234375
x:[1.0000093 4.8189845 3.7490077 1.3751227], alpha:0.0625


In [360]:
x += alpha * direction
lmbda = (1-alpha)*lmbda + alpha * res.y

In [115]:
eq_consts = [eq_const]  # = 0.
ineq_consts = [ineq_const, state_bound] # <= 0.
consts_fns = [] + eq_consts + ineq_consts
_g = lambda x: jnp.hstack([c_fn(x) for c_fn in consts_fns])
n_eq = len(eq_consts)
n_ineq = len(ineq_consts)
lbg, ubg = jnp.zeros(n_eq), jnp.zeros(n_eq)
lbh, ubh = jnp.zeros(n_ineq), jnp.full(n_ineq, jnp.inf)

In [116]:
#lagrangian
lag = lambda x, lmbda: f(x) + _g(x) @ lmbda

In [117]:
lag_hess = jax.hessian(lag)
f_grad = jax.grad(f)
g_grad = jax.jacrev(_g)

In [71]:
B = lag_hess(x0, lmbda0)
grad_f = np.asarray(f_grad(x0))
grad_g = g_grad(x0)
val_g = _g(x0)
lbg = np.array([0., 25])
ubg = np.array([40., jnp.inf])

In [93]:

prob = osqp.OSQP()

In [94]:
prob.setup(sparse.csc_matrix(B), grad_f, sparse.csc_matrix(grad_g), lbg, ubg)

-----------------------------------------------------------------
           OSQP v0.6.2  -  Operator Splitting QP Solver
              (c) Bartolomeo Stellato,  Goran Banjac
        University of Oxford  -  Stanford University 2021
-----------------------------------------------------------------
problem:  variables n = 4, constraints m = 2
          nnz(P) + nnz(A) = 18
settings: linear system solver = qdldl,
          eps_abs = 1.0e-03, eps_rel = 1.0e-03,
          eps_prim_inf = 1.0e-04, eps_dual_inf = 1.0e-04,
          rho = 1.00e-01 (adaptive),
          sigma = 1.00e-06, alpha = 1.60, max_iter = 4000
          check_termination: on (interval 25),
          scaling: on, scaled_termination: off
          warm start: on, polish: off, time_limit: off



In [95]:
res = prob.solve()

iter   objective    pri res    dua res    rho        time
   1  -3.0971e+04   4.00e+01   9.60e+03   1.00e-01   1.22e-04s
  25  -1.0000e+30   1.23e-02   1.01e+00   1.00e-01   5.41e-04s

status:               dual infeasible
number of iterations: 25
run time:             7.19e-04s
optimal rho estimate: 3.77e-03



In [96]:
n_dim = 4
n_eq, n_ineq = 1, 1
n_s = 2* n_eq + n_g
new_B = jnp.block([
    B, jnp.zeros((n_dim, ))
])

array([None, None, None, None], dtype=object)