-
Notifications
You must be signed in to change notification settings - Fork 62
Closed
Description
Code
from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import cyipopt
def basic_function(x: jnp.ndarray, data: jnp.ndarray):
return (x[0] * data + x[1]).mean()
cyipopt.minimize_ipopt(
basic_function,
jnp.array([1.0, 2.0]), (jnp.array([1.,2,5,1,5,2,8,2,0,4,9]), ),
jac=jax.jacfwd(basic_function),
hess=jax.hessian(basic_function),
)Error
# Message from JAX trimmed
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/forcebru/mambaforge/lib/python3.10/site-packages/cyipopt/scipy_interface.py", line 314, in minimize_ipopt
x, info = nlp.solve(_x0)
File "cyipopt/cython/ipopt_wrapper.pyx", line 642, in ipopt_wrapper.Problem.solve
File "cyipopt/cython/ipopt_wrapper.pyx", line 895, in ipopt_wrapper.hessian_cb
File "/Users/forcebru/mambaforge/lib/python3.10/site-packages/cyipopt/scipy_interface.py", line 164, in hessian
H = obj_factor * self.obj_hess(x) # type: ignore
File "/Users/forcebru/mambaforge/lib/python3.10/site-packages/jax/_src/api.py", line 1286, in jacfun
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
TypeError: basic_function() missing 1 required positional argument: 'data'
I think in this code:
cyipopt/cyipopt/scipy_interface.py
Lines 163 to 164 in 8c68027
| def hessian(self, x, lagrange, obj_factor): | |
| H = obj_factor * self.obj_hess(x) # type: ignore |
...self.obj_hess should be called with self.args and self.kwargs, like self.fun and self.jac:
cyipopt/cyipopt/scipy_interface.py
Line 149 in 8c68027
| return self.jac(x, *self.args, **self.kwargs) # .T |
- cyipopt 1.2.0
Metadata
Metadata
Assignees
Labels
No labels