Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp #28] support ReduceOnPlateau #819

Merged
merged 48 commits into from Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
e41c8f6
[Feature] Add ReduceOnPlateauParamScheduler and change ParamScheduler…
LEFTeyex Dec 6, 2022
bbab5e3
[Feature] add ReduceOnPlateauLR and ReduceOnPlateauMomentum
LEFTeyex Dec 13, 2022
28ffc49
pre-commit check
LEFTeyex Dec 13, 2022
97419c9
Merge branch 'main' of github.com:open-mmlab/mmengine into LEFTeyes/a…
LEFTeyex Dec 13, 2022
aadaf4d
add a little docs
LEFTeyex Dec 13, 2022
3ac869c
change position
LEFTeyex Dec 13, 2022
a17217a
fix the conflict between isort and yapf
LEFTeyex Dec 13, 2022
68085fe
fix ParamSchedulerHook after_val_epoch execute without train_loop and…
LEFTeyex Dec 13, 2022
68bddd8
Apply suggestions from code review
LEFTeyex Dec 16, 2022
97f0468
update ReduceOnPlateauParamScheduler, ReduceOnPlateauMomentum and Par…
LEFTeyex Dec 17, 2022
83f4313
fix get need_step_args attribute error in ParamSchedulerHook
LEFTeyex Dec 17, 2022
f84fd9e
fix load_state_dict error for rule in ReduceOnPlateauParamScheduler
LEFTeyex Dec 20, 2022
f116a50
add docs for ParamSchedulerHook and fix a few codes
LEFTeyex Dec 21, 2022
57a1ee5
[Docs] add ReduceOnPlateauParamScheduler, ReduceOnPlateauMomentum and…
LEFTeyex Dec 22, 2022
e19e34e
[Refactor] adjust the order of import
LEFTeyex Dec 22, 2022
7b9e19f
[Fix] add init check for threshold in ReduceOnPlateauParamScheduler
LEFTeyex Dec 22, 2022
9e0e0c7
[Test] add test for ReduceOnPlateauParamScheduler, ReduceOnPlateauLR …
LEFTeyex Dec 22, 2022
0978665
Merge branch 'main' of github.com:open-mmlab/mmengine into LEFTeyes/a…
LEFTeyex Dec 22, 2022
70cf761
[Fix] fix no attribute self.min_value
LEFTeyex Dec 22, 2022
4ecde19
[Fix] fix numerical problem in tests
LEFTeyex Dec 23, 2022
c8fcfae
[Fix] fix error in tests
LEFTeyex Dec 23, 2022
bb2a535
[Fix] fix ignore first param in tests
LEFTeyex Dec 23, 2022
11497d0
[Fix] fix bug in tests
LEFTeyex Dec 23, 2022
284bb6c
[Fix] fix bug in tests
LEFTeyex Dec 23, 2022
c309081
[Fix] fix bug in tests
LEFTeyex Dec 23, 2022
c80e7f0
[Fix] increase coverage
LEFTeyex Dec 23, 2022
ce4b4e2
[Fix] fix count self._global_step bug and docs
LEFTeyex Dec 23, 2022
f07110c
[Fix] fix tests
LEFTeyex Dec 23, 2022
f93ea75
[Fix] modified ParamSchedulerHook test
LEFTeyex Dec 23, 2022
2acec59
Update mmengine/optim/scheduler/param_scheduler.py
LEFTeyex Dec 29, 2022
37c4fce
Apply suggestions from code review
LEFTeyex Dec 29, 2022
98280b2
[Fix] modified something according to commented
LEFTeyex Dec 29, 2022
66046ea
[Docs] add api for en and zh_cn
LEFTeyex Dec 29, 2022
9fcacbf
[Fix] fix bug in test_param_scheduler_hook.py
LEFTeyex Jan 2, 2023
6058dd1
[Test] support more complicated test modes(less, greater, rel, abs) f…
LEFTeyex Jan 2, 2023
f3ab5bf
[Docs] add docs for rule
LEFTeyex Jan 6, 2023
3d6cbd8
[Fix] fix pop from empty list bug in test
LEFTeyex Jan 6, 2023
00d40d4
[Fix] fix check param_schedulers is not built bug
LEFTeyex Jan 6, 2023
84c114e
Merge branch 'main' of github.com:open-mmlab/mmengine into LEFTeyes/a…
LEFTeyex Jan 6, 2023
c405e56
[Fix] fix step_args bug and without runner._train_loop bug
LEFTeyex Jan 6, 2023
8c69c5d
[Fix] fix step_args bug and without runner._train_loop bug
LEFTeyex Jan 6, 2023
d4f6641
[Fix] fix scheduler type bug
LEFTeyex Jan 6, 2023
85f98e3
[Test] rename step_args to step_kwargs
LEFTeyex Jan 9, 2023
516f9a3
[Fix] remove redundancy check
LEFTeyex Jan 9, 2023
61798cd
[Test] remove redundancy check
LEFTeyex Jan 9, 2023
9504c6d
Apply suggestions from code review
LEFTeyex Jan 15, 2023
9b561d7
[Test] fix some defects
LEFTeyex Jan 15, 2023
c96b84f
Merge branch 'main' of github.com:open-mmlab/mmengine into LEFTeyes/a…
LEFTeyex Jan 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/en/api/optim.rst
Expand Up @@ -63,3 +63,6 @@ Scheduler
StepLR
StepMomentum
StepParamScheduler
ReduceOnPlateauLR
ReduceOnPlateauMomentum
ReduceOnPlateauParamScheduler
3 changes: 3 additions & 0 deletions docs/zh_cn/api/optim.rst
Expand Up @@ -63,3 +63,6 @@ Scheduler
StepLR
StepMomentum
StepParamScheduler
ReduceOnPlateauLR
ReduceOnPlateauMomentum
ReduceOnPlateauParamScheduler
80 changes: 68 additions & 12 deletions mmengine/hooks/param_scheduler_hook.py
@@ -1,7 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
from typing import Dict, Optional, Union

from mmengine.optim import _ParamScheduler
from mmengine.registry import HOOKS
from mmengine.runner import BaseLoop
from mmengine.utils import is_list_of
from .hook import Hook

DATA_BATCH = Optional[Union[dict, tuple, list]]
Expand All @@ -19,7 +22,7 @@ def after_train_iter(self,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""Call step function for each scheduler after each iteration.
"""Call step function for each scheduler after each training iteration.

Args:
runner (Runner): The runner of the training process.
Expand All @@ -32,15 +35,15 @@ def after_train_iter(self,
keep ``data_batch`` here.
"""

def step(param_schedulers):
assert isinstance(param_schedulers, list)
for scheduler in param_schedulers:
if not scheduler.by_epoch:
scheduler.step()

if runner.param_schedulers is None:
return

def step(_param_schedulers):
assert isinstance(_param_schedulers, list)
for scheduler in _param_schedulers:
if not scheduler.by_epoch:
scheduler.step()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to change the position of step method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Save the time to build step, when runner.param_schedulers is None.

Copy link
Member

@zhouzaida zhouzaida Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. BTW, it is unnecessary to rename param_schedulers to _param_schedulers in step method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change _param_schedulers back to param_schedulers.


if isinstance(runner.param_schedulers, list):
step(runner.param_schedulers)
elif isinstance(runner.param_schedulers, dict):
Expand All @@ -53,21 +56,74 @@ def step(param_schedulers):
f'but got {runner.param_schedulers}')

def after_train_epoch(self, runner) -> None:
"""Call step function for each scheduler after each epoch.
"""Call step function for each scheduler after each training epoch.

Args:
runner (Runner): The runner of the training process.
"""

def step(param_schedulers):
assert isinstance(param_schedulers, list)
for scheduler in param_schedulers:
if runner.param_schedulers is None:
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
return

def step(_param_schedulers):
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(_param_schedulers, list)
for scheduler in _param_schedulers:
if scheduler.by_epoch:
scheduler.step()

if isinstance(runner.param_schedulers, list):
step(runner.param_schedulers)
elif isinstance(runner.param_schedulers, dict):
for param_schedulers in runner.param_schedulers.values():
step(param_schedulers)
else:
raise TypeError(
'runner.param_schedulers should be list of ParamScheduler or '
'a dict containing list of ParamScheduler, '
f'but got {runner.param_schedulers}')

def after_val_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""Call step function for each scheduler which has attribute
``need_val_args`` after each validation epoch.

Args:
runner (Runner): The runner of the validation process.
metrics (Dict[str, float], optional): Evaluation results of all
metrics on validation dataset. The keys are the names of the
metrics, and the values are corresponding results.

Note:
if ``runner._train_loop`` or ``runner.param_schedulers``
is not built before, the hook ``after_val_epoch`` will be skipped.
"""

if runner.param_schedulers is None:
return

# avoid counting scheduler._global_step
# it has counted in after_train_* hook
if metrics is None:
return

# check train_loop is built
# to avoid execute schedulers without training
# Need to skip building train_loop when call runner.train_loop,
# so use runner._train_loop. This is a hacky approach.
if not isinstance(runner._train_loop, BaseLoop):
return
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

def step(_param_schedulers):
# check param_schedulers is list and built
if not is_list_of(_param_schedulers, _ParamScheduler):
return

for scheduler in _param_schedulers:
if (scheduler.by_epoch
and getattr(scheduler, 'need_val_args', False)):
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
scheduler.step(metrics)

if isinstance(runner.param_schedulers, list):
step(runner.param_schedulers)
elif isinstance(runner.param_schedulers, dict):
Expand Down
9 changes: 6 additions & 3 deletions mmengine/optim/__init__.py
Expand Up @@ -11,8 +11,10 @@
MultiStepLR, MultiStepMomentum,
MultiStepParamScheduler, OneCycleLR,
OneCycleParamScheduler, PolyLR, PolyMomentum,
PolyParamScheduler, StepLR, StepMomentum,
StepParamScheduler, _ParamScheduler)
PolyParamScheduler, ReduceOnPlateauLR,
ReduceOnPlateauMomentum, ReduceOnPlateauParamScheduler,
StepLR, StepMomentum, StepParamScheduler,
_ParamScheduler)

# yapf: enable
__all__ = [
Expand All @@ -25,5 +27,6 @@
'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler',
'_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'OptimWrapperDict',
'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR', 'PolyMomentum',
'PolyParamScheduler'
'PolyParamScheduler', 'ReduceOnPlateauLR', 'ReduceOnPlateauMomentum',
'ReduceOnPlateauParamScheduler'
]
14 changes: 8 additions & 6 deletions mmengine/optim/scheduler/__init__.py
Expand Up @@ -2,21 +2,22 @@
# yapf: disable
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, CosineRestartLR,
ExponentialLR, LinearLR, MultiStepLR, OneCycleLR,
PolyLR, StepLR)
PolyLR, ReduceOnPlateauLR, StepLR)
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
CosineRestartMomentum, ExponentialMomentum,
LinearMomentum, MultiStepMomentum,
PolyMomentum, StepMomentum)
PolyMomentum, ReduceOnPlateauMomentum,
StepMomentum)
from .param_scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
CosineRestartParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, OneCycleParamScheduler,
PolyParamScheduler, StepParamScheduler,
_ParamScheduler)
PolyParamScheduler,
ReduceOnPlateauParamScheduler,
StepParamScheduler, _ParamScheduler)

# yapf: enable

__all__ = [
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
'MultiStepLR', 'StepLR', 'ConstantMomentum', 'CosineAnnealingMomentum',
Expand All @@ -26,5 +27,6 @@
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler',
'PolyParamScheduler', 'PolyLR', 'PolyMomentum', 'OneCycleParamScheduler',
'OneCycleLR', 'CosineRestartParamScheduler', 'CosineRestartLR',
'CosineRestartMomentum'
'CosineRestartMomentum', 'ReduceOnPlateauParamScheduler',
'ReduceOnPlateauLR', 'ReduceOnPlateauMomentum'
]
64 changes: 63 additions & 1 deletion mmengine/optim/scheduler/lr_scheduler.py
@@ -1,11 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.registry import PARAM_SCHEDULERS
# yapf: disable
from .param_scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
CosineRestartParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, OneCycleParamScheduler,
PolyParamScheduler, StepParamScheduler)
PolyParamScheduler,
ReduceOnPlateauParamScheduler,
StepParamScheduler)

# yapf: enable


class LRSchedulerMixin:
Expand Down Expand Up @@ -314,3 +319,60 @@ class CosineRestartLR(LRSchedulerMixin, CosineRestartParamScheduler):
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""


@PARAM_SCHEDULERS.register_module()
class ReduceOnPlateauLR(LRSchedulerMixin, ReduceOnPlateauParamScheduler):
"""Reduce the learning rate of each parameter group when a metric has
stopped improving. Models often benefit from reducing the learning rate by
a factor of 2-10 once learning stagnates. This scheduler reads a metrics
quantity and if no improvement is seen for a ``patience`` number of epochs,
the learning rate is reduced.

Args:
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
monitor (str): Key name of the value to monitor in metrics dict.
rule (str): One of `less`, `greater`. In `less` rule, learning rate
will be reduced when the quantity monitored has stopped
decreasing; in `greater` rule it will be reduced when the
quantity monitored has stopped increasing. Defaults to 'less'.
The ``rule`` is the renaming of ``mode`` in pytorch.
factor (float): Factor by which the learning rate will be
reduced. new_param = param * factor. Defaults to 0.1.
patience (int): Number of epochs with no improvement after
which learning rate will be reduced. For example, if
``patience = 2``, then we will ignore the first 2 epochs
with no improvement, and will only decrease the learning rate after
the 3rd epoch if the monitor value still hasn't improved then.
Defaults to 10.
threshold (float): Threshold for measuring the new optimum,
to only focus on significant changes. Defaults to 1e-4.
threshold_rule (str): One of `rel`, `abs`. In `rel` rule,
dynamic_threshold = best * ( 1 + threshold ) in 'greater'
rule or best * ( 1 - threshold ) in `less` rule.
In `abs` rule, dynamic_threshold = best + threshold in
`greater` rule or best - threshold in `less` rule.
Defaults to 'rel'.
cooldown (int): Number of epochs to wait before resuming
normal operation after learning rate has been reduced.
Defaults to 0.
min_value (float or list[float]): A scalar or a sequence of scalars.
A lower bound on the learning rate of each parameter group
respectively. Defaults to 0. .
eps (float): Minimal decay applied to learning rate. If the difference
between new and old learning rate is smaller than eps, the update
is ignored. Defaults to 1e-8.
begin (int): Step at which to start triggering the scheduler
to monitor in val within the interval calculated
according to epoch of training. Defaults to 0.
end (int): Step at which to stop triggering the scheduler
to monitor in val within the interval calculated
according to epoch of training. Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
82 changes: 80 additions & 2 deletions mmengine/optim/scheduler/momentum_scheduler.py
@@ -1,12 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.registry import PARAM_SCHEDULERS
# yapf: disable
from .param_scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
CosineRestartParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, PolyParamScheduler,
ReduceOnPlateauParamScheduler,
StepParamScheduler)

# yapf: enable


class MomentumSchedulerMixin:
"""A mixin class for momentum schedulers.
Expand All @@ -32,8 +36,8 @@ def __init__(self, optimizer, *args, **kwargs):
super().__init__(optimizer, param_name, *args, **kwargs)

def step(self):
"""Adjusts the parameter value of each parameter group based on the
specified schedule."""
"""Adjusts the momentum of each parameter group based on the specified
schedule."""
super().step()
if self.use_betas:
for group in self.optimizer.param_groups:
Expand Down Expand Up @@ -281,3 +285,77 @@ class CosineRestartMomentum(MomentumSchedulerMixin,
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""


@PARAM_SCHEDULERS.register_module()
class ReduceOnPlateauMomentum(MomentumSchedulerMixin,
ReduceOnPlateauParamScheduler):
"""Reduce the momentum of each parameter group when a metric has stopped
improving. Models often benefit from reducing the momentum by a factor of
2-10 once learning stagnates. This scheduler reads a metrics quantity and
if no improvement is seen for a ``patience`` number of epochs, the momentum
is reduced.

Args:
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
monitor (str): Key name of the value to monitor in metrics dict.
rule (str): One of `less`, `greater`. In `less` rule, momentum will
be reduced when the quantity monitored has stopped
decreasing; in `greater` rule it will be reduced when the
quantity monitored has stopped increasing. Defaults to 'less'.
The ``rule`` is the renaming of ``mode`` in pytorch.
factor (float): Factor by which the momentum will be
reduced. new_param = param * factor. Defaults to 0.1.
patience (int): Number of epochs with no improvement after
which momentum will be reduced. For example, if
``patience = 2``, then we will ignore the first 2 epochs
with no improvement, and will only decrease the momentum after
the 3rd epoch if the monitor value still hasn't improved then.
Defaults to 10.
threshold (float): Threshold for measuring the new optimum,
to only focus on significant changes. Defaults to 1e-4.
threshold_rule (str): One of `rel`, `abs`. In `rel` rule,
dynamic_threshold = best * ( 1 + threshold ) in 'greater'
rule or best * ( 1 - threshold ) in `less` rule.
In `abs` rule, dynamic_threshold = best + threshold in
`greater` rule or best - threshold in `less` rule.
Defaults to 'rel'.
cooldown (int): Number of epochs to wait before resuming
normal operation after momentum has been reduced. Defaults to 0.
min_value (float or list[float]): A scalar or a sequence of scalars.
A lower bound on the momentum of each parameter group
respectively. Defaults to 0. .
eps (float): Minimal decay applied to momentum. If the difference
between new and old momentum is smaller than eps, the update is
ignored. Defaults to 1e-8.
begin (int): Step at which to start triggering the scheduler
to monitor in val within the interval calculated
according to epoch of training. Defaults to 0.
end (int): Step at which to stop triggering the scheduler
to monitor in val within the interval calculated
according to epoch of training. Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""

def step(self, metrics=None):
"""Adjusts the momentum of each parameter group based on the specified
schedule.

Args:
metrics (Dict[str, float], optional): Evaluation results of all
metrics on validation dataset. The keys are the names of the
metrics, and the values are corresponding results.
Defaults to None.
"""
super(MomentumSchedulerMixin, self).step(metrics)
if self.use_betas:
for group in self.optimizer.param_groups:
_, beta_1 = group['betas']
# update the betas with the calculated value
group['betas'] = (group['momentum'], beta_1)