In [55]:
import nlopt
import numpy as np
import jax.numpy as jnp
import jax

In [66]:
opt = nlopt.opt(nlopt.LD_AUGLAG, 4)
x0 = jnp.array([1.0, 5.0, 5.0, 1.0])

In [67]:
#objective
def objective(x):
    return x[0]*x[3]*np.sum(x[:3]) + x[2]

def f(x, grad):
    val, grads = jax.value_and_grad(objective)(x)
    if grad.size > 0:
        grad[:] = grads
    return val.item()
opt.set_min_objective(f)

In [68]:
#bound constraints
lb = np.array([1.0, 1.0, 1.0, 1.0])
ub = np.array([5.0, 5.0, 5.0, 5.0])
opt.set_lower_bounds(lb)
opt.set_upper_bounds(ub)

In [69]:
#constraints
def ceq_fn(x):
    return jnp.sum(x**2) - 40
def cineq_fn(x):
    return 25 - jnp.prod(x)
ceq_fn_val_grad = jax.jit(jax.value_and_grad(ceq_fn)).lower(x0).compile()
cineq_fn_val_grad = jax.jit(jax.value_and_grad(cineq_fn)).lower(x0).compile()

def ceq(result, x, grad):
    val, grads = ceq_fn_val_grad(x)
    if grad.size > 0:
       grad[:] = grads
    result[:] = val

def cineq(result, x, grad):
    val, grads = cineq_fn_val_grad(x)
    if grad.size > 0:
       grad[:] = grads
    result[:] = val

opt.add_inequality_mconstraint(cineq, [0.01])
opt.add_equality_mconstraint(ceq, [0.01])


In [70]:
xtol = 1e-4
opt.set_xtol_rel(xtol*10)

opt2 = nlopt.opt(nlopt.LD_SLSQP, 4)
opt2.set_xtol_rel(xtol)
opt.set_local_optimizer(opt2)

In [88]:
%time
xopt = opt.optimize(xopt)

CPU times: user 4 µs, sys: 0 ns, total: 4 µs
Wall time: 8.34 µs


In [89]:
xopt

array([1.        , 4.74322359, 3.82087298, 1.3794318 ])

In [119]:
xopt

array([1.        , 4.74319474, 3.82094769, 1.37939895])

In [87]:
result = np.zeros(1)
grad = np.zeros(4)
cineq(result, x0, grad)
print(grad)

[-25.  -5.  -5. -25.]


In [81]:
f(x0, grad)

16.0

In [82]:
grad

array([12.,  1.,  2., 11.])

In [51]:
val, grads = jax.value_and_grad(ceq_fn)(x0)

In [54]:
result[:] = val

In [57]:
result

array([12.])

In [48]:
result = np.zeros(1)
cineq(result, x0, np.array([]))

In [50]:
result.shape

(1,)

In [14]:
grad = np.zeros(4)

In [32]:
f(x0, grad)

16.0

In [33]:
grad

array([12.,  1.,  2., 11.])

In [21]:
grad.dtype

dtype('float64')