Skip to content

Commit

Permalink
Let momo accept additional extra arguments.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627842385
  • Loading branch information
vroulet authored and OptaxDev committed Apr 24, 2024
1 parent d9c2bbe commit fbd8f5d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
11 changes: 11 additions & 0 deletions docs/api/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ Experimental features and algorithms that don't meet the
dpsgd
mechanize
MechanicState
momo
MomoState
momo_adam
MomoAdamState
prodigy
ProdigyState
sam
Expand Down Expand Up @@ -63,6 +67,13 @@ Mechanize
.. autoclass:: MechanicState
:members:

Momo
~~~~
.. autofunction:: momo
.. autoclass:: MomoState
.. autofunction:: momo_adam
.. autoclass:: MomoAdamState

Prodigy
~~~~~~~
.. autofunction:: prodigy
Expand Down
6 changes: 6 additions & 0 deletions optax/contrib/_momo.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,11 @@ def update_fn(
updates: base.Updates,
state: MomoState,
params: Optional[base.Params],
*,
value: Optional[Array] = None,
**extra_args,
) -> tuple[base.Updates, MomoState]:
del extra_args
if params is None:
raise ValueError(base.NO_PARAMS_MSG)
if value is None:
Expand Down Expand Up @@ -227,8 +230,11 @@ def update_fn(
updates: base.Updates,
state: MomoAdamState,
params: Optional[base.Params],
*,
value: Optional[Array],
**extra_args,
) -> tuple[base.Updates, MomoAdamState]:
del extra_args
if params is None:
raise ValueError(base.NO_PARAMS_MSG)
if value is None:
Expand Down

0 comments on commit fbd8f5d

Please sign in to comment.