From fbd8f5df082588854dc1a5bf5e4242622e0280db Mon Sep 17 00:00:00 2001 From: Vincent Roulet Date: Wed, 24 Apr 2024 13:58:58 -0700 Subject: [PATCH] Let momo accept additional extra arguments. PiperOrigin-RevId: 627842385 --- docs/api/contrib.rst | 11 +++++++++++ optax/contrib/_momo.py | 6 ++++++ 2 files changed, 17 insertions(+) diff --git a/docs/api/contrib.rst b/docs/api/contrib.rst index f0e2d1dd..ba6fad2b 100644 --- a/docs/api/contrib.rst +++ b/docs/api/contrib.rst @@ -16,6 +16,10 @@ Experimental features and algorithms that don't meet the dpsgd mechanize MechanicState + momo + MomoState + momo_adam + MomoAdamState prodigy ProdigyState sam @@ -63,6 +67,13 @@ Mechanize .. autoclass:: MechanicState :members: +Momo +~~~~ +.. autofunction:: momo +.. autoclass:: MomoState +.. autofunction:: momo_adam +.. autoclass:: MomoAdamState + Prodigy ~~~~~~~ .. autofunction:: prodigy diff --git a/optax/contrib/_momo.py b/optax/contrib/_momo.py index d90a7584..e203aa66 100644 --- a/optax/contrib/_momo.py +++ b/optax/contrib/_momo.py @@ -96,8 +96,11 @@ def update_fn( updates: base.Updates, state: MomoState, params: Optional[base.Params], + *, value: Optional[Array] = None, + **extra_args, ) -> tuple[base.Updates, MomoState]: + del extra_args if params is None: raise ValueError(base.NO_PARAMS_MSG) if value is None: @@ -227,8 +230,11 @@ def update_fn( updates: base.Updates, state: MomoAdamState, params: Optional[base.Params], + *, value: Optional[Array], + **extra_args, ) -> tuple[base.Updates, MomoAdamState]: + del extra_args if params is None: raise ValueError(base.NO_PARAMS_MSG) if value is None: