Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support EarlyStoppingHook #739

Merged
merged 31 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
753bf2c
[Feature] EarlyStoppingHook
nijkah Nov 17, 2022
833d5fc
delete redundant line
nijkah Nov 17, 2022
8e2bd76
Assert stop_training and rename tests
nijkah Nov 18, 2022
3e07e10
Fix UT
nijkah Nov 18, 2022
b20112b
rename `metric` to `monitor`
nijkah Nov 18, 2022
67dc666
Fix UT
nijkah Nov 18, 2022
1808b79
Fix UT
nijkah Nov 19, 2022
1107993
edit docstring on patience
nijkah Nov 21, 2022
3e35b41
Draft for new code
nijkah Nov 27, 2022
de0ae9f
fix ut
nijkah Nov 27, 2022
c991f35
add test case
nijkah Nov 27, 2022
e4a20be
add test case
nijkah Nov 27, 2022
7caec6f
fix ut
nijkah Nov 27, 2022
4251d31
Apply suggestions from code review
nijkah Nov 28, 2022
ea87f73
Apply suggestions from code review
nijkah Nov 30, 2022
79869c2
Append hook
nijkah Nov 30, 2022
52edf9d
Append hook
nijkah Nov 30, 2022
caa7187
Apply suggestions
nijkah Nov 30, 2022
e1e812c
Merge branch 'feature/earlystop' of https://github.com/nijkah/mmengin…
nijkah Nov 30, 2022
fa03d57
Merge remote-tracking branch 'origin/main' into feature/earlystop
nijkah Feb 2, 2023
bbd482c
Update suggestions
nijkah Feb 2, 2023
17a824c
Merge branch 'main' into feature/earlystop
zhouzaida Feb 22, 2023
34a4f41
Update mmengine/hooks/__init__.py
zhouzaida Feb 22, 2023
b84bbce
fix min_delta
zhouzaida Feb 22, 2023
6482ce6
Apply suggestions from code review
zhouzaida Feb 23, 2023
bccd43c
lint
nijkah Feb 23, 2023
bb6f31a
Apply suggestions from code review
nijkah Feb 23, 2023
4b40655
delete save_last
nijkah Feb 28, 2023
1ade543
infer rule more robust
zhouzaida Mar 5, 2023
946a8a4
refine unit test
HAOCHENYE Mar 6, 2023
046d7be
Update mmengine/hooks/early_stopping_hook.py
zhouzaida Mar 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/api/hooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ mmengine.hooks
ProfilerHook
NPUProfilerHook
PrepareTTAHook
EarlyStoppingHook
1 change: 1 addition & 0 deletions docs/zh_cn/api/hooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ mmengine.hooks
ProfilerHook
NPUProfilerHook
PrepareTTAHook
EarlyStoppingHook
3 changes: 2 additions & 1 deletion mmengine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .checkpoint_hook import CheckpointHook
from .early_stopping_hook import EarlyStoppingHook
from .ema_hook import EMAHook
from .empty_cache_hook import EmptyCacheHook
from .hook import Hook
Expand All @@ -17,5 +18,5 @@
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', 'LoggerHook',
'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook', 'ProfilerHook',
'NPUProfilerHook', 'PrepareTTAHook'
'PrepareTTAHook', 'NPUProfilerHook', 'EarlyStoppingHook'
]
141 changes: 141 additions & 0 deletions mmengine/hooks/early_stopping_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from math import inf, isfinite
from typing import Optional, Tuple, Union

from mmengine.registry import HOOKS
from .hook import Hook

DATA_BATCH = Optional[Union[dict, tuple, list]]


@HOOKS.register_module()
class EarlyStoppingHook(Hook):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider recording the state of early stopping to support resuming training.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MessageHub will save all history metrics during training, maybe we could utilize it to resume training

"""Early stop the training when the monitored metric reached a plateau.

Args:
monitor (str): The monitored metric key to decide early stopping.
rule (str, optional): Comparison rule. Options are 'greater',
'less'. Defaults to None.
min_delta (float, optional): Minimum difference to continue the
training. Defaults to 0.01.
strict (bool, optional): Whether to crash the training when `monitor`
is not found in the `metrics`. Defaults to False.
check_finite: Whether to stop training when the monitor becomes NaN or
infinite. Defaults to True.
patience (int, optional): The times of validation with no improvement
after which training will be stopped. Defaults to 5.
stopping_threshold (float, optional): Stop training immediately once
the monitored quantity reaches this threshold. Defaults to None.

Note:
`New in version 0.6.0.`
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
"""
priority = 'LOWEST'

rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
_default_greater_keys = [
'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
'mAcc', 'aAcc'
]
_default_less_keys = ['loss']

def __init__(
self,
monitor: str,
rule: Optional[str] = None,
min_delta: float = 0.1,
strict: bool = False,
check_finite: bool = True,
patience: int = 5,
stopping_threshold: Optional[float] = None,
):

self.monitor = monitor
if monitor in self._default_greater_keys:
rule = 'greater'
elif monitor in self._default_less_keys:
rule = 'less'
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
assert rule in ['greater', 'less'], \
'`rule` should be either "greater" or "less".'
self.rule = rule
self.min_delta = min_delta if rule == 'greater' else -1 * min_delta
self.strict = strict
self.check_finite = check_finite
self.patience = patience
self.stopping_threshold = stopping_threshold

self.wait_count = 0
self.best_score = -inf if rule == 'greater' else inf

def _check_stop_condition(self, current_score: float) -> Tuple[bool, str]:
compare = self.rule_map[self.rule]
stop_training = False
reason_message = ''

if self.check_finite and not isfinite(current_score):
stop_training = True
reason_message = (f'Monitored metric {self.monitor} = '
f'{current_score} is infinite. '
f'Previous best value was '
f'{self.best_score:.3f}.')

elif self.stopping_threshold is not None and compare(
current_score, self.stopping_threshold):
stop_training = True
self.best_score = current_score
reason_message = (f'Stopping threshold reached: '
f'`{self.monitor}` = {current_score} is '
f'{self.rule} than {self.stopping_threshold}.')
elif compare(self.best_score + self.min_delta, current_score):

self.wait_count += 1

if self.wait_count >= self.patience:
reason_message = (f'the monitored metric did not improve '
f'in the last {self.wait_count} records. '
f'best score: {self.best_score:.3f}. ')
stop_training = True
else:
self.best_score = current_score
self.wait_count = 0

return stop_training, reason_message

def before_run(self, runner) -> None:
"""Check `stop_training` variable in `runner.train_loop`.

Args:
runner (Runner): The runner of the training process.
"""

assert hasattr(runner.train_loop, 'stop_training'), \
'`train_loop` should contain `stop_training` variable.'

def after_val_epoch(self, runner, metrics):
nijkah marked this conversation as resolved.
Show resolved Hide resolved
"""Decide whether to stop the training process.

Args:
runner (Runner): The runner of the training process.
metrics (dict): Evaluation results of all metrics
"""

if self.monitor not in metrics:
if self.strict:
raise RuntimeError(
'Early stopping conditioned on metric '
f'`{self.monitor} is not available. Please check available'
f' metrics {metrics}, or set `strict=False` in '
'`EarlyStoppingHook`.')
warnings.warn(
'Skip early stopping process since the evaluation '
f'results ({metrics.keys()}) do not include `monitor` '
f'({self.monitor}).')
return

current_score = metrics[self.monitor]

stop_training, message = self._check_stop_condition(current_score)
if stop_training:
runner.train_loop.stop_training = True
runner.logger.info(message)
10 changes: 8 additions & 2 deletions mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def __init__(
self._iter = 0
self.val_begin = val_begin
self.val_interval = val_interval
# This attribute will be updated by `EarlyStoppingHook`
# when it is enabled.
self.stop_training = False
nijkah marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(self.dataloader.dataset, 'metainfo'):
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
Expand Down Expand Up @@ -86,7 +89,7 @@ def run(self) -> torch.nn.Module:
"""Launch training."""
self.runner.call_hook('before_train')

while self._epoch < self._max_epochs:
while self._epoch < self._max_epochs and not self.stop_training:
self.run_epoch()

self._decide_current_val_interval()
Expand Down Expand Up @@ -216,6 +219,9 @@ def __init__(
self._iter = 0
self.val_begin = val_begin
self.val_interval = val_interval
# This attribute will be updated by `EarlyStoppingHook`
# when it is enabled.
self.stop_training = False
nijkah marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(self.dataloader.dataset, 'metainfo'):
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
Expand Down Expand Up @@ -257,7 +263,7 @@ def run(self) -> None:
# In iteration-based training loop, we treat the whole training process
# as a big epoch and execute the corresponding hook.
self.runner.call_hook('before_train_epoch')
while self._iter < self._max_iters:
while self._iter < self._max_iters and not self.stop_training:
self.runner.model.train()

data_batch = next(self.dataloader_iterator)
Expand Down