-
Notifications
You must be signed in to change notification settings - Fork 323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support EarlyStoppingHook #739
Merged
Merged
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
753bf2c
[Feature] EarlyStoppingHook
nijkah 833d5fc
delete redundant line
nijkah 8e2bd76
Assert stop_training and rename tests
nijkah 3e07e10
Fix UT
nijkah b20112b
rename `metric` to `monitor`
nijkah 67dc666
Fix UT
nijkah 1808b79
Fix UT
nijkah 1107993
edit docstring on patience
nijkah 3e35b41
Draft for new code
nijkah de0ae9f
fix ut
nijkah c991f35
add test case
nijkah e4a20be
add test case
nijkah 7caec6f
fix ut
nijkah 4251d31
Apply suggestions from code review
nijkah ea87f73
Apply suggestions from code review
nijkah 79869c2
Append hook
nijkah 52edf9d
Append hook
nijkah caa7187
Apply suggestions
nijkah e1e812c
Merge branch 'feature/earlystop' of https://github.com/nijkah/mmengin…
nijkah fa03d57
Merge remote-tracking branch 'origin/main' into feature/earlystop
nijkah bbd482c
Update suggestions
nijkah 17a824c
Merge branch 'main' into feature/earlystop
zhouzaida 34a4f41
Update mmengine/hooks/__init__.py
zhouzaida b84bbce
fix min_delta
zhouzaida 6482ce6
Apply suggestions from code review
zhouzaida bccd43c
lint
nijkah bb6f31a
Apply suggestions from code review
nijkah 4b40655
delete save_last
nijkah 1ade543
infer rule more robust
zhouzaida 946a8a4
refine unit test
HAOCHENYE 046d7be
Update mmengine/hooks/early_stopping_hook.py
zhouzaida File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,3 +25,4 @@ mmengine.hooks | |
ProfilerHook | ||
NPUProfilerHook | ||
PrepareTTAHook | ||
EarlyStoppingHook |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,3 +25,4 @@ mmengine.hooks | |
ProfilerHook | ||
NPUProfilerHook | ||
PrepareTTAHook | ||
EarlyStoppingHook |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import warnings | ||
from math import inf, isfinite | ||
from typing import Optional, Tuple, Union | ||
|
||
from mmengine.registry import HOOKS | ||
from .hook import Hook | ||
|
||
DATA_BATCH = Optional[Union[dict, tuple, list]] | ||
|
||
|
||
@HOOKS.register_module() | ||
class EarlyStoppingHook(Hook): | ||
"""Early stop the training when the monitored metric reached a plateau. | ||
|
||
Args: | ||
monitor (str): The monitored metric key to decide early stopping. | ||
rule (str, optional): Comparison rule. Options are 'greater', | ||
'less'. Defaults to None. | ||
min_delta (float, optional): Minimum difference to continue the | ||
training. Defaults to 0.01. | ||
strict (bool, optional): Whether to crash the training when `monitor` | ||
is not found in the `metrics`. Defaults to False. | ||
check_finite: Whether to stop training when the monitor becomes NaN or | ||
infinite. Defaults to True. | ||
patience (int, optional): The times of validation with no improvement | ||
after which training will be stopped. Defaults to 5. | ||
stopping_threshold (float, optional): Stop training immediately once | ||
the monitored quantity reaches this threshold. Defaults to None. | ||
|
||
Note: | ||
`New in version 0.7.0.` | ||
""" | ||
priority = 'LOWEST' | ||
|
||
rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y} | ||
_default_greater_keys = [ | ||
'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU', | ||
'mAcc', 'aAcc' | ||
] | ||
_default_less_keys = ['loss'] | ||
|
||
def __init__( | ||
self, | ||
monitor: str, | ||
rule: Optional[str] = None, | ||
min_delta: float = 0.1, | ||
strict: bool = False, | ||
check_finite: bool = True, | ||
patience: int = 5, | ||
stopping_threshold: Optional[float] = None, | ||
): | ||
|
||
self.monitor = monitor | ||
if rule is not None: | ||
if rule not in ['greater', 'less']: | ||
raise ValueError( | ||
'`rule` should be either "greater" or "less", ' | ||
f'but got {rule}') | ||
else: | ||
rule = self._init_rule(monitor) | ||
self.rule = rule | ||
self.min_delta = min_delta if rule == 'greater' else -1 * min_delta | ||
self.strict = strict | ||
self.check_finite = check_finite | ||
self.patience = patience | ||
self.stopping_threshold = stopping_threshold | ||
|
||
self.wait_count = 0 | ||
self.best_score = -inf if rule == 'greater' else inf | ||
|
||
def _init_rule(self, monitor: str) -> str: | ||
greater_keys = {key.lower() for key in self._default_greater_keys} | ||
less_keys = {key.lower() for key in self._default_less_keys} | ||
monitor_lc = monitor.lower() | ||
if monitor_lc in greater_keys: | ||
rule = 'greater' | ||
elif monitor_lc in less_keys: | ||
rule = 'less' | ||
elif any(key in monitor_lc for key in greater_keys): | ||
rule = 'greater' | ||
elif any(key in monitor_lc for key in less_keys): | ||
rule = 'less' | ||
else: | ||
raise ValueError(f'Cannot infer the rule for {monitor}, thus rule ' | ||
'must be specified.') | ||
return rule | ||
|
||
def _check_stop_condition(self, current_score: float) -> Tuple[bool, str]: | ||
compare = self.rule_map[self.rule] | ||
stop_training = False | ||
reason_message = '' | ||
|
||
if self.check_finite and not isfinite(current_score): | ||
stop_training = True | ||
reason_message = (f'Monitored metric {self.monitor} = ' | ||
f'{current_score} is infinite. ' | ||
f'Previous best value was ' | ||
f'{self.best_score:.3f}.') | ||
|
||
elif self.stopping_threshold is not None and compare( | ||
current_score, self.stopping_threshold): | ||
stop_training = True | ||
self.best_score = current_score | ||
reason_message = (f'Stopping threshold reached: ' | ||
f'`{self.monitor}` = {current_score} is ' | ||
f'{self.rule} than {self.stopping_threshold}.') | ||
elif compare(self.best_score + self.min_delta, current_score): | ||
|
||
self.wait_count += 1 | ||
|
||
if self.wait_count >= self.patience: | ||
reason_message = (f'the monitored metric did not improve ' | ||
f'in the last {self.wait_count} records. ' | ||
f'best score: {self.best_score:.3f}. ') | ||
stop_training = True | ||
else: | ||
self.best_score = current_score | ||
self.wait_count = 0 | ||
|
||
return stop_training, reason_message | ||
|
||
def before_run(self, runner) -> None: | ||
"""Check `stop_training` variable in `runner.train_loop`. | ||
|
||
Args: | ||
runner (Runner): The runner of the training process. | ||
""" | ||
|
||
assert hasattr(runner.train_loop, 'stop_training'), \ | ||
'`train_loop` should contain `stop_training` variable.' | ||
|
||
def after_val_epoch(self, runner, metrics): | ||
nijkah marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Decide whether to stop the training process. | ||
|
||
Args: | ||
runner (Runner): The runner of the training process. | ||
metrics (dict): Evaluation results of all metrics | ||
""" | ||
|
||
if self.monitor not in metrics: | ||
if self.strict: | ||
raise RuntimeError( | ||
'Early stopping conditioned on metric ' | ||
f'`{self.monitor} is not available. Please check available' | ||
f' metrics {metrics}, or set `strict=False` in ' | ||
'`EarlyStoppingHook`.') | ||
warnings.warn( | ||
'Skip early stopping process since the evaluation ' | ||
f'results ({metrics.keys()}) do not include `monitor` ' | ||
f'({self.monitor}).') | ||
return | ||
|
||
current_score = metrics[self.monitor] | ||
|
||
stop_training, message = self._check_stop_condition(current_score) | ||
if stop_training: | ||
runner.train_loop.stop_training = True | ||
runner.logger.info(message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider recording the state of early stopping to support resuming training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MessageHub
will save all history metrics during training, maybe we could utilize it to resume training