-
Notifications
You must be signed in to change notification settings - Fork 488
/
scheduler.py
159 lines (125 loc) · 7.27 KB
/
scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
"""Experimental learning rate schedulers used for training LLMs."""
import textwrap
import warnings
from typing import Union
from composer.core import State, Time, TimeUnit
from composer.optim import ComposerScheduler, LinearScheduler
from composer.optim.scheduler import _convert_time
__all__ = ['InverseSquareRootWithWarmupScheduler']
def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time],
name: str) -> None:
if isinstance(time, str):
time = Time.from_timestring(time)
if isinstance(t_max, str):
t_max = Time.from_timestring(t_max)
assert not isinstance(time, str) and not isinstance(t_max, str)
if time.unit != t_max.unit:
raise ValueError(f'{time.unit=} does not match {t_max.unit=}.')
def _raise_if_units_dur(time: Union[str, Time], name: str) -> None:
if isinstance(time, str):
time = Time.from_timestring(time)
assert not isinstance(time, str)
if time.unit == TimeUnit('dur'):
raise ValueError(f'{name} cannot be in units of "dur".')
class InverseSquareRootWithWarmupScheduler(ComposerScheduler):
r"""Inverse square root LR decay with warmup and optional linear cooldown.
Specifically, the learning rate multiplier :math:`\alpha(t)` can be expressed as:
.. math::
\alpha(t) = \begin{cases}
t / t_{warmup}, & \text{if } t < t_{warmup} \\
\alpha_{f,decay} + \frac{1 - \alpha_{f,decay}}{\sqrt{\tau_d}}, & \text{if } t_{warmup} <= t < t_{max} - t_{cooldown} \\
\alpha_i + (alpha_{f,cooldown} - \alpha_i) \times \tau_c, & \text{otherwise}
\end{cases}
Given :math:`\tau_d`, the time elapsed during the inverse square root decay (normalized by :math:`t_scale`), as:
.. math::
\tau_d = (t - t_{warmup} + t_{scale}) / {t_scale}
:math:`\alpha_i` as the value of the learning rate multiplier when :math:`\tau_d` is evaluated at :math:`t = t_{max} - t_{cooldown}`,
and :math:`\tau_c`, the fraction of linear cooldown time elapsed (clipped to the interval :math:`[0, 1]`), as:
.. math::
\tau_c = (t - t_{max} + t_{cooldown}) / t_{cooldown}
Where :math:`t_{warmup}` represents the warmup time, :math:`t_{scale}` represents the time scale,
:math:`t_{cooldown}` represents the cooldown time, :math:`t_{max}` represents the duration of this scheduler,
:math:`\alpha_{f,decay}` represents the learning rate multiplier that the inverse square root decays to at infinite time,
and :math:`\alpha_{f,cooldown}` represents the learning rate multiplier that the linear cooldown decays to.
Note, :math:`\alpha_{f,decay} >= \alpha_{f,cooldown}` to ensure that the learning rate is monotonically decreasing after warmup.
Also note, ``t_warmup``, ``t_scale``, and ``t_cooldown`` cannot be specified in units of duration; since this schedule is designed for continual learning,
``max_duration`` is expected to change. Instead, these parameters need to be specified in the same units as ``max_duration`` passed to the trainer.
Args:
t_warmup (str | Time): The warmup time.
t_scale (str | Time): The time scale.
t_cooldown (str | Time): The cooldown time.
t_max (str | Time): The duration of this scheduler. Default = ``"1dur"``.
alpha_f_decay (float): The learning rate multiplier to decay inverse square root decay to. Default = ``0.0``.
alpha_f_cooldown (float): The learning rate multiplier to decay linear cooldown to. Default = ``0.0``.
"""
def __init__(self,
t_warmup: Union[str, Time],
t_scale: Union[str, Time],
t_cooldown: Union[str, Time],
t_max: Union[str, Time] = '1dur',
alpha_f_decay: float = 0.0,
alpha_f_cooldown: float = 0.0) -> None:
if alpha_f_decay < alpha_f_cooldown:
raise ValueError(('Required: alpha_f_decay >= alpha_f_cooldown. '
f'Current: alpha_f_decay={alpha_f_decay}, '
f'alpha_f_cooldown={alpha_f_cooldown}.'))
_raise_if_units_dur(t_warmup, 't_warmup')
_raise_if_units_dur(t_scale, 't_scale')
_raise_if_units_dur(t_cooldown, 't_cooldown')
self.t_warmup = t_warmup
self.t_scale = t_scale
self.t_cooldown = t_cooldown
self.t_max = t_max
self.alpha_f_decay = alpha_f_decay
self.alpha_f_cooldown = alpha_f_cooldown
self.warmup_scheduler = LinearScheduler(alpha_i=0.0,
alpha_f=1.0,
t_max=t_warmup)
def __call__(self, state: State, ssr: float = 1.0) -> float:
assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked'
_raise_if_units_dont_match(self.t_warmup, state.max_duration,
't_warmup')
_raise_if_units_dont_match(self.t_scale, state.max_duration, 't_scale')
_raise_if_units_dont_match(self.t_cooldown, state.max_duration,
't_cooldown')
t_warmup = _convert_time(self.t_warmup, state)
if t_warmup.value == 0:
warnings.warn(
textwrap.dedent("""\
The warmup duration is 0. If warmup was specified as a fraction of the total
training duration, the warmup duration is calculated in the
same unit as the trainer's max_duration parameter."""))
if state.timestamp < t_warmup:
return self.warmup_scheduler(state)
t_scale = _convert_time(self.t_scale, state, ssr=ssr)
t_cooldown = _convert_time(self.t_cooldown, state, ssr=ssr)
t_max = _convert_time(self.t_max, state, ssr=ssr)
current_time = state.timestamp.get(t_scale.unit)
t_shift = t_scale - t_warmup
# t_cooldown_start is max of t_warmup, t_max - t_cooldown
t_cooldown_start = t_max - t_cooldown
if t_cooldown_start < t_warmup:
t_cooldown_start = t_warmup
if state.timestamp < t_cooldown_start:
# Rescale LR by a coefficient equal to the inverse square root of the time
# elapsed after warmup, rescaled by the time scale, such that, at
# infinite time, the LR decays to alpha_f_decay.
coeff = 1 / ((current_time + t_shift) / t_scale).value**0.5
current_factor = (self.alpha_f_decay + coeff *
(1.0 - self.alpha_f_decay))
return current_factor
else:
coeff = 1 / ((t_cooldown_start + t_shift) / t_scale).value**0.5
alpha_i = self.alpha_f_decay + coeff * (1.0 - self.alpha_f_decay)
if t_cooldown.value == 0:
return alpha_i
# Linearly decay the LR from its value at the step at which cooldown
# started to alpha_f_cooldown over t_cooldown time.
frac_of_cooldown = ((current_time - t_cooldown_start) /
t_cooldown).value
frac_of_cooldown = min(1.0, frac_of_cooldown)
current_factor = (alpha_i + frac_of_cooldown *
(self.alpha_f_cooldown - alpha_i))
return current_factor