From 52a4cf89503f32828de6e1513d708762e8db4558 Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Tue, 4 Jul 2023 10:39:24 +0100 Subject: [PATCH 1/2] set up correct signature, add _grad_fun --- jaxopt/_src/lbfgsb.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/jaxopt/_src/lbfgsb.py b/jaxopt/_src/lbfgsb.py index 316093d4..1029be74 100644 --- a/jaxopt/_src/lbfgsb.py +++ b/jaxopt/_src/lbfgsb.py @@ -21,6 +21,7 @@ # [2] J. Nocedal and S. Wright. Numerical Optimization, second edition. import dataclasses +import inspect import warnings from typing import Any, Callable, NamedTuple, Optional, Union @@ -558,6 +559,9 @@ def _value_and_grad_fun(self, params, *args, **kwargs): params = params.params (value, _), grad = self._value_and_grad_with_aux(params, *args, **kwargs) return value, grad + + def _grad_fun(self, params, *args, **kwargs): + return self._value_and_grad_fun(params, *args, **kwargs)[1] def __post_init__(self): _, _, self._value_and_grad_with_aux = base._make_funs_with_aux( @@ -565,8 +569,16 @@ def __post_init__(self): value_and_grad=self.value_and_grad, has_aux=self.has_aux, ) + + # Sets up reference signature. + fun = getattr(self.fun, "subfun", self.fun) + signature = inspect.signature(fun) + parameters = list(signature.parameters.values()) + new_param = inspect.Parameter(name="bounds", + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD) + parameters.insert(1, new_param) + self.reference_signature = inspect.Signature(parameters) - self.reference_signature = self.fun jit, unroll = self._get_loop_options() linesearch_solver = _setup_linesearch( From a5aada2815b954b0e933bb9630c9ec05f6e4e9ab Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Tue, 4 Jul 2023 10:45:07 +0100 Subject: [PATCH 2/2] add test --- tests/lbfgsb_test.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/lbfgsb_test.py b/tests/lbfgsb_test.py index 1e68a1ab..75df578c 100644 --- a/tests/lbfgsb_test.py +++ b/tests/lbfgsb_test.py @@ -248,6 +248,21 @@ def fun(x): self.assertEqual(N_CALLS, n_iter + 1) + def test_grad_with_bounds(self): + # Test that the gradient is correct when bounds are specified by keyword. + # Pertinent to issue #463. + def pipeline(x, init_pars, bounds, data): + def fit_objective(pars, data, x): + return -jax.scipy.stats.norm.logpdf(pars, loc=data*x, scale=1.0) + solver = LBFGSB(fun=fit_objective, implicit_diff=True, maxiter=500, tol=1e-6) + return solver.run(init_pars, bounds=bounds, data=data, x=x)[0] + + grad_fn = jax.grad(pipeline) + data = jnp.array(1.5) + res = grad_fn(0.5, jnp.array(0.0), (jnp.array(0.0), jnp.array(10.0)), data) + self.assertEqual(res, data) + + if __name__ == "__main__": absltest.main()