Skip to content

Objective function not passed extra arguments when Hessian is provided #175

@ForceBru

Description

@ForceBru

Code

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import cyipopt

def basic_function(x: jnp.ndarray, data: jnp.ndarray):
    return (x[0] * data + x[1]).mean()

cyipopt.minimize_ipopt(
    basic_function,
    jnp.array([1.0, 2.0]), (jnp.array([1.,2,5,1,5,2,8,2,0,4,9]), ),
    jac=jax.jacfwd(basic_function),
    hess=jax.hessian(basic_function),
)

Error

# Message from JAX trimmed
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/forcebru/mambaforge/lib/python3.10/site-packages/cyipopt/scipy_interface.py", line 314, in minimize_ipopt
    x, info = nlp.solve(_x0)
  File "cyipopt/cython/ipopt_wrapper.pyx", line 642, in ipopt_wrapper.Problem.solve
  File "cyipopt/cython/ipopt_wrapper.pyx", line 895, in ipopt_wrapper.hessian_cb
  File "/Users/forcebru/mambaforge/lib/python3.10/site-packages/cyipopt/scipy_interface.py", line 164, in hessian
    H = obj_factor * self.obj_hess(x)  # type: ignore
  File "/Users/forcebru/mambaforge/lib/python3.10/site-packages/jax/_src/api.py", line 1286, in jacfun
    y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
TypeError: basic_function() missing 1 required positional argument: 'data'

I think in this code:

def hessian(self, x, lagrange, obj_factor):
H = obj_factor * self.obj_hess(x) # type: ignore

...self.obj_hess should be called with self.args and self.kwargs, like self.fun and self.jac:

return self.jac(x, *self.args, **self.kwargs) # .T

  • cyipopt 1.2.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions