-
Notifications
You must be signed in to change notification settings - Fork 400
/
event.py
233 lines (184 loc) · 9.12 KB
/
event.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Training Loop Events."""
from composer.utils.string_enum import StringEnum
__all__ = ['Event']
class Event(StringEnum):
"""Enum to represent training loop events.
Events mark specific points in the training loop where an :class:`~.core.Algorithm` and :class:`~.core.Callback`
can run.
The following pseudocode shows where each event fires in the training loop:
.. code-block:: python
# <INIT>
# <FIT_START>
for epoch in range(NUM_EPOCHS):
# <EPOCH_START>
while True:
# <BEFORE_DATALOADER>
batch = next(dataloader)
if batch is None:
break
# <AFTER_DATALOADER>
# <BATCH_START>
# <BEFORE_TRAIN_BATCH>
for microbatch in batch.split(grad_accum):
# <BEFORE_FORWARD>
outputs = model(batch)
# <AFTER_FORWARD>
# <BEFORE_LOSS>
loss = model.loss(outputs, batch)
# <AFTER_LOSS>
# <BEFORE_BACKWARD>
loss.backward()
# <AFTER_BACKWARD>
# Un-scale and clip gradients
# <AFTER_TRAIN_BATCH>
optimizer.step()
# <BATCH_END>
if should_eval(batch=True):
for eval_dataloader in eval_dataloaders:
# <EVAL_START>
for batch in eval_dataloader:
# <EVAL_BATCH_START>
# <EVAL_BEFORE_FORWARD>
outputs, targets = model(batch)
# <EVAL_AFTER_FORWARD>
metrics.update(outputs, targets)
# <EVAL_BATCH_END>
# <EVAL_END>
# <BATCH_CHECKPOINT>
# <EPOCH_END>
if should_eval(batch=False):
for eval_dataloader in eval_dataloaders:
# <EVAL_START>
for batch in eval_dataloader:
# <EVAL_BATCH_START>
# <EVAL_BEFORE_FORWARD>
outputs, targets = model(batch)
# <EVAL_AFTER_FORWARD>
metrics.update(outputs, targets)
# <EVAL_BATCH_END>
# <EVAL_END>
# <EPOCH_CHECKPOINT>
# <FIT_END>
Attributes:
INIT: Invoked in the constructor of :class:`~.trainer.Trainer`. Model surgery (see
:mod:`~composer.utils.module_surgery`) typically occurs here.
FIT_START: Invoked at the beginning of each call to :meth:`.Trainer.fit`. Dataset transformations typically
occur here.
EPOCH_START: Start of an epoch.
BEFORE_DATALOADER: Immediately before the dataloader is called.
AFTER_DATALOADER: Immediately after the dataloader is called. Typically used for on-GPU dataloader transforms.
BATCH_START: Start of a batch.
BEFORE_TRAIN_BATCH: Before the forward-loss-backward computation for a training batch. When using gradient
accumulation, this is still called only once.
BEFORE_FORWARD: Before the call to ``model.forward()``.
This is called multiple times per batch when using gradient accumulation.
AFTER_FORWARD: After the call to ``model.forward()``.
This is called multiple times per batch when using gradient accumulation.
BEFORE_LOSS: Before the call to ``model.loss()``.
This is called multiple times per batch when using gradient accumulation.
AFTER_LOSS: After the call to ``model.loss()``.
This is called multiple times per batch when using gradient accumulation.
BEFORE_BACKWARD: Before the call to ``loss.backward()``.
This is called multiple times per batch when using gradient accumulation.
AFTER_BACKWARD: After the call to ``loss.backward()``.
This is called multiple times per batch when using gradient accumulation.
AFTER_TRAIN_BATCH: After the forward-loss-backward computation for a training batch. When using gradient
accumulation, this event still fires only once.
BATCH_END: End of a batch, which occurs after the optimizer step and any gradient scaling.
BATCH_CHECKPOINT: After :attr:`.Event.BATCH_END` and any batch-wise evaluation. Saving checkpoints at this
event allows the checkpoint saver to use the results from any batch-wise evaluation to determine whether
a checkpoint should be saved.
EPOCH_END: End of an epoch.
EPOCH_CHECKPOINT: After :attr:`.Event.EPOCH_END` and any epoch-wise evaluation. Saving checkpoints at this
event allows the checkpoint saver to use the results from any epoch-wise evaluation to determine whether
a checkpointshould be saved.
FIT_END: Invoked at the end of each call to :meth:`.Trainer.fit`. This event exists primarily for logging information
and flushing callbacks. Algorithms should not transform the training state on this event, as any changes will not
be preserved in checkpoints.
EVAL_START: Start of evaluation through the validation dataset.
EVAL_BATCH_START: Before the call to ``model.validate(batch)``
EVAL_BEFORE_FORWARD: Before the call to ``model.validate(batch)``
EVAL_AFTER_FORWARD: After the call to ``model.validate(batch)``
EVAL_BATCH_END: After the call to ``model.validate(batch)``
EVAL_END: End of evaluation through the validation dataset.
"""
INIT = 'init'
FIT_START = 'fit_start'
EPOCH_START = 'epoch_start'
BEFORE_DATALOADER = 'before_dataloader'
AFTER_DATALOADER = 'after_dataloader'
BATCH_START = 'batch_start'
BEFORE_TRAIN_BATCH = 'before_train_batch'
BEFORE_FORWARD = 'before_forward'
AFTER_FORWARD = 'after_forward'
BEFORE_LOSS = 'before_loss'
AFTER_LOSS = 'after_loss'
BEFORE_BACKWARD = 'before_backward'
AFTER_BACKWARD = 'after_backward'
AFTER_TRAIN_BATCH = 'after_train_batch'
BATCH_END = 'batch_end'
BATCH_CHECKPOINT = 'batch_checkpoint'
EPOCH_END = 'epoch_end'
EPOCH_CHECKPOINT = 'epoch_checkpoint'
FIT_END = 'fit_end'
EVAL_START = 'eval_start'
EVAL_BATCH_START = 'eval_batch_start'
EVAL_BEFORE_FORWARD = 'eval_before_forward'
EVAL_AFTER_FORWARD = 'eval_after_forward'
EVAL_BATCH_END = 'eval_batch_end'
EVAL_END = 'eval_end'
PREDICT_START = 'predict_start'
PREDICT_BATCH_START = 'predict_batch_start'
PREDICT_BEFORE_FORWARD = 'predict_before_forward'
PREDICT_AFTER_FORWARD = 'predict_after_forward'
PREDICT_BATCH_END = 'predict_batch_end'
PREDICT_END = 'predict_end'
@property
def is_before_event(self) -> bool:
"""Whether the event is an "before" event.
An "before" event (e.g., :attr:`~Event.BEFORE_LOSS`) has a corresponding "after" event
(.e.g., :attr:`~Event.AFTER_LOSS`).
"""
return self in _BEFORE_EVENTS
@property
def is_after_event(self) -> bool:
"""Whether the event is an "after" event.
An "after" event (e.g., :attr:`~Event.AFTER_LOSS`) has a corresponding "before" event
(.e.g., :attr:`~Event.BEFORE_LOSS`).
"""
return self in _AFTER_EVENTS
@property
def canonical_name(self) -> str:
"""The name of the event, without before/after markers.
Events that have a corresponding "before" or "after" event share the same canonical name.
Example:
>>> Event.EPOCH_START.canonical_name
'epoch'
>>> Event.EPOCH_END.canonical_name
'epoch'
Returns:
str: The canonical name of the event.
"""
name: str = self.value
name = name.replace('before_', '')
name = name.replace('after_', '')
name = name.replace('_start', '')
name = name.replace('_end', '')
return name
@property
def is_predict(self) -> bool:
"""Whether the event is during the predict loop."""
return self.value.startswith('predict')
@property
def is_eval(self) -> bool:
"""Whether the event is during the eval loop."""
return self.value.startswith('eval')
_BEFORE_EVENTS = (Event.FIT_START, Event.EPOCH_START, Event.BEFORE_DATALOADER, Event.BATCH_START,
Event.BEFORE_TRAIN_BATCH, Event.BEFORE_FORWARD, Event.BEFORE_LOSS, Event.BEFORE_BACKWARD,
Event.EVAL_START, Event.EVAL_BATCH_START, Event.EVAL_BEFORE_FORWARD, Event.PREDICT_START,
Event.PREDICT_BATCH_START, Event.PREDICT_BEFORE_FORWARD)
_AFTER_EVENTS = (Event.EPOCH_END, Event.BATCH_END, Event.AFTER_DATALOADER, Event.AFTER_TRAIN_BATCH, Event.AFTER_FORWARD,
Event.AFTER_LOSS, Event.AFTER_BACKWARD, Event.EVAL_END, Event.EVAL_BATCH_END, Event.EVAL_AFTER_FORWARD,
Event.FIT_END, Event.PREDICT_END, Event.PREDICT_BATCH_END, Event.PREDICT_AFTER_FORWARD)