Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp2023-325] Find the proper learning rate #1318

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
67e9dc2
Init tuner for finding best lr
yhna940 Aug 10, 2023
b714913
Apply lint
yhna940 Aug 11, 2023
a923847
Add ex for tuning
yhna940 Aug 11, 2023
4c9ef09
Refactor to rpc
yhna940 Aug 17, 2023
3580dd8
Apply lint
yhna940 Aug 17, 2023
55364e0
Add logger to tune
yhna940 Aug 18, 2023
882271a
Fix searcher init args
yhna940 Aug 18, 2023
6285928
Apply lint
yhna940 Aug 18, 2023
0431eb0
Fix typo
yhna940 Aug 18, 2023
b5985fb
Fix minor
yhna940 Aug 21, 2023
4b5a249
Fix rpc init
yhna940 Aug 21, 2023
6846aba
Fix env for rpc
yhna940 Aug 22, 2023
a320ee1
fix rpc device map
yhna940 Aug 23, 2023
ccb8f07
Del rpc
yhna940 Aug 23, 2023
bae8605
Fix examples
yhna940 Aug 23, 2023
18fd768
Fix minor
yhna940 Aug 23, 2023
71b4b2a
Fix typo
yhna940 Aug 23, 2023
fecfacb
Split seachers
yhna940 Aug 28, 2023
010a3f1
Comment the tuner
yhna940 Aug 28, 2023
69e62a7
Rename solver of nevergrad
yhna940 Aug 28, 2023
23d1f97
Comment the report hook
yhna940 Aug 28, 2023
482a9e5
Comment the searchers
yhna940 Aug 28, 2023
04b46a3
Add readme for tune
yhna940 Aug 28, 2023
3418ddc
Add error logging
yhna940 Aug 29, 2023
92ad439
Add unittest for tune
yhna940 Aug 30, 2023
308ece3
Apply lint
yhna940 Aug 30, 2023
0767f52
Add random searcher
yhna940 Aug 30, 2023
70d91e4
Fix unittest bug
yhna940 Aug 30, 2023
1e12211
Fix tuner unittest
yhna940 Aug 31, 2023
c4a7e04
Add tuning interface for runner
yhna940 Aug 31, 2023
4d71002
Fix minor
yhna940 Aug 31, 2023
cfc3f6a
Refactor report op
yhna940 Sep 1, 2023
3488ae1
Fix report bug
yhna940 Sep 1, 2023
c0d8e45
Update mmengine/tune/_report_hook.py
yhna940 Sep 9, 2023
62d6777
Merge branch 'open-mmlab:main' into feature/hyper-naive
yhna940 Sep 9, 2023
27bb08b
Fix comment for report hook
yhna940 Sep 9, 2023
afb5af2
Specify phase in monitor
yhna940 Sep 9, 2023
0cdf020
Fix comment on tuner for monitor
yhna940 Sep 9, 2023
48e2abc
Apply reduce operation to score in trial
yhna940 Sep 9, 2023
eb6b387
Fix comment on tuner
yhna940 Sep 9, 2023
16d5186
Enhance safe trial during tune
yhna940 Sep 9, 2023
8f5ee32
Fix unittest bug
yhna940 Sep 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 23 additions & 0 deletions examples/tune/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Find the Optimal Learning Rate

## Install external dependencies

First, you should install `nevergrad` for tuning.

```bash
pip install nevergrad
```

## Run the example

Single device training

```bash
python examples/tune/find_lr.py
```

Distributed data parallel tuning

```bash
torchrun --nnodes 1 --nproc_per_node 8 examples/tune/find_lr.py --launcher pytorch
```
147 changes: 147 additions & 0 deletions examples/tune/find_lr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import argparse
import tempfile

import torch
import torch.nn as nn
from torch.utils.data import Dataset

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.registry import DATASETS, METRICS, MODELS
from mmengine.runner import Runner


class ToyModel(BaseModel):

def __init__(self, data_preprocessor=None):
super().__init__(data_preprocessor=data_preprocessor)
self.linear1 = nn.Linear(2, 32)
self.linear2 = nn.Linear(32, 64)
self.linear3 = nn.Linear(64, 1)

def forward(self, inputs, data_samples=None, mode='tensor'):
if isinstance(inputs, list):
inputs = torch.stack(inputs)
if isinstance(data_samples, list):
data_samples = torch.stack(data_samples)
outputs = self.linear1(inputs)
outputs = self.linear2(outputs)
outputs = self.linear3(outputs)

if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = ((data_samples - outputs)**2).mean()
outputs = dict(loss=loss)
return outputs
elif mode == 'predict':
return outputs


class ToyDataset(Dataset):
METAINFO = dict() # type: ignore
num_samples = 100
data = torch.rand(num_samples, 2) * 10
label = 3 * data[:, 0] + 4 * data[:, 1] + torch.randn(num_samples) * 0.1

@property
def metainfo(self):
return self.METAINFO

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return dict(inputs=self.data[index], data_samples=self.label[index])


class ToyMetric(BaseMetric):

def __init__(self, collect_device='cpu'):
super().__init__(collect_device=collect_device)
self.results = []

def process(self, data_batch, predictions):
true_values = data_batch['data_samples']
sqe = [(t - p)**2 for t, p in zip(true_values, predictions)]
self.results.extend(sqe)

def compute_metrics(self, results=None):
mse = torch.tensor(self.results).mean().item()
return dict(mse=mse)


def parse_args():
parser = argparse.ArgumentParser(description='Distributed Tuning')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)

args = parser.parse_args()
return args


def main():
args = parse_args()

MODELS.register_module(module=ToyModel, force=True)
METRICS.register_module(module=ToyMetric, force=True)
DATASETS.register_module(module=ToyDataset, force=True)

temp_dir = tempfile.TemporaryDirectory()

runner_cfg = dict(
work_dir=temp_dir.name,
model=dict(type='ToyModel'),
train_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
val_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
val_evaluator=[dict(type='ToyMetric')],
test_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
test_evaluator=[dict(type='ToyMetric')],
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)),
train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1),
val_cfg=dict(),
test_cfg=dict(),
launcher=args.launcher,
default_hooks=dict(logger=dict(type='LoggerHook', interval=1)),
custom_hooks=[],
env_cfg=dict(dist_cfg=dict(backend='nccl')),
experiment_name='test1')

runner = Runner.from_tuning(
runner_cfg=runner_cfg,
hparam_spec={
'optim_wrapper.optimizer.lr': {
'type': 'continuous',
'lower': 1e-5,
'upper': 1e-3
}
},
monitor='loss',
rule='less',
num_trials=16,
tuning_epoch=2,
searcher_cfg=dict(type='NevergradSearcher'),
)
runner.train()

temp_dir.cleanup()


if __name__ == '__main__':
main()
55 changes: 55 additions & 0 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS,
MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS,
RUNNERS, VISUALIZERS, DefaultScope)
from mmengine.tune import Tuner
from mmengine.utils import apply_to, digit_version, get_git_hash, is_seq_of
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
set_multi_processing)
Expand Down Expand Up @@ -474,6 +475,60 @@ def from_cfg(cls, cfg: ConfigType) -> 'Runner':

return runner

@classmethod
def from_tuning(
cls,
runner_cfg: ConfigType,
hparam_spec: Dict,
monitor: str,
rule: str,
num_trials: int,
tuning_iter: Optional[int] = None,
tuning_epoch: Optional[int] = None,
report_op: str = 'latest',
searcher_cfg: Dict = dict(type='RandomSearcher')
) -> 'Runner':
"""Build a runner from tuning.

Args:
runner_cfg (ConfigType): A config used for building runner. Keys of
``runner_cfg`` can see :meth:`__init__`.
hparam_spec (Dict): A dict of hyper parameters to be tuned.
monitor (str): The metric name to be monitored.
rule (Dict): The rule to measure the best metric.
num_trials (int): The maximum number of trials for tuning.
tuning_iter (Optional[int]): The maximum iterations for each trial.
If specified, tuning stops after reaching this limit.
Default is None, indicating no specific iteration limit.
tuning_epoch (Optional[int]): The maximum epochs for each trial.
If specified, tuning stops after reaching this number
of epochs. Default is None, indicating no epoch limit.
report_op (str):
Operation mode for metric reporting. Default is 'latest'.
searcher_cfg (Dict): Configuration for the searcher.
Default is `dict(type='RandomSearcher')`.

Returns:
Runner: A runner build from ``runner_cfg`` tuned by trials.
"""

runner_cfg = copy.deepcopy(runner_cfg)
tuner = Tuner(
runner_cfg=runner_cfg,
hparam_spec=hparam_spec,
monitor=monitor,
rule=rule,
num_trials=num_trials,
tuning_iter=tuning_iter,
tuning_epoch=tuning_epoch,
report_op=report_op,
searcher_cfg=searcher_cfg)
hparam = tuner.tune()['hparam']
assert isinstance(hparam, dict), 'hparam should be a dict'
for k, v in hparam.items():
Tuner.inject_config(runner_cfg, k, v)
return cls.from_cfg(runner_cfg)

@property
def experiment_name(self):
"""str: Name of experiment."""
Expand Down
5 changes: 5 additions & 0 deletions mmengine/tune/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .searchers import * # noqa F403
from .tuner import Tuner

__all__ = ['Tuner']
135 changes: 135 additions & 0 deletions mmengine/tune/_report_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Dict, List, Optional, Sequence, Union

from mmengine.hooks import Hook

DATA_BATCH = Optional[Union[dict, tuple, list]]


class ReportingHook(Hook):
"""Auxiliary hook to report the score to tuner.

If tuning limit is specified, this hook will mark the loop to stop.

Args:
monitor (str): The monitored metric key to report.
tuning_iter (int, optional): The iteration limit to stop tuning.
Defaults to None.
tuning_epoch (int, optional): The epoch limit to stop tuning.
Defaults to None.
report_op (str, optional): The operation to report the score.
Options are 'latest', 'mean'. Defaults to 'latest'.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to describe the meaning of latest, mean here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In line with your suggestion, I have added comments to the meaning and role of report_op, explaining the options like latest and mean.

max_scoreboard_len (int, optional):
The maximum length of the scoreboard.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scoreboard is a new conception for users. We need to introduce it here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify the newly introduced concept of the scoreboard, I have incorporated additional comments in the relevant section to guide users regarding its purpose and usage.

"""

report_op_supported = ['latest', 'mean']

def __init__(self,
monitor: str,
tuning_iter: Optional[int] = None,
tuning_epoch: Optional[int] = None,
report_op: str = 'latest',
max_scoreboard_len: int = 1024):
assert report_op in self.report_op_supported, \
f'report_op {report_op} is not supported'
yhna940 marked this conversation as resolved.
Show resolved Hide resolved
if tuning_iter is not None and tuning_epoch is not None:
raise ValueError(
'tuning_iter and tuning_epoch cannot be set at the same time')
self.report_op = report_op
self.tuning_iter = tuning_iter
self.tuning_epoch = tuning_epoch

self.monitor = monitor
self.max_scoreboard_len = max_scoreboard_len
self.scoreboard: List[float] = []

def _append_score(self, score: float):
"""Append the score to the scoreboard."""
self.scoreboard.append(score)
if len(self.scoreboard) > self.max_scoreboard_len:
self.scoreboard.pop(0)

def _should_stop(self, runner):
"""Check if the training should be stopped.

Args:
runner (Runner): The runner of the training process.
"""
if self.tuning_iter is not None:
if runner.iter + 1 >= self.tuning_iter:
return True
elif self.tuning_epoch is not None:
if runner.epoch + 1 >= self.tuning_epoch:
return True
else:
return False

def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Union[dict, Sequence]] = None,
mode: str = 'train') -> None:
"""Record the score after each iteration.

Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (dict, optional): Outputs from model.
"""

tag, _ = runner.log_processor.get_log_after_iter(
runner, batch_idx, 'train')
score = tag.get(self.monitor, None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
score = tag.get(self.monitor, None)
score = tag.get(self.monitor)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest adding a prefix to monitor like train/loss and val/accuracy. We only check the monitor at specific phase (train or validation) according to its prefix, and raise an error immediately if it is not defined in tag. Reporting an error to users immediately could be better than raising an error after the whole tuning round.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to check the monitored value is a number.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following your advice, I have enhanced the monitoring process by specifying prefixes to it. Moreover, I've embedded logic to verify that the monitored values are numerical to prevent potential errors.

if score is not None:
self._append_score(score)

if self._should_stop(runner):
runner.train_loop.stop_training = True

def after_train_epoch(self, runner) -> None:
"""Record the score after each epoch.

Args:
runner (Runner): The runner of the training process.
"""
if self._should_stop(runner):
runner.train_loop.stop_training = True

def after_val_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""Record the score after each validation epoch.

Args:
runner (Runner): The runner of the validation process.
metrics (Dict[str, float], optional): Evaluation results of all
metrics on validation dataset. The keys are the names of the
metrics, and the values are corresponding results.
"""
if metrics is None:
return
score = metrics.get(self.monitor, None)
if score is not None:
self._append_score(score)

def report_score(self) -> Optional[float]:
"""Aggregate the scores in the scoreboard.

Returns:
Optional[float]: The aggregated score.
"""
if not self.scoreboard:
score = None
elif self.report_op == 'latest':
score = self.scoreboard[-1]
else:
score = sum(self.scoreboard) / len(self.scoreboard)
return score

def clear(self):
"""Clear the scoreboard."""
self.scoreboard.clear()
8 changes: 8 additions & 0 deletions mmengine/tune/searchers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .nevergrad import NevergradSearcher
from .random import RandomSearcher
from .searcher import HYPER_SEARCHERS, Searcher

__all__ = [
'Searcher', 'HYPER_SEARCHERS', 'NevergradSearcher', 'RandomSearcher'
]