Skip to content

Commit

Permalink
Backtracking linesearch.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605625086
  • Loading branch information
vroulet authored and OptaxDev committed Feb 19, 2024
1 parent f4dd313 commit f7e4e08
Show file tree
Hide file tree
Showing 8 changed files with 1,000 additions and 36 deletions.
5 changes: 5 additions & 0 deletions docs/api/transformations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ Transformations
ScaleByAdamState
scale_by_amsgrad
ScaleByAmsgradState
scale_by_backtracking_linesearch
ScaleByBacktrackingLinesearchState
scale_by_belief
ScaleByBeliefState
scale_by_factored_rms
Expand Down Expand Up @@ -177,6 +179,9 @@ Transformations and states
.. autoclass:: ScaleByAmsgradState
:members:

.. autofunction:: scale_by_backtracking_linesearch
.. autoclass:: ScaleByBacktrackingLinesearchState

.. autofunction:: scale_by_belief
.. autoclass:: ScaleByBeliefState
:members:
Expand Down
5 changes: 5 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@ General

.. autosummary::
scale_gradient
value_and_grad_from_state

Scale gradient
~~~~~~~~~~~~~~
.. autofunction:: scale_gradient

Value and grad from state
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: value_and_grad_from_state


Numerical Stability
-------------------
Expand Down
6 changes: 6 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
from optax._src.linear_algebra import global_norm
from optax._src.linear_algebra import matrix_inverse_pth_root
from optax._src.linear_algebra import power_iteration
from optax._src.linesearch import scale_by_backtracking_linesearch
from optax._src.linesearch import ScaleByBacktrackingLinesearchState
from optax._src.lookahead import lookahead
from optax._src.lookahead import LookaheadParams
from optax._src.lookahead import LookaheadState
Expand Down Expand Up @@ -146,6 +148,7 @@
from optax._src.update import periodic_update
from optax._src.utils import multi_normal
from optax._src.utils import scale_gradient
from optax._src.utils import value_and_grad_from_state
from optax._src.wrappers import apply_if_finite
from optax._src.wrappers import ApplyIfFiniteState
from optax._src.wrappers import flatten
Expand Down Expand Up @@ -316,6 +319,7 @@
"scale_by_adam",
"scale_by_adamax",
"scale_by_amsgrad",
"scale_by_backtracking_linesearch",
"scale_by_belief",
"scale_by_lion",
"scale_by_factored_rms",
Expand All @@ -336,6 +340,7 @@
"ScaleByAdaDeltaState",
"ScaleByAdamState",
"ScaleByAmsgradState",
"ScaleByBacktrackingLinesearchState",
"ScaleByBeliefState",
"ScaleByLionState",
"ScaleByNovogradState",
Expand Down Expand Up @@ -367,6 +372,7 @@
"TransformUpdateFn",
"TransformUpdateExtraArgsFn",
"Updates",
"value_and_grad_from_state",
"warmup_cosine_decay_schedule",
"warmup_exponential_decay_schedule",
"yogi",
Expand Down

0 comments on commit f7e4e08

Please sign in to comment.