Skip to content

Commit

Permalink
LBFGS part 3: combine lbfgs and zoom linesearch
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 644798494
  • Loading branch information
vroulet authored and OptaxDev committed Jun 19, 2024
1 parent 524f1cb commit 99967d2
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 73 deletions.
14 changes: 6 additions & 8 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -1895,9 +1895,7 @@ def lbfgs(
scale_init_precond: bool = True,
linesearch: Optional[
base.GradientTransformationExtraArgs
] = _linesearch.scale_by_backtracking_linesearch(
max_backtracking_steps=30, store_grad=True
),
] = _linesearch.scale_by_zoom_linesearch(max_linesearch_steps=15),
) -> base.GradientTransformationExtraArgs:
r"""L-BFGS optimizer.
Expand Down Expand Up @@ -1967,9 +1965,9 @@ def lbfgs(
... )
... params = optax.apply_updates(params, updates)
... print('Objective function: ', f(params))
Objective function: 5.040001
Objective function: 7.460699e-14
Objective function: 1.0602291e-27
Objective function: 0.0
Objective function: 0.0
Objective function: 0.0
Objective function: 0.0
Objective function: 0.0
Expand All @@ -1979,7 +1977,7 @@ def lbfgs(
.. warning::
This optimizer works best with a linesearch (current default is a
backtracking linesearch). See example above for best use in a non-stochastic
zoom linesearch). See example above for best use in a non-stochastic
setting, where we can recycle gradients computed by the linesearch using
:func:`optax.value_and_grad_from_state`.
Expand All @@ -2002,7 +2000,7 @@ def lbfgs(
scale_init_precond: whether to use a scaled identity as the initial
preconditioner, see formula above.
linesearch: an instance of :class:`optax.GradientTransformationExtraArgs`
such as :func:`optax.scale_by_backtracking_linesearch` that computes a
such as :func:`optax.scale_by_zoom_linesearch` that computes a
learning rate, a.k.a. stepsize, to satisfy some criterion such as a
sufficient decrease of the objective by additional calls to the objective.
Expand Down
52 changes: 30 additions & 22 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from optax._src import alias
from optax._src import base
from optax._src import linesearch as _linesearch
from optax._src import numerics
from optax._src import transform
from optax._src import update
Expand Down Expand Up @@ -632,7 +633,7 @@ def fun(x):
sol_arr, _ = _run_lbfgs_solver(opt, fun, init_array, maxiter=3)
sol_tree, _ = _run_lbfgs_solver(opt, fun, init_tree, maxiter=3)
sol_tree = jnp.stack((sol_tree[0], sol_tree[1]))
chex.assert_trees_all_close(sol_arr, sol_tree)
chex.assert_trees_all_close(sol_arr, sol_tree, rtol=5*1e-5, atol=5*1e-5)

@parameterized.product(scale_init_precond=[True, False])
def test_multiclass_logreg(self, scale_init_precond):
Expand Down Expand Up @@ -695,15 +696,13 @@ def fun(weights):
)
chex.assert_trees_all_close(sol, sol_skl, atol=5e-2)

# TODO(vroulet): test eggholder and zakharov with zoom linesearch
# once implemented.
@parameterized.product(
problem_name=[
'rosenbrock',
'himmelblau',
'matyas',
# 'eggholder',
# 'zakharov',
'eggholder',
'zakharov',
],
)
def test_against_scipy(self, problem_name: str):
Expand All @@ -714,29 +713,38 @@ def test_against_scipy(self, problem_name: str):
init_params = problem['init']
jnp_fun, np_fun = problem['fun'], problem['numpy_fun']

opt = alias.lbfgs()
if problem_name == 'zakharov':
opt = alias.lbfgs(
linesearch=_linesearch.scale_by_zoom_linesearch(
max_linesearch_steps=30
)
)
else:
opt = alias.lbfgs()
optax_sol, _ = _run_lbfgs_solver(
opt, jnp_fun, init_params, maxiter=500, tol=tol
)
scipy_sol = scipy_optimize.minimize(np_fun, init_params, method='BFGS').x

# 1. Check minimizer obtained against known minimizer or scipy minimizer
if problem_name not in ['matyas', 'zakharov']:
chex.assert_trees_all_close(
optax_sol, problem['minimizer'], atol=tol, rtol=tol
)
else:
chex.assert_trees_all_close(optax_sol, scipy_sol, atol=tol, rtol=tol)

# 2. Check if minimum is reached or equal to scipy's found value
if problem_name == 'eggholder':
chex.assert_trees_all_close(
jnp_fun(optax_sol), np_fun(scipy_sol), atol=tol, rtol=tol
)
else:
chex.assert_trees_all_close(
jnp_fun(optax_sol), problem['minimum'], atol=tol, rtol=tol
)
with self.subTest('Check minimizer'):
if problem_name in ['matyas', 'zakharov']:
chex.assert_trees_all_close(
optax_sol, problem['minimizer'], atol=tol, rtol=tol
)
else:
chex.assert_trees_all_close(optax_sol, scipy_sol, atol=tol, rtol=tol)

with self.subTest('Check minimum'):
# 2. Check if minimum is reached or equal to scipy's found value
if problem_name == 'eggholder':
chex.assert_trees_all_close(
jnp_fun(optax_sol), np_fun(scipy_sol), atol=tol, rtol=tol
)
else:
chex.assert_trees_all_close(
jnp_fun(optax_sol), problem['minimum'], atol=tol, rtol=tol
)

def test_minimize_bad_initialization(self):
# This test runs deliberately "bad" initial values to test that handling
Expand Down
83 changes: 41 additions & 42 deletions optax/_src/linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,36 +375,6 @@ def body_fn(
return base.GradientTransformationExtraArgs(init_fn, update_fn)


# Flags to print errors, used for debugging, tested
WARNING_PREAMBLE = "WARNING: jaxopt.ZoomLineSearch: "
FLAG_NAN_INF_VALUES = (
WARNING_PREAMBLE + "NaN or Inf values encountered in function values."
)
FLAG_INTERVAL_NOT_FOUND = (
WARNING_PREAMBLE
+ "No interval satisfying curvature condition."
"Consider increasing maximal possible stepsize of the linesearch."
)
FLAG_INTERVAL_TOO_SMALL = (
WARNING_PREAMBLE
+ "Length of searched interval has been reduced below threshold."
)
FLAG_CURVATURE_COND_NOT_SATSIFIED = (
WARNING_PREAMBLE
+ "Returning stepsize with sufficient decrease "
"but curvature condition not satisfied."
)
FLAG_NO_STEPSIZE_FOUND = (
WARNING_PREAMBLE
+ "Linesearch failed, no stepsize satisfying sufficient decrease found."
)
FLAG_NOT_A_DESCENT_DIRECTION = (
WARNING_PREAMBLE
+ "The linesearch failed because the provided direction "
"is not a descent direction. "
)


def _cond_print(condition, message, **kwargs):
"""Prints message if condition is true."""
jax.lax.cond(
Expand Down Expand Up @@ -1239,7 +1209,7 @@ class ZoomLinesearchInfo(NamedTuple):
A positive value in the sufficient curvature error is more problematic as it
means that the algorithm may not be guaranteed to produce monotonically
decreasing values.
Consider using `verbose=True` in :func:`scale_by_zoom_linesearch` for
Consider using ``verbose=True`` in :func:`scale_by_zoom_linesearch` for
additional failure diagnostics if the linesearch fails.
Attributes:
Expand All @@ -1259,7 +1229,6 @@ class ZoomLinesearchInfo(NamedTuple):
class ScaleByZoomLinesearchState(NamedTuple):
"""State for scale_by_zoom_linesearch.
Attributes:
Attributes:
learning_rate: learning rate computed at the end of a round of line-search,
used to scale the update.
Expand Down Expand Up @@ -1306,16 +1275,16 @@ def scale_by_zoom_linesearch(
where
- :math:`f` is the function to minimize,
- :math:`w` are the current parameters,
- :math:`\eta` is the learning rate to find,
- :math:`u` is the update direction,
- :math:`c_1` is a coefficient (``slope_rtol``) measuring the relative
decrease of the function in terms of the slope (scalar product between
the gradient and the updates),
- :math:`c_2` is a coefficient (``curv_rtol``) measuring the relative
decrease of curvature.
- :math:`\epsilon` is an absolute tolerance (``tol``).
- :math:`f` is the function to minimize,
- :math:`w` are the current parameters,
- :math:`\eta` is the learning rate to find,
- :math:`u` is the update direction,
- :math:`c_1` is a coefficient (``slope_rtol``) measuring the relative
decrease of the function in terms of the slope (scalar product between
the gradient and the updates),
- :math:`c_2` is a coefficient (``curv_rtol``) measuring the relative
decrease of curvature.
- :math:`\epsilon` is an absolute tolerance (``tol``).
To deal with very flat functions, this linesearch switches from the sufficient
decrease criterion presented above to an approximate sufficient decrease
Expand Down Expand Up @@ -1552,3 +1521,33 @@ def update_fn(
)

return base.GradientTransformationExtraArgs(init_fn, update_fn)


# Flags to print errors, used for debugging, tested
WARNING_PREAMBLE = "WARNING: jaxopt.ZoomLineSearch: "
FLAG_NAN_INF_VALUES = (
WARNING_PREAMBLE + "NaN or Inf values encountered in function values."
)
FLAG_INTERVAL_NOT_FOUND = (
WARNING_PREAMBLE
+ "No interval satisfying curvature condition."
"Consider increasing maximal possible stepsize of the linesearch."
)
FLAG_INTERVAL_TOO_SMALL = (
WARNING_PREAMBLE
+ "Length of searched interval has been reduced below threshold."
)
FLAG_CURVATURE_COND_NOT_SATSIFIED = (
WARNING_PREAMBLE
+ "Returning stepsize with sufficient decrease "
"but curvature condition not satisfied."
)
FLAG_NO_STEPSIZE_FOUND = (
WARNING_PREAMBLE
+ "Linesearch failed, no stepsize satisfying sufficient decrease found."
)
FLAG_NOT_A_DESCENT_DIRECTION = (
WARNING_PREAMBLE
+ "The linesearch failed because the provided direction "
"is not a descent direction. "
)
2 changes: 1 addition & 1 deletion optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,7 +1327,7 @@ def left_product(vec, idx_alpha):

def scale_by_lbfgs(
memory_size: int = 10,
scale_init_precond: bool = False,
scale_init_precond: bool = True,
) -> base.GradientTransformation:
r"""Scales updates by L-BFGS.
Expand Down

0 comments on commit 99967d2

Please sign in to comment.