From 5ab2e8f796e689ba162d0de1046dd6515d9c463c Mon Sep 17 00:00:00 2001 From: Nasy Date: Tue, 4 Oct 2022 11:25:03 -0500 Subject: [PATCH 1/2] Expose adamax and adamaxw --- optax/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/optax/__init__.py b/optax/__init__.py index 5b9bf666..1b84a7ad 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -20,6 +20,8 @@ from optax._src.alias import adagrad from optax._src.alias import adam from optax._src.alias import adamw +from optax._src.alias import adamax +from optax._src.alias import adamaxw from optax._src.alias import dpsgd from optax._src.alias import fromage from optax._src.alias import lamb @@ -176,6 +178,8 @@ "adagrad", "adam", "adamw", + "adamax", + "adamaxw", "adaptive_grad_clip", "AdaptiveGradClipState", "add_decayed_weights", From 71e74558cf8b9c1ba5593452f702ffe7bf8058b3 Mon Sep 17 00:00:00 2001 From: Nasy Date: Tue, 4 Oct 2022 11:46:55 -0500 Subject: [PATCH 2/2] Add adamax and adamaxw to API doc. --- docs/api.rst | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/api.rst b/docs/api.rst index ad6f537f..6e408bc7 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -10,6 +10,8 @@ Common Optimizers adagrad adam adamw + adamax + adamaxw fromage lamb lars @@ -48,6 +50,16 @@ AdamW .. autofunction:: adamw +Adamax +~~~~ + +.. autofunction:: adamax + +AdamaxW +~~~~~ + +.. autofunction:: adamaxw + Fromage ~~~~~~~ @@ -147,6 +159,7 @@ Gradient Transforms Params scale scale_by_adam + scale_by_adamax scale_by_belief scale_by_factored_rms scale_by_optimistic_gradient @@ -257,6 +270,7 @@ Optax Transforms and States .. autofunction:: scale .. autofunction:: scale_by_adam +.. autofunction:: scale_by_adamax .. autofunction:: scale_by_belief .. autofunction:: scale_by_factored_rms .. autofunction:: scale_by_param_block_norm