Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support EarlyStoppingHook #739

Merged
merged 31 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
753bf2c
[Feature] EarlyStoppingHook
nijkah Nov 17, 2022
833d5fc
delete redundant line
nijkah Nov 17, 2022
8e2bd76
Assert stop_training and rename tests
nijkah Nov 18, 2022
3e07e10
Fix UT
nijkah Nov 18, 2022
b20112b
rename `metric` to `monitor`
nijkah Nov 18, 2022
67dc666
Fix UT
nijkah Nov 18, 2022
1808b79
Fix UT
nijkah Nov 19, 2022
1107993
edit docstring on patience
nijkah Nov 21, 2022
3e35b41
Draft for new code
nijkah Nov 27, 2022
de0ae9f
fix ut
nijkah Nov 27, 2022
c991f35
add test case
nijkah Nov 27, 2022
e4a20be
add test case
nijkah Nov 27, 2022
7caec6f
fix ut
nijkah Nov 27, 2022
4251d31
Apply suggestions from code review
nijkah Nov 28, 2022
ea87f73
Apply suggestions from code review
nijkah Nov 30, 2022
79869c2
Append hook
nijkah Nov 30, 2022
52edf9d
Append hook
nijkah Nov 30, 2022
caa7187
Apply suggestions
nijkah Nov 30, 2022
e1e812c
Merge branch 'feature/earlystop' of https://github.com/nijkah/mmengin…
nijkah Nov 30, 2022
fa03d57
Merge remote-tracking branch 'origin/main' into feature/earlystop
nijkah Feb 2, 2023
bbd482c
Update suggestions
nijkah Feb 2, 2023
17a824c
Merge branch 'main' into feature/earlystop
zhouzaida Feb 22, 2023
34a4f41
Update mmengine/hooks/__init__.py
zhouzaida Feb 22, 2023
b84bbce
fix min_delta
zhouzaida Feb 22, 2023
6482ce6
Apply suggestions from code review
zhouzaida Feb 23, 2023
bccd43c
lint
nijkah Feb 23, 2023
bb6f31a
Apply suggestions from code review
nijkah Feb 23, 2023
4b40655
delete save_last
nijkah Feb 28, 2023
1ade543
infer rule more robust
zhouzaida Mar 5, 2023
946a8a4
refine unit test
HAOCHENYE Mar 6, 2023
046d7be
Update mmengine/hooks/early_stopping_hook.py
zhouzaida Mar 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 13 additions & 3 deletions mmengine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .checkpoint_hook import CheckpointHook
from .early_stopping_hook import EarlyStoppingHook
from .ema_hook import EMAHook
from .empty_cache_hook import EmptyCacheHook
from .hook import Hook
Expand All @@ -12,7 +13,16 @@
from .sync_buffer_hook import SyncBuffersHook

__all__ = [
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', 'LoggerHook',
'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook'
'Hook',
'IterTimerHook',
'DistSamplerSeedHook',
'ParamSchedulerHook',
'SyncBuffersHook',
'EmptyCacheHook',
'CheckpointHook',
'LoggerHook',
'NaiveVisualizationHook',
'EMAHook',
'RuntimeInfoHook',
'EarlyStoppingHook',
nijkah marked this conversation as resolved.
Show resolved Hide resolved
]
98 changes: 98 additions & 0 deletions mmengine/hooks/early_stopping_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import List, Optional, Union

from mmengine.registry import HOOKS
from .hook import Hook

DATA_BATCH = Optional[Union[dict, tuple, list]]


@HOOKS.register_module()
class EarlyStoppingHook(Hook):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider recording the state of early stopping to support resuming training.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MessageHub will save all history metrics during training, maybe we could utilize it to resume training

"""Early stop the training when the metric reached a plateau.

Args:
metric (str): The metric key to decide early stopping.
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
rule (str, optional): Comparison rule. Options are 'greater',
'less'. Defaults to None.
delta(float, optional): Minimum difference to continue the training.
Defaults to 0.01.
pool_size (int, optional): The number of experiments to consider.
Defaults to 5.
patience (int, optional): Maximum number of tolerance.
nijkah marked this conversation as resolved.
Show resolved Hide resolved
Defaults to 0.
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
"""
priority = 'LOWEST'

rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
_default_greater_keys = [
'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
'mAcc', 'aAcc'
]
_default_less_keys = ['loss']

def __init__(
self,
metric: str,
rule: str = None,
nijkah marked this conversation as resolved.
Show resolved Hide resolved
delta: float = 0.1,
pool_size: int = 5,
patience: int = 0,
):

self.metric = metric
if metric in self._default_greater_keys:
rule = 'greater'
elif metric in self._default_less_keys:
rule = 'less'
assert rule in ['greater', 'less'], \
'`rule` should be either \'greater\' or \'less\'.'
nijkah marked this conversation as resolved.
Show resolved Hide resolved
self.rule = rule
self.delta = delta
self.pool_size = pool_size
self.patience = patience
self.count = 0

self.pool_values: List[float] = []

def after_val_epoch(self, runner, metrics):
nijkah marked this conversation as resolved.
Show resolved Hide resolved
"""Decide whether to stop the training process.

Args:
runner (Runner): The runner of the training process.
metrics (dict): Evaluation results of all metrics
"""

if self.metric not in metrics:
warnings.warn(
f'Skip early stopping process since the evaluation '
f'results ({metrics.keys()}) do not include `metric` '
f'({self.metric}).')
return

latest_value = metrics[self.metric]
compare = self.rule_map[self.rule]

self.pool_values.append(latest_value)

if self.rule == 'greater':
# maintain largest values
self.pool_values = sorted(self.pool_values)[-self.pool_size:]
else:
# maintain smalleast values
self.pool_values = sorted(self.pool_values)[:self.pool_size]
nijkah marked this conversation as resolved.
Show resolved Hide resolved

if len(self.pool_values) == self.pool_size and compare(
sum(self.pool_values) / self.pool_size + self.delta,
nijkah marked this conversation as resolved.
Show resolved Hide resolved
latest_value):

self.count += 1

if self.count >= self.patience:
runner.train_loop.stop_training = True
runner.logger.info(
'The metric reached a plateau. '
'This training process will be stopped early.')
else:
self.count = 0
6 changes: 4 additions & 2 deletions mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
self._iter = 0
self.val_begin = val_begin
self.val_interval = val_interval
self.stop_training = False
nijkah marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(self.dataloader.dataset, 'metainfo'):
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
Expand Down Expand Up @@ -86,7 +87,7 @@ def run(self) -> torch.nn.Module:
"""Launch training."""
self.runner.call_hook('before_train')

while self._epoch < self._max_epochs:
while self._epoch < self._max_epochs and not self.stop_training:
self.run_epoch()

self._decide_current_val_interval()
Expand Down Expand Up @@ -216,6 +217,7 @@ def __init__(
self._iter = 0
self.val_begin = val_begin
self.val_interval = val_interval
self.stop_training = False
nijkah marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(self.dataloader.dataset, 'metainfo'):
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
Expand Down Expand Up @@ -257,7 +259,7 @@ def run(self) -> None:
# In iteration-based training loop, we treat the whole training process
# as a big epoch and execute the corresponding hook.
self.runner.call_hook('before_train_epoch')
while self._iter < self._max_iters:
while self._iter < self._max_iters and not self.stop_training:
self.runner.model.train()

data_batch = next(self.dataloader_iterator)
Expand Down
179 changes: 179 additions & 0 deletions tests/test_hooks/test_early_stopping_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from unittest.mock import Mock

import pytest
import torch
import torch.nn as nn
from torch.utils.data import Dataset

from mmengine.evaluator import BaseMetric
from mmengine.hooks import EarlyStoppingHook
from mmengine.logging import MessageHub
from mmengine.model import BaseModel
from mmengine.optim import OptimWrapper
from mmengine.runner import Runner


class ToyModel(BaseModel):

def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)

def forward(self, inputs, data_sample, mode='tensor'):
labels = torch.stack(data_sample)
inputs = torch.stack(inputs)
outputs = self.linear(inputs)
if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
else:
return outputs


class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)

@property
def metainfo(self):
return self.METAINFO

def __len__(self):
return self.data.size(0)

def __getitem__(self, index):
return dict(inputs=self.data[index], data_sample=self.label[index])


class TriangleMetric(BaseMetric):

default_prefix: str = 'test'

def __init__(self, length):
super().__init__()
self.length = length
self.best_idx = length // 2
self.cur_idx = 0
self.vals = [90, 91, 92, 93, 94, 93] * 2

def process(self, *args, **kwargs):
self.results.append(0)

def compute_metrics(self, *args, **kwargs):
acc = self.vals[self.cur_idx]
self.cur_idx += 1
return dict(acc=acc)


def get_mock_runner():
runner = Mock()
runner.train_loop = Mock()
runner.train_loop.stop_training = False
runner.message_hub = MessageHub.get_instance('test_after_val_epoch')
return runner


class TestCheckpointHook:

def test_init(self):

hook = EarlyStoppingHook(metric='acc')
assert hook.rule == 'greater'

hook = EarlyStoppingHook(metric='loss')
assert hook.rule == 'less'

with pytest.raises(AssertionError):
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
EarlyStoppingHook(metric='accuracy/top1', rule='the world')

def test_after_val_epoch(self, tmp_path):
runner = get_mock_runner()

# if `metric` does not match, skip the hook.
with pytest.warns(UserWarning) as record_warnings:
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
metrics = {'accuracy/top1': 0.5, 'loss': 0.23}
hook = EarlyStoppingHook(metric='acc', rule='greater')
hook.after_val_epoch(runner, metrics)

# Since there will be many warnings thrown, we just need to check
# if the expected exceptions are thrown
expected_message = (
f'Skip early stopping process since the evaluation results '
f'({metrics.keys()}) do not include `metric` ({hook.metric}).')
for warning in record_warnings:
if str(warning.message) == expected_message:
break
else:
assert False

# Check largest 5 values
runner = get_mock_runner()
metrics = [{'accuracy/top1': i / 10.} for i in range(8)]
hook = EarlyStoppingHook(metric='accuracy/top1', rule='greater')
for metric in metrics:
hook.after_val_epoch(runner, metric)
assert all([i / 10 in hook.pool_values for i in range(3, 8)])

# Check smalleast 3 values
runner = get_mock_runner()
metrics = [{'loss': i / 10.} for i in range(8)]
hook = EarlyStoppingHook(metric='loss', pool_size=3)
for metric in metrics:
hook.after_val_epoch(runner, metric)
assert all([i / 10 in hook.pool_values for i in range(3)])

# Check stop training
runner = get_mock_runner()
metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)]
hook = EarlyStoppingHook(
metric='accuracy/top1', rule='greater', delta=1)
for metric in metrics:
hook.after_val_epoch(runner, metric)
assert runner.train_loop.stop_training
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved

# Check patience
runner = get_mock_runner()
metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)]
hook = EarlyStoppingHook(
metric='accuracy/top1', rule='greater', delta=1, patience=5)
for metric in metrics:
hook.after_val_epoch(runner, metric)
assert not runner.train_loop.stop_training

def test_with_runner(self, tmp_path):
max_epoch = 10
work_dir = osp.join(str(tmp_path), 'runner_test')
checkpoint_cfg = dict(
type='EarlyStoppingHook',
metric='test/acc',
rule='greater',
delta=0.4,
)
runner = Runner(
model=ToyModel(),
work_dir=work_dir,
train_dataloader=dict(
dataset=DummyDataset(),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
val_dataloader=dict(
dataset=DummyDataset(),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
val_evaluator=dict(type=TriangleMetric, length=max_epoch),
optim_wrapper=OptimWrapper(
torch.optim.Adam(ToyModel().parameters())),
train_cfg=dict(
by_epoch=True, max_epochs=max_epoch, val_interval=1),
val_cfg=dict(),
default_hooks=dict(checkpoint=checkpoint_cfg))
runner.train()
assert runner.epoch == 7