Skip to content

Commit

Permalink
Document the extra args of the update function in docstring
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615443065
  • Loading branch information
fabianp authored and OptaxDev committed Mar 13, 2024
1 parent 269c9dc commit 0f9ea47
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
4 changes: 3 additions & 1 deletion optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,7 +1771,9 @@ def polyak_sgd(
eps: a value to add in the denominator of the update (defaults to 0).
Returns:
A :class:`GradientTransformationExtraArgs`.
A :class:`GradientTransformationExtraArgs`, where the ``update`` function
takes an additional keyword argument ``value`` containing the current
value of the objective function.
"""
return combine.chain(
sgd(learning_rate=scaling),
Expand Down
10 changes: 9 additions & 1 deletion optax/_src/linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,15 @@ def scale_by_backtracking_linesearch(
:func:`optax.value_and_grad_from_state`. See the example above.
Returns:
The corresponding GradientTransformationExtraArgs.
A :class:`GradientTransformationExtraArgs`, where the ``update`` function
takes the following additional keyword arguments:
* value: value of the function at the current params.
* grad: gradient of the function at the current params.
* value_fn: function returning the value of the function we seek to
optimize.
* **extra_args: additional keyword arguments, if the function needs
additional arguments such as input data, they should be put there (
see example in this docstrihng).
"""

def init_fn(params: base.Params) -> ScaleByBacktrackingLinesearchState:
Expand Down

0 comments on commit 0f9ea47

Please sign in to comment.