Skip to content

Commit

Permalink
Backtracking linesearch.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605625086
  • Loading branch information
vroulet authored and OptaxDev committed Feb 14, 2024
1 parent 3358c34 commit ec46301
Show file tree
Hide file tree
Showing 7 changed files with 812 additions and 22 deletions.
6 changes: 6 additions & 0 deletions docs/api/transformations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ Transformations
ScaleByAdamState
scale_by_amsgrad
ScaleByAmsgradState
scale_by_backtracking_linesearch
ScaleByBacktrackingLinesearchState
scale_by_belief
ScaleByBeliefState
scale_by_factored_rms
Expand Down Expand Up @@ -177,6 +179,10 @@ Transformations and states
.. autoclass:: ScaleByAmsgradState
:members:

.. autofunction:: scale_by_backtracking_linesearch
.. autoclass:: ScaleByBacktrackingLinesearchState
:members:

.. autofunction:: scale_by_belief
.. autoclass:: ScaleByBeliefState
:members:
Expand Down
15 changes: 15 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,26 @@ General

.. autosummary::
scale_gradient
recycled_value_and_grad
extract_from_state
split_kwargs

Scale gradient
~~~~~~~~~~~~~~
.. autofunction:: scale_gradient

Recycle value and grad from state
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: recycled_value_and_grad

Extract from state
~~~~~~~~~~~~~~~~~~
.. autofunction:: extract_from_state

Split kwargs
~~~~~~~~~~~~~~~~~~
.. autofunction:: split_kwargs


Numerical Stability
-------------------
Expand Down
10 changes: 10 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
from optax._src.linear_algebra import global_norm
from optax._src.linear_algebra import matrix_inverse_pth_root
from optax._src.linear_algebra import power_iteration
from optax._src.linesearch import scale_by_backtracking_linesearch
from optax._src.linesearch import ScaleByBacktrackingLinesearchState
from optax._src.lookahead import lookahead
from optax._src.lookahead import LookaheadParams
from optax._src.lookahead import LookaheadState
Expand Down Expand Up @@ -144,8 +146,11 @@
from optax._src.update import apply_updates
from optax._src.update import incremental_update
from optax._src.update import periodic_update
from optax._src.utils import extract_from_state
from optax._src.utils import multi_normal
from optax._src.utils import recycled_value_and_grad
from optax._src.utils import scale_gradient
from optax._src.utils import split_kwargs
from optax._src.wrappers import apply_if_finite
from optax._src.wrappers import ApplyIfFiniteState
from optax._src.wrappers import flatten
Expand Down Expand Up @@ -255,6 +260,7 @@
"EmaState",
"EmptyState",
"exponential_decay",
"extract_from_state",
"FactoredState",
"flatten",
"fromage",
Expand Down Expand Up @@ -306,6 +312,7 @@
"polynomial_schedule",
"power_iteration",
"radam",
"recycled_value_and_grad",
"rmsprop",
"rprop",
"safe_int32_increment",
Expand All @@ -316,6 +323,7 @@
"scale_by_adam",
"scale_by_adamax",
"scale_by_amsgrad",
"scale_by_backtracking_linesearch",
"scale_by_belief",
"scale_by_lion",
"scale_by_factored_rms",
Expand All @@ -336,6 +344,7 @@
"ScaleByAdaDeltaState",
"ScaleByAdamState",
"ScaleByAmsgradState",
"ScaleByBacktrackingLinesearchState",
"ScaleByBeliefState",
"ScaleByLionState",
"ScaleByNovogradState",
Expand All @@ -359,6 +368,7 @@
"smooth_labels",
"softmax_cross_entropy",
"softmax_cross_entropy_with_integer_labels",
"split_kwargs",
"stateless",
"stateless_with_tree_map",
"trace",
Expand Down

0 comments on commit ec46301

Please sign in to comment.