diff --git a/optax/schedules/_schedule.py b/optax/schedules/_schedule.py index 7c388847..034fd8cf 100644 --- a/optax/schedules/_schedule.py +++ b/optax/schedules/_schedule.py @@ -241,24 +241,33 @@ def cosine_decay_schedule( alpha: float = 0.0, exponent: float = 1.0, ) -> base.Schedule: - """Returns a function which implements cosine learning rate decay. + r"""Returns a function which implements cosine learning rate decay. - The schedule does not restart when ``decay_steps`` has been reached. Instead, - the learning rate remains constant afterwards. For a cosine schedule with - restarts, :func:`optax.join_schedules` can be used to join several - cosine decay schedules. + This schedule smoothly decreases the learning rate over a specified number of + steps (``decay_steps``). The decay follows a cosine function, with an optional + exponent to modify the decay curve. A minimum value (``alpha``) ensures the + learning rate does not drop entirely to zero. + + More precisely, the learning rate at iteration :math:`t` is given by: + + .. math:: + + \gamma_0 (1 - \alpha) \frac{1}{2}(1+\cos(\pi\,\frac{t}{T})^p) + \alpha\,, + + where :math:`T` is the number of decay steps (``decay_steps``), :math:`p` is + the ``exponent`` and :math:`\gamma_0` is the initial value (``init_value``). References: Loshchilov et al., `SGDR: Stochastic Gradient Descent with Warm Restarts `_, 2017 Args: - init_value: An initial value ``init_v``. + init_value: An initial value for the learning rate. decay_steps: Positive integer - the number of steps for which to apply the decay for. - alpha: Float. The minimum value of the multiplier used to adjust the - learning rate. - exponent: Float. The default decay is ``0.5 * (1 + cos(pi * t/T))``, where + alpha: The minimum value of the multiplier used to adjust the + learning rate. Defaults to 0.0. + exponent: The default decay is ``0.5 * (1 + cos(pi * t/T))``, where ``t`` is the current timestep and ``T`` is the ``decay_steps``. The exponent modifies this to be ``(0.5 * (1 + cos(pi * t/T))) ** exponent``. Defaults to 1.0.