-
Notifications
You must be signed in to change notification settings - Fork 400
/
early_stopper.py
154 lines (129 loc) · 6.65 KB
/
early_stopper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Early stopping callback."""
from __future__ import annotations
import logging
from typing import Any, Callable, Optional, Union
import torch
from composer.core import State, Time
from composer.core.callback import Callback
from composer.core.time import TimeUnit
from composer.loggers import Logger
log = logging.getLogger(__name__)
__all__ = ['EarlyStopper']
class EarlyStopper(Callback):
"""Track a metric and halt training if it does not improve within a given interval.
Example:
.. doctest::
>>> from composer import Evaluator, Trainer
>>> from composer.callbacks.early_stopper import EarlyStopper
>>> from torchmetrics.classification.accuracy import Accuracy
>>> # constructing trainer object with this callback
>>> early_stopper = EarlyStopper("Accuracy", "my_evaluator", patience=1)
>>> evaluator = Evaluator(
... dataloader = eval_dataloader,
... label = 'my_evaluator',
... metrics = Accuracy()
... )
>>> trainer = Trainer(
... model=model,
... train_dataloader=train_dataloader,
... eval_dataloader=evaluator,
... optimizers=optimizer,
... max_duration="1ep",
... callbacks=[early_stopper],
... )
Args:
monitor (str): The name of the metric to monitor.
dataloader_label (str): The label of the dataloader or evaluator associated with the tracked metric.
If ``monitor`` is in an :class:`.Evaluator`, the ``dataloader_label`` field should be set to the label of the
:class:`.Evaluator`.
If monitor is a training metric or an ordinary evaluation metric not in an :class:`.Evaluator`,
the ``dataloader_label`` should be set to the dataloader label, which defaults to ``'train'`` or
``'eval'``, respectively.
comp (str | (Any, Any) -> Any, optional): A comparison operator to measure change of the monitored metric.
The comparison operator will be called ``comp(current_value, prev_best)``. For metrics where the optimal value is low
(error, loss, perplexity), use a less than operator, and for metrics like accuracy where the optimal value
is higher, use a greater than operator. Defaults to :func:`torch.less` if loss, error, or perplexity are substrings
of the monitored metric, otherwise defaults to :func:`torch.greater`.
min_delta (float, optional): An optional float that requires a new value to exceed the best value by at least that amount.
Default: ``0.0``.
patience (Time | int | str, optional): The interval of time the monitored metric can not improve without stopping
training. Default: 1 epoch. If patience is an integer, it is interpreted as the number of epochs.
"""
def __init__(
self,
monitor: str,
dataloader_label: str,
comp: Optional[Union[str, Callable[[
Any,
Any,
], Any]]] = None,
min_delta: float = 0.0,
patience: Union[int, str, Time] = 1,
):
self.monitor = monitor
self.dataloader_label = dataloader_label
self.min_delta = abs(min_delta)
if callable(comp):
self.comp_func = comp
if isinstance(comp, str):
if comp.lower() in ('greater', 'gt'):
self.comp_func = torch.greater
elif comp.lower() in ('less', 'lt'):
self.comp_func = torch.less
else:
raise ValueError(
"Unrecognized comp string. Use the strings 'gt', 'greater', 'lt' or 'less' or a callable comparison operator"
)
if comp is None:
if any(substr in monitor.lower() for substr in ['loss', 'error', 'perplexity']):
self.comp_func = torch.less
else:
self.comp_func = torch.greater
self.best = None
self.best_occurred = None
if isinstance(patience, str):
self.patience = Time.from_timestring(patience)
elif isinstance(patience, int):
self.patience = Time(patience, TimeUnit.EPOCH)
else:
self.patience = patience
if self.patience.unit not in (TimeUnit.EPOCH, TimeUnit.BATCH):
raise ValueError('If `patience` is an instance of Time, it must have units of EPOCH or BATCH.')
def _get_monitored_metric(self, state: State):
if self.dataloader_label in state.current_metrics:
if self.monitor in state.current_metrics[self.dataloader_label]:
return state.current_metrics[self.dataloader_label][self.monitor]
raise ValueError(f"Couldn't find the metric {self.monitor} with the dataloader label {self.dataloader_label}."
"Check that the dataloader_label is set to 'eval', 'train' or the evaluator name.")
def _update_stopper_state(self, state: State):
metric_val = self._get_monitored_metric(state)
if not torch.is_tensor(metric_val):
metric_val = torch.tensor(metric_val)
if self.best is None:
self.best = metric_val
self.best_occurred = state.timestamp
elif self.comp_func(metric_val, self.best) and torch.abs(metric_val - self.best) > self.min_delta:
self.best = metric_val
self.best_occurred = state.timestamp
assert self.best_occurred is not None
if self.patience.unit == TimeUnit.EPOCH:
if state.timestamp.epoch - self.best_occurred.epoch > self.patience:
state.max_duration = state.timestamp.batch
elif self.patience.unit == TimeUnit.BATCH:
if state.timestamp.batch - self.best_occurred.batch > self.patience:
state.max_duration = state.timestamp.batch
else:
raise ValueError(f'The units of `patience` should be EPOCH or BATCH.')
def eval_end(self, state: State, logger: Logger) -> None:
if self.dataloader_label == state.dataloader_label:
# if the monitored metric is an eval metric or in an evaluator
self._update_stopper_state(state)
def epoch_end(self, state: State, logger: Logger) -> None:
if self.dataloader_label == state.dataloader_label:
# if the monitored metric is not an eval metric, the right logic is run on EPOCH_END
self._update_stopper_state(state)
def batch_end(self, state: State, logger: Logger) -> None:
if self.patience.unit == TimeUnit.BATCH and self.dataloader_label == state.dataloader_label:
self._update_stopper_state(state)