diff --git a/examples/tune/README.md b/examples/tune/README.md new file mode 100644 index 0000000000..d94aa479b3 --- /dev/null +++ b/examples/tune/README.md @@ -0,0 +1,23 @@ +# Find the Optimal Learning Rate + +## Install external dependencies + +First, you should install `nevergrad` for tuning. + +```bash +pip install nevergrad +``` + +## Run the example + +Single device training + +```bash +python examples/tune/find_lr.py +``` + +Distributed data parallel tuning + +```bash +torchrun --nnodes 1 --nproc_per_node 8 examples/tune/find_lr.py --launcher pytorch +``` diff --git a/examples/tune/find_lr.py b/examples/tune/find_lr.py new file mode 100644 index 0000000000..3c4f3fa92d --- /dev/null +++ b/examples/tune/find_lr.py @@ -0,0 +1,147 @@ +import argparse +import tempfile + +import torch +import torch.nn as nn +from torch.utils.data import Dataset + +from mmengine.evaluator import BaseMetric +from mmengine.model import BaseModel +from mmengine.registry import DATASETS, METRICS, MODELS +from mmengine.runner import Runner + + +class ToyModel(BaseModel): + + def __init__(self, data_preprocessor=None): + super().__init__(data_preprocessor=data_preprocessor) + self.linear1 = nn.Linear(2, 32) + self.linear2 = nn.Linear(32, 64) + self.linear3 = nn.Linear(64, 1) + + def forward(self, inputs, data_samples=None, mode='tensor'): + if isinstance(inputs, list): + inputs = torch.stack(inputs) + if isinstance(data_samples, list): + data_samples = torch.stack(data_samples) + outputs = self.linear1(inputs) + outputs = self.linear2(outputs) + outputs = self.linear3(outputs) + + if mode == 'tensor': + return outputs + elif mode == 'loss': + loss = ((data_samples - outputs)**2).mean() + outputs = dict(loss=loss) + return outputs + elif mode == 'predict': + return outputs + + +class ToyDataset(Dataset): + METAINFO = dict() # type: ignore + num_samples = 100 + data = torch.rand(num_samples, 2) * 10 + label = 3 * data[:, 0] + 4 * data[:, 1] + torch.randn(num_samples) * 0.1 + + @property + def metainfo(self): + return self.METAINFO + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return dict(inputs=self.data[index], data_samples=self.label[index]) + + +class ToyMetric(BaseMetric): + + def __init__(self, collect_device='cpu'): + super().__init__(collect_device=collect_device) + self.results = [] + + def process(self, data_batch, predictions): + true_values = data_batch['data_samples'] + sqe = [(t - p)**2 for t, p in zip(true_values, predictions)] + self.results.extend(sqe) + + def compute_metrics(self, results=None): + mse = torch.tensor(self.results).mean().item() + return dict(mse=mse) + + +def parse_args(): + parser = argparse.ArgumentParser(description='Distributed Tuning') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + MODELS.register_module(module=ToyModel, force=True) + METRICS.register_module(module=ToyMetric, force=True) + DATASETS.register_module(module=ToyDataset, force=True) + + temp_dir = tempfile.TemporaryDirectory() + + runner_cfg = dict( + work_dir=temp_dir.name, + model=dict(type='ToyModel'), + train_dataloader=dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=False), + batch_size=3, + num_workers=0), + val_evaluator=[dict(type='ToyMetric')], + test_dataloader=dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=False), + batch_size=3, + num_workers=0), + test_evaluator=[dict(type='ToyMetric')], + optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)), + train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1), + val_cfg=dict(), + test_cfg=dict(), + launcher=args.launcher, + default_hooks=dict(logger=dict(type='LoggerHook', interval=1)), + custom_hooks=[], + env_cfg=dict(dist_cfg=dict(backend='nccl')), + experiment_name='test1') + + runner = Runner.from_tuning( + runner_cfg=runner_cfg, + hparam_spec={ + 'optim_wrapper.optimizer.lr': { + 'type': 'continuous', + 'lower': 1e-5, + 'upper': 1e-3 + } + }, + monitor='train/loss', + rule='less', + num_trials=16, + tuning_epoch=2, + searcher_cfg=dict(type='NevergradSearcher'), + ) + runner.train() + + temp_dir.cleanup() + + +if __name__ == '__main__': + main() diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index bd6757a844..f2c29e0987 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -37,6 +37,7 @@ HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS, VISUALIZERS, DefaultScope) +from mmengine.tune import Tuner from mmengine.utils import apply_to, digit_version, get_git_hash, is_seq_of from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env, set_multi_processing) @@ -475,6 +476,60 @@ def from_cfg(cls, cfg: ConfigType) -> 'Runner': return runner + @classmethod + def from_tuning( + cls, + runner_cfg: ConfigType, + hparam_spec: Dict, + monitor: str, + rule: str, + num_trials: int, + tuning_iter: Optional[int] = None, + tuning_epoch: Optional[int] = None, + report_op: str = 'latest', + searcher_cfg: Dict = dict(type='RandomSearcher') + ) -> 'Runner': + """Build a runner from tuning. + + Args: + runner_cfg (ConfigType): A config used for building runner. Keys of + ``runner_cfg`` can see :meth:`__init__`. + hparam_spec (Dict): A dict of hyper parameters to be tuned. + monitor (str): The metric name to be monitored. + rule (Dict): The rule to measure the best metric. + num_trials (int): The maximum number of trials for tuning. + tuning_iter (Optional[int]): The maximum iterations for each trial. + If specified, tuning stops after reaching this limit. + Default is None, indicating no specific iteration limit. + tuning_epoch (Optional[int]): The maximum epochs for each trial. + If specified, tuning stops after reaching this number + of epochs. Default is None, indicating no epoch limit. + report_op (str): + Operation mode for metric reporting. Default is 'latest'. + searcher_cfg (Dict): Configuration for the searcher. + Default is `dict(type='RandomSearcher')`. + + Returns: + Runner: A runner build from ``runner_cfg`` tuned by trials. + """ + + runner_cfg = copy.deepcopy(runner_cfg) + tuner = Tuner( + runner_cfg=runner_cfg, + hparam_spec=hparam_spec, + monitor=monitor, + rule=rule, + num_trials=num_trials, + tuning_iter=tuning_iter, + tuning_epoch=tuning_epoch, + report_op=report_op, + searcher_cfg=searcher_cfg) + hparam = tuner.tune()['hparam'] + assert isinstance(hparam, dict), 'hparam should be a dict' + for k, v in hparam.items(): + Tuner.inject_config(runner_cfg, k, v) + return cls.from_cfg(runner_cfg) + @property def experiment_name(self): """str: Name of experiment.""" diff --git a/mmengine/tune/__init__.py b/mmengine/tune/__init__.py new file mode 100644 index 0000000000..3d921e9ebc --- /dev/null +++ b/mmengine/tune/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .searchers import * # noqa F403 +from .tuner import Tuner + +__all__ = ['Tuner'] diff --git a/mmengine/tune/_report_hook.py b/mmengine/tune/_report_hook.py new file mode 100644 index 0000000000..e0fdde168c --- /dev/null +++ b/mmengine/tune/_report_hook.py @@ -0,0 +1,175 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Callable, Dict, List, Optional, Sequence, Union + +from mmengine.hooks import Hook + +DATA_BATCH = Optional[Union[dict, tuple, list]] + + +class ReportingHook(Hook): + """Auxiliary hook to report the score to tuner. The ReportingHook maintains + a "scoreboard" which keeps track of the monitored metrics' scores during + the training process. The scores are aggregated based on the method + specified by the 'report_op' parameter. If tuning limit is specified, this + hook will mark the loop to stop. + + Args: + monitor (str): The monitored metric key prefixed with either 'train/' + or 'val/' to indicate the specific phase where the metric should + be monitored. For instance, 'train/loss' will monitor the 'loss' + metric during the training phase, and 'val/accuracy' will monitor + the 'accuracy' metric during the validation phase. + The actual metric key (i.e., the part following the prefix) + should correspond to a key in the logs produced during + training or validation. + tuning_iter (int, optional): The iteration limit to stop tuning. + Defaults to None. + tuning_epoch (int, optional): The epoch limit to stop tuning. + Defaults to None. + report_op (str, optional): The method for aggregating scores + in the scoreboard. Accepts the following options: + - 'latest': Returns the most recent score in the scoreboard. + - 'mean': Returns the mean of all scores in the scoreboard. + - 'max': Returns the highest score in the scoreboard. + - 'min': Returns the lowest score in the scoreboard. + Defaults to 'latest'. + max_scoreboard_len (int, optional): + Specifies the maximum number of scores that can be retained + on the scoreboard, helping to manage memory and computational + overhead. Defaults to 1024. + """ + + report_op_supported: Dict[str, Callable[[List[float]], float]] = { + 'latest': lambda x: x[-1], + 'mean': lambda x: sum(x) / len(x), + 'max': max, + 'min': min + } + + def __init__(self, + monitor: str, + tuning_iter: Optional[int] = None, + tuning_epoch: Optional[int] = None, + report_op: str = 'latest', + max_scoreboard_len: int = 1024): + if not monitor.startswith('train/') and not monitor.startswith('val/'): + raise ValueError("The 'monitor' parameter should start " + "with 'train/' or 'val/' to specify the phase.") + if report_op not in self.report_op_supported: + raise ValueError(f'report_op {report_op} is not supported') + if tuning_iter is not None and tuning_epoch is not None: + raise ValueError( + 'tuning_iter and tuning_epoch cannot be set at the same time') + self.monitor_prefix, self.monitor_metric = monitor.split('/', 1) + self.report_op = report_op + self.tuning_iter = tuning_iter + self.tuning_epoch = tuning_epoch + + self.max_scoreboard_len = max_scoreboard_len + self.scoreboard: List[float] = [] + + def _append_score(self, score: float): + """Append the score to the scoreboard.""" + self.scoreboard.append(score) + if len(self.scoreboard) > self.max_scoreboard_len: + self.scoreboard.pop(0) + + def _should_stop(self, runner): + """Check if the training should be stopped. + + Args: + runner (Runner): The runner of the training process. + """ + if self.tuning_iter is not None: + if runner.iter + 1 >= self.tuning_iter: + return True + elif self.tuning_epoch is not None: + if runner.epoch + 1 >= self.tuning_epoch: + return True + else: + return False + + def after_train_iter(self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None, + outputs: Optional[Union[dict, Sequence]] = None, + mode: str = 'train') -> None: + """Record the score after each iteration. + + Args: + runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. + data_batch (dict or tuple or list, optional): Data from dataloader. + outputs (dict, optional): Outputs from model. + """ + if self.monitor_prefix != 'train': + return + tag, _ = runner.log_processor.get_log_after_iter( + runner, batch_idx, 'train') + score = tag.get(self.monitor_metric) + if not isinstance(score, (int, float)): + raise ValueError(f"The monitored value '{self.monitor_metric}' " + 'should be a number.') + self._append_score(score) + + if self._should_stop(runner): + runner.train_loop.stop_training = True + + def after_train_epoch(self, runner) -> None: + """Record the score after each epoch. + + Args: + runner (Runner): The runner of the training process. + """ + if self._should_stop(runner): + runner.train_loop.stop_training = True + + def after_val_epoch(self, + runner, + metrics: Optional[Dict[str, float]] = None) -> None: + """Record the score after each validation epoch. + + Args: + runner (Runner): The runner of the validation process. + metrics (Dict[str, float], optional): Evaluation results of all + metrics on validation dataset. The keys are the names of the + metrics, and the values are corresponding results. + """ + if self.monitor_prefix != 'val' or metrics is None: + return + score = metrics.get(self.monitor_metric) + if not isinstance(score, (int, float)): + raise ValueError(f"The monitored value '{self.monitor_metric}' " + 'should be a number.') + self._append_score(score) + + def report_score(self) -> Optional[float]: + """Aggregate the scores in the scoreboard. + + Returns: + Optional[float]: The aggregated score. + """ + if not self.scoreboard: + score = None + else: + operation = self.report_op_supported[self.report_op] + score = operation(self.scoreboard) + return score + + @classmethod + def register_report_op(cls, name: str, func: Callable[[List[float]], + float]): + """Register a new report operation. + + Args: + name (str): The name of the report operation. + func (Callable[[List[float]], float]): The function to aggregate + the scores. + """ + cls.report_op_supported[name] = func + + def clear(self): + """Clear the scoreboard.""" + self.scoreboard.clear() diff --git a/mmengine/tune/searchers/__init__.py b/mmengine/tune/searchers/__init__.py new file mode 100644 index 0000000000..a7077efb4b --- /dev/null +++ b/mmengine/tune/searchers/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .nevergrad import NevergradSearcher +from .random import RandomSearcher +from .searcher import HYPER_SEARCHERS, Searcher + +__all__ = [ + 'Searcher', 'HYPER_SEARCHERS', 'NevergradSearcher', 'RandomSearcher' +] diff --git a/mmengine/tune/searchers/nevergrad.py b/mmengine/tune/searchers/nevergrad.py new file mode 100644 index 0000000000..bf13140599 --- /dev/null +++ b/mmengine/tune/searchers/nevergrad.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import hashlib +import json +from typing import Dict + +from .searcher import HYPER_SEARCHERS, Searcher + +try: + import nevergrad as ng +except ImportError: + ng = None # type: ignore + + +@HYPER_SEARCHERS.register_module() +class NevergradSearcher(Searcher): + """Support hyper parameter searchering with nevergrad. + + Note: + The detailed usage of nevergrad can be found at + https://facebookresearch.github.io/nevergrad/. + + Args: + rule (str): The rule to compare the score. + Options are 'greater', 'less'. + hparam_spec (Dict[str, Dict]): The hyper parameter specification. + num_trials (int): The number of trials. + solver_type (str): The type of solver. + """ + + def __init__(self, + rule: str, + hparam_spec: Dict[str, Dict], + num_trials: int, + solver_type: str = 'NGOpt', + *args, + **kwargs): + super().__init__(rule, hparam_spec) + assert ng is not None, 'nevergrad is not installed' + self._solver = self._build_solver(solver_type, num_trials) + self._records = dict() # type: ignore + + if self.rule == 'less': + self._rule_op = 1.0 + else: + self._rule_op = -1.0 + + def _build_solver(self, solver_type: str, num_trials: int): + """Build the solver of nevergrad. + + Args: + solver_type (str): The type of solver. + num_trials (int): The number of trials. + """ + converted_hparam_spec = ng.p.Dict( + **{ + k: ng.p.Scalar(lower=v['lower'], upper=v['upper']) + if v['type'] == 'continuous' else ng.p.Choice(v['values']) + for k, v in self.hparam_spec.items() + }) + solver = ng.optimization.optimizerlib.registry[solver_type]( + parametrization=converted_hparam_spec, budget=num_trials) + return solver + + def _hash_dict(self, d: dict) -> str: + """Hash the dict. + + Args: + d (dict): The dict to be hashed. + + Returns: + str: The hashed string. + """ + serialized_data = json.dumps(d, sort_keys=True).encode() + hashed = hashlib.md5(serialized_data).hexdigest() + return hashed + + def record(self, hparam: Dict, score: float): + """Record hparam and score to solver. + + Args: + hparam (Dict): The hparam to be updated + score (float): The score to be updated + """ + hash_key = self._hash_dict(hparam) + assert hash_key in self._records, \ + f'hparam {hparam} is not in the record' + self._solver.tell(self._records[hash_key], score * self._rule_op) + + def suggest(self) -> Dict: + """Suggest a new hparam based on solver's strategy. + + Returns: + Dict: suggested hparam + """ + hparam = self._solver.ask() + hash_key = self._hash_dict(hparam.value) + self._records[hash_key] = hparam + return hparam.value diff --git a/mmengine/tune/searchers/random.py b/mmengine/tune/searchers/random.py new file mode 100644 index 0000000000..421940c756 --- /dev/null +++ b/mmengine/tune/searchers/random.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from typing import Dict + +from .searcher import HYPER_SEARCHERS, Searcher + + +@HYPER_SEARCHERS.register_module() +class RandomSearcher(Searcher): + + def __init__(self, rule: str, hparam_spec: Dict[str, Dict], *args, + **kwargs): + super().__init__(rule, hparam_spec) + + def suggest(self) -> Dict: + """Suggest a new hparam based on random selection. + + Returns: + Dict: suggested hparam + """ + suggestion = {} + for key, spec in self._hparam_spec.items(): + if spec['type'] == 'discrete': + suggestion[key] = random.choice(spec['values']) + elif spec['type'] == 'continuous': + suggestion[key] = random.uniform(spec['lower'], spec['upper']) + + return suggestion diff --git a/mmengine/tune/searchers/searcher.py b/mmengine/tune/searchers/searcher.py new file mode 100644 index 0000000000..ea81bed267 --- /dev/null +++ b/mmengine/tune/searchers/searcher.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Dict + +from mmengine.registry import Registry + +HYPER_SEARCHERS = Registry('hyper parameter searcher') + + +class Searcher: + """Base hyper parameter searcher class. + + All hyper parameter searchers inherit from this class. + """ + + rules_supported = ['greater', 'less'] + + def __init__(self, rule: str, hparam_spec: Dict[str, Dict], *args, + **kwargs): + assert rule in self.rules_supported, \ + f"rule must be 'less' or 'greater', but got {rule}" + self._rule = rule + self._validate_hparam_spec(hparam_spec) + self._hparam_spec = hparam_spec + + def _validate_hparam_spec(self, hparam_spec: Dict[str, Dict]): + """Validate hparam_spec. + + An example of hparam_spec: + + 1. discrete: + hparam_spec = { + 'lr': { + 'type': 'discrete', + 'values': [0.01, 0.02, 0.03] + } + } + + 2. continuous: + hparam_spec = { + 'lr': { + 'type': 'continuous', + 'lower': 0.01, + 'upper': 0.1 + } + } + + Args: + hparam_spec (Dict[str, Dict]): The hyper parameter specification. + """ + for _, v in hparam_spec.items(): + assert v.get('type', None) in [ + 'discrete', 'continuous' + ], \ + 'hparam_spec must have a key "type" and ' \ + f'its value must be "discrete" or "continuous", but got {v}' + if v['type'] == 'discrete': + assert 'values' in v and isinstance(v['values'], list) and \ + v['values'], 'Expected a non-empty "values" list for ' + \ + 'discrete type, but got {v}' + else: + assert 'lower' in v and 'upper' in v, \ + 'Expected keys "lower" and "upper" for continuous ' + \ + f'type, but got {v}' + assert isinstance(v['lower'], (int, float)) and \ + isinstance(v['upper'], (int, float)), \ + f'Expected "lower" and "upper" to be numbers, but got {v}' + assert v['lower'] < v['upper'], \ + f'Expected "lower" to be less than "upper", but got {v}' + + @property + def hparam_spec(self) -> Dict[str, Dict]: + """Dict: The hyper parameter specification.""" + return self._hparam_spec + + @property + def rule(self) -> str: + """str: The rule of the searcher, 'less' or 'greater'.""" + return self._rule + + def record(self, hparam: Dict, score: float): + """Record hparam and score to solver. + + Args: + hparam (Dict): The hparam to be updated + score (float): The score to be updated + """ + + def suggest(self) -> Dict: + """Suggest a new hparam based on solver's strategy. + + Returns: + Dict: suggested hparam + """ diff --git a/mmengine/tune/tuner.py b/mmengine/tune/tuner.py new file mode 100644 index 0000000000..177aab265c --- /dev/null +++ b/mmengine/tune/tuner.py @@ -0,0 +1,344 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import tempfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch + +from mmengine.config import Config, ConfigDict +from mmengine.dist import (all_reduce, broadcast_object_list, init_dist, + is_distributed, is_main_process) +from mmengine.logging import MessageHub, MMLogger +from mmengine.registry import DefaultScope +from mmengine.visualization import Visualizer +from ._report_hook import ReportingHook +from .searchers import HYPER_SEARCHERS, Searcher + +ConfigType = Union[Dict, Config, ConfigDict] + + +class Tuner: + """A helper for hyperparameter tuning. + + By specifying a hyperparameter search space and a monitor metric, + this `Tuner` will automatically identify the optimal hyperparameters + for the runner. + + The `Tuner` employs black-box optimization techniques, ensuring + a systematic search for the best hyperparameters within a defined space. + The tuning process iteratively: + + 1. Searches for hyperparameters considering + the outcomes of previous trials. + 2. Constructs and trains the runner using the given hyperparameters. + 3. Assesses the performance of the trained runner's model + and logs it in the searcher. + + Upon the conclusion of all trials, the hyperparameters that yielded + the peak performance are returned. + + Args: + runner_cfg (ConfigType): + Configuration for the runner. + hparam_spec (Dict[str, Dict]): + The hyperparameter search space definition. + monitor (str): The metric to be monitored during the tuning process. + rule (str): The criterion used to determine the best hyperparameters. + Only 'greater' and 'less' are currently supported. + num_trials (int): Total number of trials to execute. + tuning_iter (Optional[int]): The maximum iterations for each trial. + If specified, tuning stops after reaching this limit. + Default is None, indicating no specific iteration limit. + tuning_epoch (Optional[int]): The maximum epochs for each trial. + If specified, tuning stops after reaching this number of epochs. + Default is None, indicating no epoch limit. + report_op (str): + Operation mode for metric reporting. Default is 'latest'. + searcher_cfg (ConfigType): Configuration for the searcher. + Default is `dict(type='RandomSearcher')`. + + Note: + The black-box optimization depends on external packages, + such as `nevergrad`. Ensure the necessary packages are installed + before using. + + Example: + >>> from mmengine.tune import Tuner + >>> runner_config = {"...": "..."} + >>> hparam_spec = { + >>> 'optim_wrapper.optimizer.lr': { + >>> 'type': 'continuous', + >>> 'lower': 1e-5, + >>> 'upper': 1e-3 + >>> } + >>> } + >>> tuner = Tuner( + >>> runner_cfg, + >>> hparam_spec=hparam_spec, + >>> monitor='train/loss', + >>> rule='less', + >>> num_trials=32, + >>> ) + >>> result = tuner.tune() + >>> print(result['hparam']) + >>> print(result['score']) + """ + rules_supported = ['greater', 'less'] + + def __init__(self, + runner_cfg: ConfigType, + hparam_spec: Dict[str, Dict], + monitor: str, + rule: str, + num_trials: int, + tuning_iter: Optional[int] = None, + tuning_epoch: Optional[int] = None, + report_op: str = 'latest', + searcher_cfg: ConfigType = dict(type='RandomSearcher')): + + self._runner_cfg = runner_cfg.copy() + self._hparam_spec = hparam_spec + self._monitor = monitor + + if rule not in self.rules_supported: + raise ValueError(f'Rule {rule} is not supported') + self._rule = rule + + self._num_trials = num_trials + self._tuning_iter = tuning_iter + self._tuning_epoch = tuning_epoch + self._reporting_op = report_op + self._history: List[Tuple[Dict, float]] = [] + + # Initialize distributed environment if necessary + # This adjustment ensures consistent hyperparameter searching and + # performance recording across all processes. + launcher = self._runner_cfg.get('launcher', 'none') + self._distributed = launcher != 'none' + if self._distributed and not is_distributed(): + env_cfg = runner_cfg.get('env_cfg', {}) + dist_cfg = env_cfg.get('dist_cfg', {}) + init_dist(launcher, **dist_cfg) + + # Build logger to record tuning process + self._logger = MMLogger.get_instance( + 'Tuner', log_level='INFO', distributed=self._distributed) + self._logger.info( + f'Tuner initialized with rule: {rule} and monitor: {monitor}') + + # Build searcher to search for optimal hyperparameters + self._searcher = self._build_searcher(searcher_cfg) + + @property + def hparam_spec(self) -> Dict[str, Dict]: + """str: The hyperparameter search space definition.""" + return self._hparam_spec + + @property + def monitor(self) -> str: + """str: The metric to be monitored during the tuning process.""" + return self._monitor + + @property + def rule(self) -> str: + """str: The criterion used to determine the best hyperparameters.""" + return self._rule + + @property + def num_trials(self) -> int: + """int: Total number of trials to execute.""" + return self._num_trials + + @property + def tuning_iter(self) -> Optional[int]: + """Optional[int]: The maximum iterations for each trial. + If specified, tuning + """ + return self._tuning_iter + + @property + def tuning_epoch(self) -> Optional[int]: + """Optional[int]: The maximum epochs for each trial. + If specified, tuning + """ + return self._tuning_epoch + + @property + def reporting_op(self) -> str: + """str: Operation mode for metric reporting. Default is 'latest'.""" + return self._reporting_op + + @property + def history(self) -> List[Tuple[Dict, float]]: + """List[Tuple[Dict, float]]: The history of hyperparameters and + scores.""" + return self._history + + @property + def searcher(self) -> Searcher: + """Searcher: The searcher used for hyperparameter tuning.""" + return self._searcher + + @staticmethod + def inject_config(cfg: ConfigType, key: str, value: Any): + """Inject a value into a config. + + The name can be multi-level, like 'optimizer.lr'. + + Args: + cfg (ConfigType): The config to be injected. + key (str): The key of the value to be injected. + value (Any): The value to be injected. + """ + keys = key.split('.') + for k in keys[:-1]: + if isinstance(cfg, list): + idx = int(k) + if idx >= len(cfg): + raise KeyError(f'Index {idx} is out of range in {cfg}') + cfg = cfg[idx] + else: + if k not in cfg: + raise KeyError(f"Key '{k}' not found in {cfg}") + cfg = cfg[k] + + if isinstance(cfg, list): + idx = int(keys[-1]) + if idx >= len(cfg): + raise KeyError(f'Index {idx} is out of range in {cfg}') + cfg[idx] = value + else: + if keys[-1] not in cfg: + raise KeyError(f"Key '{keys[-1]}' not found in {cfg}") + else: + cfg[keys[-1]] = value + return + + def _build_searcher(self, searcher_cfg: ConfigType) -> Searcher: + """Build searcher from searcher_cfg. + + An Example of ``searcher_cfg``:: + + searcher_cfg = dict( + type='NevergradSearcher', + solver_type='CMA' + ) + + Args: + searcher_cfg (ConfigType): The searcher config. + """ + searcher_cfg = searcher_cfg.copy() + self._logger.info(f'Building searcher of type: {searcher_cfg["type"]}') + searcher_cfg.update( + dict( + rule=self.rule, + hparam_spec=self.hparam_spec, + num_trials=self._num_trials)) + return HYPER_SEARCHERS.build(searcher_cfg) + + def _tear_down_runner(self, runner): + """Clear the global states of a runner.""" + + # Set the runner's cls attributes to None + runner.cfg = None + runner._train_loop = None + runner._val_loop = None + runner._test_loop = None + + # Remove the instance managed by the ManagerMixin + MMLogger._instance_dict.pop(runner.logger.instance_name) + MessageHub._instance_dict.pop(runner.message_hub.instance_name) + Visualizer._instance_dict.pop(runner.visualizer.instance_name) + DefaultScope._instance_dict.pop(runner.default_scope.instance_name) + + def _run_trial(self) -> Tuple[Dict, float, Optional[Exception]]: + """Retrieve hyperparameters from searcher and run a trial.""" + from mmengine.runner import Runner + + # Retrieve hyperparameters for the trial: + # Only the main process invokes the searcher's suggest method + # to mitigate the potential randomness that might occur in methods + # like Bayesian optimization or evolutionary algorithms. + # These methods might introduce randomness in the selection of + # hyperparameters, potentially leading to inconsistent suggestions + # across different processes. By centralizing the suggestion + # to the main process, we ensure a consistent set of hyperparameters + # is used for each trial. + if is_main_process(): + hparams_to_broadcast = [self._searcher.suggest()] + else: + hparams_to_broadcast = [None] # type: ignore + broadcast_object_list(hparams_to_broadcast, src=0) + hparam = hparams_to_broadcast[0] + + # Inject hyperparameters into runner config. + for k, v in hparam.items(): + self.inject_config(self._runner_cfg, k, v) + runner = Runner.from_cfg(self._runner_cfg) + report_hook = ReportingHook(self._monitor, self._tuning_iter, + self._tuning_epoch, self._reporting_op) + runner.register_hook(report_hook, priority='VERY_LOW') + default_score = float('inf') if self._rule == 'less' else -float('inf') + + # Run a trial. + # If an exception occurs during the trial, the score is set + # to default_score. + score: float + error: Optional[Exception] = None + try: + runner.train() + score = report_hook.report_score() # type: ignore + if score is None or math.isnan(score) or math.isinf(score): + score = default_score + except Exception as e: + score = default_score + error = e + finally: + self._tear_down_runner(runner) + + # Synchronize and average scores across all processes + score_tensor = torch.tensor(score, dtype=torch.float64) + all_reduce(score_tensor, op='mean') + score = score_tensor.item() + + if is_main_process(): + self._searcher.record(hparam, score) + return hparam, score, error + + def tune(self) -> Dict[str, Union[Dict[str, Any], float]]: + """Launch tuning. + + Returns: + Dict[str, Union[Dict[str, Any], float]]: + A dictionary containing the best hyperparameters under the key + 'hparam' and the corresponding score under the key 'score'. + """ + temp_dir = tempfile.TemporaryDirectory() + self._runner_cfg['work_dir'] = temp_dir.name + self._logger.info(f'Starting tuning for {self._num_trials} trials...') + for trail_idx in range(self._num_trials): + hparam, score, error = self._run_trial() + log_msg = f'Trial [{trail_idx + 1}/{self._num_trials}]' + if error is not None: + log_msg += f' failed. Error: {error}' + else: + log_msg += f' finished. Score obtained: {score}' + log_msg += f' Hyperparameters used: {hparam}' + self._logger.info(log_msg) + self._history.append((hparam, score)) + + best_hparam: dict + best_score: float + if self._rule == 'greater': + best_hparam, best_score = max(self._history, key=lambda x: x[1]) + else: + best_hparam, best_score = min(self._history, key=lambda x: x[1]) + self._logger.info(f'Best hyperparameters obtained: {best_hparam}') + self._logger.info(f'Best score obtained: {best_score}') + self._logger.info('Tuning completed.') + temp_dir.cleanup() + return dict(hparam=best_hparam, score=best_score) + + def clear(self): + """Clear the history of hyperparameters and scores.""" + self._history.clear() diff --git a/tests/test_tune/test_report_hook.py b/tests/test_tune/test_report_hook.py new file mode 100644 index 0000000000..c9fefdef0e --- /dev/null +++ b/tests/test_tune/test_report_hook.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import MagicMock + +from mmengine.testing import RunnerTestCase +from mmengine.tune._report_hook import ReportingHook + + +class TestReportingHook(RunnerTestCase): + + def test_append_score(self): + hook = ReportingHook(monitor='train/acc', max_scoreboard_len=3) + + # Adding scores to the scoreboard + hook._append_score(0.5) + hook._append_score(0.6) + hook._append_score(0.7) + self.assertEqual(hook.scoreboard, [0.5, 0.6, 0.7]) + + # When exceeding max length, it should pop the first item + hook._append_score(0.8) + self.assertEqual(hook.scoreboard, [0.6, 0.7, 0.8]) + + def test_should_stop(self): + runner = MagicMock(iter=3, epoch=1) + + # Test with tuning_iter + hook1 = ReportingHook(monitor='train/cc', tuning_iter=5) + self.assertFalse(hook1._should_stop(runner)) + runner.iter = 4 + self.assertTrue(hook1._should_stop(runner)) + + # Test with tuning_epoch + hook2 = ReportingHook(monitor='train/acc', tuning_epoch=3) + self.assertFalse(hook2._should_stop(runner)) + runner.epoch = 2 + self.assertTrue(hook2._should_stop(runner)) + + def test_report_score(self): + hook1 = ReportingHook(monitor='train/acc', report_op='latest') + hook1.scoreboard = [0.5, 0.6, 0.7] + self.assertEqual(hook1.report_score(), 0.7) + + hook2 = ReportingHook(monitor='train/acc', report_op='mean') + hook2.scoreboard = [0.5, 0.6, 0.7] + self.assertEqual(hook2.report_score(), 0.6) + + # Test with an empty scoreboard + hook3 = ReportingHook(monitor='train/acc', report_op='mean') + self.assertIsNone(hook3.report_score()) + + def test_clear(self): + hook = ReportingHook(monitor='train/acc') + hook.scoreboard = [0.5, 0.6, 0.7] + hook.clear() + self.assertEqual(hook.scoreboard, []) + + def test_after_train_iter(self): + runner = MagicMock(iter=3, epoch=1) + runner.log_processor.get_log_after_iter = MagicMock( + return_value=({ + 'acc': 0.9 + }, 'log_str')) + + # Check if the monitored score gets appended correctly + hook = ReportingHook(monitor='train/acc') + hook.after_train_iter(runner, 0) + self.assertEqual(hook.scoreboard[-1], 0.9) + + # Check the error raised when the monitored score is missing from logs + hook2 = ReportingHook(monitor='train/non_existent') + with self.assertRaises(ValueError): + hook2.after_train_iter(runner, 0) + + # Check that training stops if tuning_iter is reached + runner.iter = 5 + hook3 = ReportingHook(monitor='train/acc', tuning_iter=5) + hook3.after_train_iter(runner, 0) + self.assertTrue(runner.train_loop.stop_training) + + def test_after_val_epoch(self): + runner = MagicMock(iter=3, epoch=1) + + # Check if the monitored score gets appended correctly from metrics + metrics = {'acc': 0.9} + hook = ReportingHook(monitor='val/acc') + hook.after_val_epoch(runner, metrics=metrics) + self.assertEqual(hook.scoreboard[-1], 0.9) + + # Check the error raised when the monitored score is missing from logs + metrics = {'loss': 0.1} + hook2 = ReportingHook(monitor='val/acc') + with self.assertRaises(ValueError): + hook2.after_val_epoch(runner, metrics=metrics) + + def test_with_runner(self): + runner = self.build_runner(self.epoch_based_cfg) + acc_hook = ReportingHook(monitor='val/acc', tuning_epoch=1) + runner.register_hook(acc_hook, priority='VERY_LOW') + runner.train() + self.assertEqual(runner.epoch, 1) + score = acc_hook.report_score() + self.assertAlmostEqual(score, 1) diff --git a/tests/test_tune/test_searchers/test_nevergrad.py b/tests/test_tune/test_searchers/test_nevergrad.py new file mode 100644 index 0000000000..08c92ec7c0 --- /dev/null +++ b/tests/test_tune/test_searchers/test_nevergrad.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from typing import List +from unittest import TestCase, skipIf + +from mmengine.tune.searchers import NevergradSearcher + +try: + import nevergrad # noqa: F401 + NEVERGRAD_AVAILABLE = True +except ImportError: + NEVERGRAD_AVAILABLE = False + + +@skipIf(not NEVERGRAD_AVAILABLE, 'nevergrad is not installed') +class TestNevergradSearcher(TestCase): + + def noisy_sphere_function(self, x: List[float]): + """Sphere function with noise: f(x) = sum(x_i^2) + noise""" + noise = random.gauss(0, 0.1) # Gaussian noise with mean 0 and std 0.1 + return sum([x_i**2 for x_i in x]) + noise + + def one_max_function(self, x: List[int]): + """OneMax function: f(x) = sum(x_i) for binary x_i""" + return sum(x) + + def test_hash_dict(self): + searcher = NevergradSearcher( + rule='less', + hparam_spec={'test': { + 'type': 'discrete', + 'values': [0, 1] + }}, + num_trials=100, + solver_type='OnePlusOne') + + # Check different dicts yield different hashes + d1 = {'x': 1, 'y': 2} + d2 = {'x': 1, 'y': 3} + self.assertNotEqual(searcher._hash_dict(d1), searcher._hash_dict(d2)) + + # Check same dict yields same hash + self.assertEqual(searcher._hash_dict(d1), searcher._hash_dict(d1)) + + # Check order doesn't matter + d3 = {'y': 2, 'x': 1} + self.assertEqual(searcher._hash_dict(d1), searcher._hash_dict(d3)) + + def test_noisy_sphere_function(self): + hparam_continuous_space = { + 'x1': { + 'type': 'continuous', + 'lower': -5.0, + 'upper': 5.0 + }, + 'x2': { + 'type': 'continuous', + 'lower': -5.0, + 'upper': 5.0 + } + } + searcher = NevergradSearcher( + rule='less', + hparam_spec=hparam_continuous_space, + num_trials=100, + solver_type='CMA') + for _ in range(100): + hparam = searcher.suggest() + score = self.noisy_sphere_function([v for _, v in hparam.items()]) + searcher.record(hparam, score) + # For the noisy sphere function, + # the optimal should be close to x1=0 and x2=0 + hparam = searcher.suggest() + self.assertAlmostEqual(hparam['x1'], 0.0, delta=1) + self.assertAlmostEqual(hparam['x2'], 0.0, delta=1) + + def test_one_max_function(self): + # Define the discrete search space for OneMax + hparam_discrete_space = { + f'x{i}': { + 'type': 'discrete', + 'values': [0, 1] + } + for i in range(1, 8) + } + searcher = NevergradSearcher( + rule='greater', + hparam_spec=hparam_discrete_space, + num_trials=300, + solver_type='NGO') + for _ in range(300): + hparam = searcher.suggest() + score = self.one_max_function([v for _, v in hparam.items()]) + searcher.record(hparam, score) + hparam = searcher.suggest() + self.assertGreaterEqual(score, 6) diff --git a/tests/test_tune/test_searchers/test_random.py b/tests/test_tune/test_searchers/test_random.py new file mode 100644 index 0000000000..6b57843aaa --- /dev/null +++ b/tests/test_tune/test_searchers/test_random.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +from mmengine.tune.searchers import RandomSearcher + + +class TestRandomSearcher(TestCase): + + def test_suggest(self): + searcher = RandomSearcher( + rule='greater', + hparam_spec={ + 'x1': { + 'type': 'discrete', + 'values': [0.01, 0.02, 0.03] + }, + 'x2': { + 'type': 'continuous', + 'lower': 0.01, + 'upper': 0.1 + } + }) + + for _ in range(100): + hparam = searcher.suggest() + self.assertTrue(hparam['x1'] in [0.01, 0.02, 0.03]) + self.assertTrue(hparam['x2'] >= 0.01 and hparam['x2'] <= 0.1) diff --git a/tests/test_tune/test_searchers/test_searcher.py b/tests/test_tune/test_searchers/test_searcher.py new file mode 100644 index 0000000000..a010fa0397 --- /dev/null +++ b/tests/test_tune/test_searchers/test_searcher.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +from mmengine.tune.searchers import Searcher + + +class TestSearcher(TestCase): + + def test_rule(self): + valid_hparam_spec_1 = { + 'lr': { + 'type': 'discrete', + 'values': [0.01, 0.02, 0.03] + } + } + # Invalid cases + with self.assertRaises(AssertionError): + Searcher(rule='invalid_rule', hparam_spec=valid_hparam_spec_1) + Searcher(rule='greater', hparam_spec=valid_hparam_spec_1) + Searcher(rule='less', hparam_spec=valid_hparam_spec_1) + + def test_validate_hparam_spec(self): + # Unknown hparam spec type + invalid_hparam_spec_1 = { + 'lr': { + 'type': 'unknown_type', + 'values': [0.01, 0.02, 0.03] + } + } + with self.assertRaises(AssertionError): + Searcher(rule='greater', hparam_spec=invalid_hparam_spec_1) + + # Missing keys in continuous hparam_spec + invalid_hparam_spec_2 = {'lr': {'type': 'continuous', 'lower': 0.01}} + with self.assertRaises(AssertionError): + Searcher(rule='less', hparam_spec=invalid_hparam_spec_2) + + # Invalid discrete hparam_spec + invalid_hparam_spec_3 = { + 'lr': { + 'type': 'discrete', + 'values': [] # Empty list + } + } + with self.assertRaises(AssertionError): + Searcher(rule='greater', hparam_spec=invalid_hparam_spec_3) + + # Invalid continuous hparam_spec + invalid_hparam_spec_4 = { + 'lr': { + 'type': 'continuous', + 'lower': 0.1, + 'upper': 0.01 # lower is greater than upper + } + } + with self.assertRaises(AssertionError): + Searcher(rule='less', hparam_spec=invalid_hparam_spec_4) + + # Invalid data type in continuous hparam_spec + invalid_hparam_spec_5 = { + 'lr': { + 'type': 'continuous', + 'lower': '0.01', # String instead of number + 'upper': 0.1 + } + } + with self.assertRaises(AssertionError): + Searcher(rule='less', hparam_spec=invalid_hparam_spec_5) + + def test_hparam_spec_property(self): + hparam_spec = { + 'lr': { + 'type': 'discrete', + 'values': [0.01, 0.02, 0.03] + } + } + searcher = Searcher(rule='greater', hparam_spec=hparam_spec) + self.assertEqual(searcher.hparam_spec, hparam_spec) + + def test_rule_property(self): + searcher = Searcher(rule='greater', hparam_spec={}) + self.assertEqual(searcher.rule, 'greater') diff --git a/tests/test_tune/test_tuner.py b/tests/test_tune/test_tuner.py new file mode 100644 index 0000000000..1af905976e --- /dev/null +++ b/tests/test_tune/test_tuner.py @@ -0,0 +1,240 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict +from unittest import mock + +from mmengine.testing import RunnerTestCase +from mmengine.tune import Tuner +from mmengine.tune.searchers import HYPER_SEARCHERS, Searcher + + +class ToySearcher(Searcher): + + def suggest(self) -> Dict: + hparam = dict() + for k, v in self.hparam_spec.items(): + if v['type'] == 'discrete': + hparam[k] = v['values'][0] + else: + hparam[k] = (v['lower'] + v['upper']) / 2 + return hparam + + +class TestTuner(RunnerTestCase): + + def setUp(self) -> None: + super().setUp() + HYPER_SEARCHERS.register_module(module=ToySearcher) + self.hparam_spec = { + 'optim_wrapper.optimizer.lr': { + 'type': 'discrete', + 'values': [0.1, 0.2, 0.3] + } + } + + def tearDown(self): + super().tearDown() + HYPER_SEARCHERS.module_dict.pop('ToySearcher', None) + + def test_init(self): + with self.assertRaises(ValueError): + Tuner( + runner_cfg=dict(), + hparam_spec=dict(), + monitor='loss', + rule='invalid_rule', + num_trials=2, + searcher_cfg=dict(type='ToySearcher')) + + # Initializing with correct parameters + tuner = Tuner( + runner_cfg=self.epoch_based_cfg, + hparam_spec=self.hparam_spec, + monitor='loss', + rule='less', + num_trials=2, + searcher_cfg=dict(type='ToySearcher')) + + # Verify the properties + self.assertEqual(tuner.hparam_spec, self.hparam_spec) + self.assertEqual(tuner.monitor, 'loss') + self.assertEqual(tuner.rule, 'less') + self.assertEqual(tuner.num_trials, 2) + + # Ensure a searcher of type ToySearcher is used + self.assertIsInstance(tuner.searcher, ToySearcher) + + def mock_is_main_process(self, return_value=True): + return mock.patch( + 'mmengine.dist.is_main_process', return_value=return_value) + + def mock_broadcast(self, side_effect=None): + return mock.patch( + 'mmengine.dist.broadcast_object_list', side_effect=side_effect) + + def test_inject_config(self): + # Inject into a single level + cfg = {'a': 1} + Tuner.inject_config(cfg, 'a', 2) + self.assertEqual(cfg['a'], 2) + + # Inject into a nested level + cfg = {'level1': {'level2': {'level3': 3}}} + Tuner.inject_config(cfg, 'level1.level2.level3', 4) + self.assertEqual(cfg['level1']['level2']['level3'], 4) + + # Inject into a non-existent key + cfg = {} + with self.assertRaises(KeyError): + Tuner.inject_config(cfg, 'a', 1) + + # Inject into a sequence + cfg = {'sequence': [1, 2, 3]} + Tuner.inject_config(cfg, 'sequence.1', 5) + self.assertEqual(cfg['sequence'][1], 5) + + @mock.patch('mmengine.runner.Runner.train') + @mock.patch('mmengine.tune._report_hook.ReportingHook.report_score') + def test_successful_run(self, mock_report_score, mock_train): + tuner = Tuner( + runner_cfg=self.epoch_based_cfg, + hparam_spec=self.hparam_spec, + monitor='train/loss', + rule='less', + num_trials=2, + searcher_cfg=dict(type='ToySearcher')) + + tuner.searcher.suggest = mock.MagicMock( + return_value={'optim_wrapper.optimizer.lr': 0.1}) + tuner.searcher.record = mock.MagicMock() + + mock_report_score.return_value = 0.05 + + with self.mock_is_main_process(), self.mock_broadcast(): + hparam, score, error = tuner._run_trial() + + self.assertEqual(hparam, {'optim_wrapper.optimizer.lr': 0.1}) + self.assertEqual(score, 0.05) + self.assertIsNone(error) + tuner.searcher.record.assert_called_with( + {'optim_wrapper.optimizer.lr': 0.1}, 0.05) + + @mock.patch('mmengine.runner.Runner.train') + @mock.patch('mmengine.tune._report_hook.ReportingHook.report_score') + def test_run_with_exception(self, mock_report_score, mock_train): + mock_train.side_effect = Exception('Error during training') + + tuner = Tuner( + runner_cfg=self.epoch_based_cfg, + hparam_spec=self.hparam_spec, + monitor='train/loss', + rule='less', + num_trials=2, + searcher_cfg=dict(type='ToySearcher')) + + tuner.searcher.suggest = mock.MagicMock( + return_value={'optim_wrapper.optimizer.lr': 0.1}) + tuner.searcher.record = mock.MagicMock() + + with self.mock_is_main_process(), self.mock_broadcast(): + hparam, score, error = tuner._run_trial() + + self.assertEqual(hparam, {'optim_wrapper.optimizer.lr': 0.1}) + self.assertEqual(score, float('inf')) + self.assertTrue(isinstance(error, Exception)) + tuner.searcher.record.assert_called_with( + {'optim_wrapper.optimizer.lr': 0.1}, float('inf')) + + @mock.patch('mmengine.runner.Runner.train') + @mock.patch('mmengine.tune._report_hook.ReportingHook.report_score') + def test_tune(self, mock_report_score, mock_train): + mock_scores = [0.05, 0.03, 0.04, 0.06] + mock_hparams = [{ + 'optim_wrapper.optimizer.lr': 0.1 + }, { + 'optim_wrapper.optimizer.lr': 0.05 + }, { + 'optim_wrapper.optimizer.lr': 0.2 + }, { + 'optim_wrapper.optimizer.lr': 0.3 + }] + + mock_report_score.side_effect = mock_scores + + tuner = Tuner( + runner_cfg=self.epoch_based_cfg, + hparam_spec=self.hparam_spec, + monitor='loss', + rule='less', + num_trials=4, + searcher_cfg=dict(type='ToySearcher')) + + mock_run_trial_return_values = [ + (mock_hparams[0], mock_scores[0], None), + (mock_hparams[1], mock_scores[1], + Exception('Error during training')), + (mock_hparams[2], mock_scores[2], None), + (mock_hparams[3], mock_scores[3], None) + ] + tuner._run_trial = mock.MagicMock( + side_effect=mock_run_trial_return_values) + + with self.mock_is_main_process(), self.mock_broadcast(): + result = tuner.tune() + + self.assertEqual(tuner._history, [(mock_hparams[0], mock_scores[0]), + (mock_hparams[1], mock_scores[1]), + (mock_hparams[2], mock_scores[2]), + (mock_hparams[3], mock_scores[3])]) + + self.assertEqual(result, { + 'hparam': mock_hparams[1], + 'score': mock_scores[1] + }) + + tuner = Tuner( + runner_cfg=self.epoch_based_cfg, + hparam_spec=self.hparam_spec, + monitor='loss', + rule='greater', + num_trials=4, + searcher_cfg=dict(type='ToySearcher')) + tuner._run_trial = mock.MagicMock( + side_effect=mock_run_trial_return_values) + with self.mock_is_main_process(), self.mock_broadcast(): + result = tuner.tune() + self.assertEqual(result, { + 'hparam': mock_hparams[3], + 'score': mock_scores[3] + }) + + def test_clear(self): + tuner = Tuner( + runner_cfg=self.epoch_based_cfg, + hparam_spec=self.hparam_spec, + monitor='loss', + rule='less', + num_trials=2, + searcher_cfg=dict(type='ToySearcher')) + + tuner.history.append(({'optim_wrapper.optimizer.lr': 0.1}, 0.05)) + tuner.clear() + self.assertEqual(tuner.history, []) + + def test_with_runner(self): + tuner = Tuner( + runner_cfg=self.epoch_based_cfg, + hparam_spec=self.hparam_spec, + monitor='val/acc', + rule='greater', + num_trials=10, + searcher_cfg=dict(type='ToySearcher')) + + with self.mock_is_main_process(), self.mock_broadcast(): + result = tuner.tune() + + self.assertTrue({ + hparam['optim_wrapper.optimizer.lr'] + for hparam, _ in tuner.history + }.issubset( + set(self.hparam_spec['optim_wrapper.optimizer.lr']['values']))) + self.assertEqual(result['score'], 1)