Skip to content

Commit

Permalink
Clarify formula for cosine learning rate.
Browse files Browse the repository at this point in the history
I think it's useful to have the precise
formula of the learning in the docstring.
Otherwise I find myself often going to the
source code to understand the learning rate.

PiperOrigin-RevId: 610120777
  • Loading branch information
fabianp authored and OptaxDev committed Feb 25, 2024
1 parent d559415 commit 7aa5e79
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions optax/schedules/_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,24 +241,33 @@ def cosine_decay_schedule(
alpha: float = 0.0,
exponent: float = 1.0,
) -> base.Schedule:
"""Returns a function which implements cosine learning rate decay.
r"""Returns a function which implements cosine learning rate decay.
The schedule does not restart when ``decay_steps`` has been reached. Instead,
the learning rate remains constant afterwards. For a cosine schedule with
restarts, :func:`optax.join_schedules` can be used to join several
cosine decay schedules.
This schedule smoothly decreases the learning rate over a specified number of
steps (``decay_steps``). The decay follows a cosine function, with an optional
exponent to modify the decay curve. A minimum value (``alpha``) ensures the
learning rate does not drop entirely to zero.
More precisely, the learning rate at iteration :math:`t` is given by:
.. math::
\gamma_0 (1 - \alpha) \frac{1}{2}(1+\cos(\pi\,\frac{t}{T})^p) + \alpha\,,
where :math:`T` is the number of decay steps (``decay_steps``), :math:`p` is
the ``exponent`` and :math:`\gamma_0` is the initial value (``init_value``).
References:
Loshchilov et al., `SGDR: Stochastic Gradient Descent with Warm Restarts
<https://arxiv.org/abs/1608.03983>`_, 2017
Args:
init_value: An initial value ``init_v``.
init_value: An initial value for the learning rate.
decay_steps: Positive integer - the number of steps for which to apply
the decay for.
alpha: Float. The minimum value of the multiplier used to adjust the
learning rate.
exponent: Float. The default decay is ``0.5 * (1 + cos(pi * t/T))``, where
alpha: The minimum value of the multiplier used to adjust the
learning rate. Defaults to 0.0.
exponent: The default decay is ``0.5 * (1 + cos(pi * t/T))``, where
``t`` is the current timestep and ``T`` is the ``decay_steps``. The
exponent modifies this to be ``(0.5 * (1 + cos(pi * t/T))) ** exponent``.
Defaults to 1.0.
Expand Down

0 comments on commit 7aa5e79

Please sign in to comment.