From 94a486f43535875b42f6d4f35a9ae6dedea144ef Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Wed, 24 Apr 2024 18:13:53 -0400 Subject: [PATCH] Add a nesterov flag to radam optimizer. --- optax/_src/alias.py | 13 +++++++++++-- optax/_src/transform.py | 14 ++++++++++++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index de13d31b..10ac849a 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -1241,7 +1241,9 @@ def radam( b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0, - threshold: float = 5.0 + threshold: float = 5.0, + *, + nesterov: bool = False, ) -> base.GradientTransformation: """The Rectified Adam optimizer. @@ -1285,13 +1287,20 @@ def radam( in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam. threshold: Threshold for variance tractability. + nesterov: Whether to use Nesterov momentum. Returns: The corresponding `GradientTransformation`. """ return combine.chain( transform.scale_by_radam( - b1=b1, b2=b2, eps=eps, eps_root=eps_root, threshold=threshold), + b1=b1, + b2=b2, + eps=eps, + eps_root=eps_root, + threshold=threshold, + nesterov=nesterov, + ), transform.scale_by_learning_rate(learning_rate), ) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 0143c84d..9dda6786 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -708,7 +708,9 @@ def scale_by_radam( b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0, - threshold: float = 5.0 + threshold: float = 5.0, + *, + nesterov: bool = False, ) -> base.GradientTransformation: """Rescale updates according to the Rectified Adam algorithm. @@ -722,6 +724,7 @@ def scale_by_radam( eps_root: Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. threshold: Threshold for variance tractability. + nesterov: Whether to use Nesterov momentum. Returns: A `GradientTransformation` object. @@ -749,7 +752,14 @@ def update_fn(updates, state, params=None): count_inc = numerics.safe_int32_increment(state.count) b2t = b2**count_inc ro = ro_inf - 2 * count_inc * b2t / (1 - b2t) - mu_hat = otu.tree_bias_correction(mu, b1, count_inc) + if nesterov: + mu_hat = jtu.tree_map( + lambda m, g: b1 * m + (1 - b1) * g, + otu.tree_bias_correction( + mu, b1, numerics.safe_int32_increment(count_inc)), + otu.tree_bias_correction(updates, b1, count_inc)) + else: + mu_hat = otu.tree_bias_correction(mu, b1, count_inc) nu_hat = otu.tree_bias_correction(nu, b2, count_inc) updates = jax.lax.cond( ro >= threshold, _radam_update, lambda _: mu_hat,