Skip to content

Commit

Permalink
Add a nesterov flag to radam optimizer.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Apr 24, 2024
1 parent fbd8f5d commit 1571de5
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
13 changes: 11 additions & 2 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,7 +1241,9 @@ def radam(
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
threshold: float = 5.0
threshold: float = 5.0,
*,
nesterov: bool = False,
) -> base.GradientTransformation:
"""The Rectified Adam optimizer.
Expand Down Expand Up @@ -1285,13 +1287,20 @@ def radam(
in RMSProp), to avoid dividing by zero when rescaling. This is needed for
instance when computing (meta-)gradients through Adam.
threshold: Threshold for variance tractability.
nesterov: Whether to use Nesterov momentum.
Returns:
The corresponding `GradientTransformation`.
"""
return combine.chain(
transform.scale_by_radam(
b1=b1, b2=b2, eps=eps, eps_root=eps_root, threshold=threshold),
b1=b1,
b2=b2,
eps=eps,
eps_root=eps_root,
threshold=threshold,
nesterov=nesterov,
),
transform.scale_by_learning_rate(learning_rate),
)

Expand Down
13 changes: 11 additions & 2 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,9 @@ def scale_by_radam(
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
threshold: float = 5.0
threshold: float = 5.0,
*,
nesterov: bool = False,
) -> base.GradientTransformation:
"""Rescale updates according to the Rectified Adam algorithm.
Expand Down Expand Up @@ -749,7 +751,14 @@ def update_fn(updates, state, params=None):
count_inc = numerics.safe_int32_increment(state.count)
b2t = b2**count_inc
ro = ro_inf - 2 * count_inc * b2t / (1 - b2t)
mu_hat = otu.tree_bias_correction(mu, b1, count_inc)
if nesterov:
mu_hat = jtu.tree_map(
lambda m, g: b1 * m + (1 - b1) * g,
otu.tree_bias_correction(
mu, b1, numerics.safe_int32_increment(count_inc)),
otu.tree_bias_correction(updates, b1, count_inc))
else:
mu_hat = otu.tree_bias_correction(mu, b1, count_inc)
nu_hat = otu.tree_bias_correction(nu, b2, count_inc)
updates = jax.lax.cond(
ro >= threshold, _radam_update, lambda _: mu_hat,
Expand Down

0 comments on commit 1571de5

Please sign in to comment.