-
Notifications
You must be signed in to change notification settings - Fork 400
/
speed_monitor.py
152 lines (124 loc) · 6.81 KB
/
speed_monitor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Monitor throughput during training."""
from __future__ import annotations
from collections import deque
from typing import Any, Deque, Dict
from composer.core import State
from composer.core.callback import Callback
from composer.loggers import Logger
__all__ = ['SpeedMonitor']
class SpeedMonitor(Callback):
"""Logs the training throughput.
The training throughput in terms of number of samples per second is logged on the
:attr:`.Event.BATCH_END` event if we have reached the ``window_size`` threshold.
The wall clock train time is logged on every :attr:`.Event.BATCH_END` event.
The average throughout over an epoch is logged on the :attr:`.Event.EPOCH_END` event.
Example:
.. doctest::
>>> from composer import Trainer
>>> from composer.callbacks import SpeedMonitor
>>> # constructing trainer object with this callback
>>> trainer = Trainer(
... model=model,
... train_dataloader=train_dataloader,
... eval_dataloader=eval_dataloader,
... optimizers=optimizer,
... max_duration="1ep",
... callbacks=[SpeedMonitor(window_size=100)],
... )
The training throughput is logged by the :class:`~composer.loggers.logger.Logger` to the following keys as
described below.
+-----------------------+-------------------------------------------------------------+
| Key | Logged data |
+=======================+=============================================================+
| | Rolling average (over ``window_size`` most recent |
| ``samples/step`` | batches) of the number of samples processed per second |
| | |
+-----------------------+-------------------------------------------------------------+
| | Number of samples processed per second (averaged over |
| ``samples/epoch`` | an entire epoch) |
+-----------------------+-------------------------------------------------------------+
| ``wall_clock/train`` | Total elapsed training time |
+-----------------------+-------------------------------------------------------------+
| ``wall_clock/val`` | Total elapsed validation time |
+-----------------------+-------------------------------------------------------------+
| ``wall_clock/total`` | Total elapsed time (wall_clock/train + wall_clock/val) |
+-----------------------+-------------------------------------------------------------+
Args:
window_size (int, optional): Number of batches to use for a rolling average of throughput.
Defaults to 100.
"""
def __init__(self, window_size: int = 100):
# Track the epoch num samples and wct to compute throughput over the entire epoch
self.epoch_start_num_samples = 0
self.epoch_start_wct = 0.0
# Track the batch num samples and wct to compute throughput over a window of batches
self.batch_start_num_samples = 0
self.batch_start_wct = 0.0
self.batch_wct_buffer: Deque[float] = deque(maxlen=window_size)
self.batch_num_samples_buffer: Deque[int] = deque(maxlen=window_size)
self.window_size = window_size
# Keep track of time spent evaluating
self.total_eval_wct = 0.0
def state_dict(self) -> Dict[str, Any]:
return {
'epoch_start_num_samples': self.epoch_start_num_samples,
'epoch_start_wct': self.epoch_start_wct,
'batch_start_num_samples': self.batch_start_num_samples,
'batch_start_wct': self.batch_start_wct,
'batch_wct_buffer': self.batch_wct_buffer,
'batch_num_samples_buffer': self.batch_num_samples_buffer,
# "window_wct": self.window_wct,
# "window_num_samples": self.window_num_samples,
'total_eval_wct': self.total_eval_wct,
}
def load_state_dict(self, state: Dict[str, Any]) -> None:
self.epoch_start_num_samples = state['epoch_start_num_samples']
self.epoch_start_wct = state['epoch_start_wct']
self.batch_start_num_samples = state['batch_start_num_samples']
self.batch_start_wct = state['batch_start_wct']
self.batch_wct_buffer = deque(
[x for x in state['batch_wct_buffer']],
maxlen=self.window_size,
)
self.batch_num_samples_buffer = deque(
[x for x in state['batch_num_samples_buffer']],
maxlen=self.window_size,
)
self.total_eval_wct = state['total_eval_wct']
def epoch_start(self, state: State, logger: Logger):
del logger # unused
self.epoch_start_wct = state.timestamp.total_wct.total_seconds()
self.epoch_start_num_samples = int(state.timestamp.sample)
def batch_start(self, state: State, logger: Logger) -> None:
del logger # unused
self.batch_start_wct = state.timestamp.total_wct.total_seconds()
self.batch_start_num_samples = int(state.timestamp.sample)
def batch_end(self, state: State, logger: Logger):
batch_num_samples = int(state.timestamp.sample) - self.batch_start_num_samples
batch_wct = state.timestamp.total_wct.total_seconds() - self.batch_start_wct
# Add the new element
self.batch_wct_buffer.append(batch_wct)
self.batch_num_samples_buffer.append(batch_num_samples)
# Log the throughput
if len(self.batch_num_samples_buffer) == self.window_size:
throughput = sum(self.batch_num_samples_buffer) / sum(self.batch_wct_buffer)
logger.data_batch({'samples/step': throughput})
# Log the time
# `state.timestamp` excludes any time spent in evaluation
logger.data_batch({
'wall_clock/train': state.timestamp.total_wct.total_seconds(),
'wall_clock/val': self.total_eval_wct,
'wall_clock/total': (state.timestamp.total_wct.total_seconds() + self.total_eval_wct),
})
def eval_end(self, state: State, logger: Logger):
del logger # unused
self.total_eval_wct += state.eval_timestamp.total_wct.total_seconds()
def epoch_end(self, state: State, logger: Logger):
# `state.timestamp` excludes any time spent in evaluation
epoch_time_in_train = state.timestamp.total_wct.total_seconds() - self.epoch_start_wct
train_examples_per_epoch = int(state.timestamp.sample) - self.epoch_start_num_samples
logger.data_epoch({
'samples/epoch': train_examples_per_epoch / epoch_time_in_train,
})