In [1]:
from jax.config import config
# Enable 64 bit floating point precision
config.update("jax_enable_x64", True)

# We use the CPU instead of GPU und mute all warnings if no GPU/TPU is found.
config.update('jax_platform_name', 'cpu')

import jax.numpy as np
from jax import jit, grad, jacfwd, jacrev

from cyipopt import minimize_ipopt

In [127]:
def objective(x):
    return x[0]**2 + x[1]**2

def eq_constraints(x):
    return np.array([x[0] + 10])

def test(x):
    return -x

def ineq_constraints(x):
    return np.append(np.array([ test(x[0]) -1,
                    test(x[1]*x[1])+6 ]),np.zeros(6))


In [128]:
# jit the functions
obj_jit = jit(objective)
con_eq_jit = jit(eq_constraints)
con_ineq_jit = jit(ineq_constraints)

# build the derivatives and jit them
obj_grad = jit(grad(obj_jit))  # objective gradient
obj_hess = jit(jacrev(jacfwd(obj_jit))) # objective hessian
con_eq_jac = jit(jacfwd(con_eq_jit))  # jacobian
con_ineq_jac = jit(jacfwd(con_ineq_jit))  # jacobian
con_eq_hess = jacrev(jacfwd(con_eq_jit)) # hessian
con_eq_hessvp = jit(lambda x, v: con_eq_hess(x) * v[0]) # hessian vector-product
con_ineq_hess = jacrev(jacfwd(con_ineq_jit))  # hessian
con_ineq_hessvp = jit(lambda x, v: con_ineq_hess(x) * v[0]) # hessian vector-product

In [129]:
# constraints
cons = [
    {'type': 'eq', 'fun': con_eq_jit, 'jac': con_eq_jac, 'hess': con_eq_hessvp},
    {'type': 'ineq', 'fun': con_ineq_jit, 'jac': con_ineq_jac, 'hess': con_ineq_hessvp}
 ]

# starting point
x0 = np.array([0.0, -1.0])

# variable bounds: 1 <= x[i] <= 5
bnds = [(-100, 100) for _ in range(x0.size)]

# executing the solver
res = minimize_ipopt(obj_jit, jac=obj_grad, hess=obj_hess, x0=x0, bounds=bnds,
                  constraints=cons, options={'disp': 5})

This is Ipopt version 3.14.11, running with linear solver ma27.

Number of nonzeros in equality constraint Jacobian...:        2
Number of nonzeros in inequality constraint Jacobian.:        4
Number of nonzeros in Lagrangian Hessian.............:        3

Total number of variables............................:        2
                     variables with only lower bounds:        0
                variables with lower and upper bounds:        2
                     variables with only upper bounds:        0
Total number of equality constraints.................:        1
Total number of inequality constraints...............:        2
        inequality constraints with only lower bounds:        2
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  1.0000000e+00 1.00e+01 1.60e+00   0.0 0.00e+00    -  0.00e+00 0.00e+00   0
   

In [130]:
print(res)

     fun: 100.0
    info: {'x': array([-1.00000000e+01, -7.00291389e-24]), 'g': array([0., 9., 6.]), 'obj_val': 100.0, 'mult_g': array([ 2.00000000e+01, -1.11101111e-12, -1.66656696e-12]), 'mult_x_L': array([1.43679923e-13, 1.31065611e-13]), 'mult_x_U': array([1.18271586e-13, 1.27697816e-13]), 'status': 0, 'status_msg': b'Algorithm terminated successfully at a locally optimal point, satisfying the convergence tolerances (can be specified by options).'}
 message: b'Algorithm terminated successfully at a locally optimal point, satisfying the convergence tolerances (can be specified by options).'
    nfev: 7
     nit: 6
    njev: 8
  status: 0
 success: True
       x: array([-1.00000000e+01, -7.00291389e-24])


In [131]:
res.x[0]

-10.0

In [132]:
res.x[1]

-7.002913893151417e-24

In [101]:
aa = np.array([1,2,3,4])

In [102]:
aa

Array([1, 2, 3, 4], dtype=int64)

In [103]:
aa[0:2] = np.array([6,7])

TypeError: '<class 'jaxlib.xla_extension.Array'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [105]:
np.append(aa,np.array([5]))

Array([1, 2, 3, 4, 5], dtype=int64)

In [106]:
aa = np.append(aa,np.array([5]))

In [107]:
aa

Array([1, 2, 3, 4, 5], dtype=int64)

In [108]:
aa = np.append(aa,np.array([5]))

In [109]:
aa

Array([1, 2, 3, 4, 5, 5], dtype=int64)

In [111]:
aa[0] = 8

TypeError: '<class 'jaxlib.xla_extension.Array'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [114]:
np.append(aa,np.array([5,8]))

Array([1, 2, 3, 4, 5, 5, 5, 8], dtype=int64)

In [115]:
a = np.array([5])

In [116]:
a = a + 4

In [117]:
a

Array([9], dtype=int64)