-
Notifications
You must be signed in to change notification settings - Fork 400
/
profiler_schedule.py
64 lines (52 loc) · 2.75 KB
/
profiler_schedule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Profiler Schedules."""
from typing import Callable
from composer.core.state import State
from composer.profiler.profiler_action import ProfilerAction
__all__ = ['cyclic_schedule']
def cyclic_schedule(
skip_first: int = 0,
wait: int = 0,
warmup: int = 1,
active: int = 4,
repeat: int = 1,
) -> Callable[[State], ProfilerAction]:
"""Profiler schedule function for a cyclic profiling window.
This function returns a schedule function that uses a cyclic profiling window. The resulting function can be
passed as the ``prof_schedule`` argument to the :class:`.Trainer`.
The cyclic window skips the first ``skip_first`` batches in every epoch. Then, it performs a cycle of
skipping ``wait`` batches, warming up for ``warmup`` batches, and recording ``active`` batches.
It repeats this cycle up to ``repeat`` times per epoch (or for the entire epoch, if ``repeat`` is 0).
This logic repeats every epoch.
Args:
skip_first (int, optional): Number of batches to skip profiling at epoch start. Defaults to ``0``.
wait (int, optional): For each profiling cycle, number of batches to skip at the beginning of the cycle.
Defaults to ``0``.
warmup (int, optional): For each profiling cycle, number of batches to be in the warmup state after skipping
``wait`` batches. Defaults to ``1``.
active (int, optional): For each profiling cycle, number of batches to record after warming up. Defaults to ``4``.
repeat (int, optional): Number of profiling cycles to perform per epoch. Set to ``0`` to record the entire epoch.
Defaults to ``1``.
Returns:
(State -> ProfilerAction): A ``prof_schedule`` for the :class:`.Trainer`.
"""
def schedule(state: State):
# do wait, then warump, then active, up to repeat times per cycle
cycle_len = wait + warmup + active
batch_idx = int(state.timestamp.batch_in_epoch)
if batch_idx < skip_first:
return ProfilerAction.SKIP
if repeat != 0 and batch_idx >= cycle_len * repeat + skip_first:
# exhausted the repeat
return ProfilerAction.SKIP
position_in_cycle = (batch_idx - skip_first) % cycle_len
if position_in_cycle < wait:
return ProfilerAction.SKIP
if position_in_cycle < wait + warmup:
return ProfilerAction.WARMUP
is_last_batch_in_epoch = state.dataloader_len is not None and state.timestamp.batch_in_epoch == state.dataloader_len - 1
if position_in_cycle == cycle_len - 1 or is_last_batch_in_epoch:
return ProfilerAction.ACTIVE_AND_SAVE
return ProfilerAction.ACTIVE
return schedule