diff --git a/mmrazor/engine/hooks/estimate_resources_hook.py b/mmrazor/engine/hooks/estimate_resources_hook.py index dc27f2906..e8c4d8446 100644 --- a/mmrazor/engine/hooks/estimate_resources_hook.py +++ b/mmrazor/engine/hooks/estimate_resources_hook.py @@ -7,7 +7,7 @@ from mmengine.registry import HOOKS from mmengine.structures import BaseDataElement -from mmrazor.models.task_modules import ResourceEstimator +from mmrazor.registry import TASK_UTILS DATA_BATCH = Optional[Sequence[dict]] @@ -23,7 +23,7 @@ class EstimateResourcesHook(Hook): by_epoch (bool): Saving checkpoints by epoch or by iteration. Default to True. estimator_cfg (Dict[str, Any]): Used for building a resource estimator. - Default to dict(). + Default to None. Example: >>> add the `EstimatorResourcesHook` in custom_hooks as follows: @@ -41,11 +41,14 @@ class EstimateResourcesHook(Hook): def __init__(self, interval: int = -1, by_epoch: bool = True, - estimator_cfg: Dict[str, Any] = dict(), + estimator_cfg: Dict[str, Any] = None, **kwargs) -> None: self.interval = interval self.by_epoch = by_epoch - self.estimator = ResourceEstimator(**estimator_cfg) + estimator_cfg = dict() if estimator_cfg is None else estimator_cfg + if 'type' not in estimator_cfg: + estimator_cfg['type'] = 'mmrazor.ResourceEstimator' + self.estimator = TASK_UTILS.build(estimator_cfg) def after_val_epoch(self, runner, diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index 8edf1be3a..fc907f3aa 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy import os import os.path as osp import random @@ -15,7 +14,6 @@ from mmengine.utils import is_list_of from torch.utils.data import DataLoader -from mmrazor.models.task_modules import ResourceEstimator from mmrazor.registry import LOOPS, TASK_UTILS from mmrazor.structures import Candidates, export_fix_subnet from mmrazor.utils import SupportRandomSubnet @@ -46,8 +44,8 @@ class EvolutionSearchLoop(EpochBasedTrainLoop): crossover_prob (float): The probability of crossover. Defaults to 0.5. constraints_range (Dict[str, Any]): Constraints to be used for screening candidates. Defaults to dict(flops=(0, 330)). - resource_estimator_cfg (dict, Optional): Used for building a - resource estimator. Defaults to None. + estimator_cfg (dict, Optional): Used for building a resource estimator. + Defaults to None. predictor_cfg (dict, Optional): Used for building a metric predictor. Defaults to None. score_key (str): Specify one metric in evaluation results to score @@ -71,7 +69,7 @@ def __init__(self, mutate_prob: float = 0.1, crossover_prob: float = 0.5, constraints_range: Dict[str, Any] = dict(flops=(0., 330.)), - resource_estimator_cfg: Optional[Dict] = None, + estimator_cfg: Optional[Dict] = None, predictor_cfg: Optional[Dict] = None, score_key: str = 'accuracy/top1', init_candidates: Optional[str] = None) -> None: @@ -113,66 +111,26 @@ def __init__(self, else: self.model = runner.model - # Build resource estimator. - resource_estimator_cfg = dict( - ) if resource_estimator_cfg is None else resource_estimator_cfg - self.estimator = self.build_resource_estimator(resource_estimator_cfg) + # initialize estimator + estimator_cfg = dict() if estimator_cfg is None else estimator_cfg + if 'type' not in estimator_cfg: + estimator_cfg['type'] = 'mmrazor.ResourceEstimator' + self.estimator = TASK_UTILS.build(estimator_cfg) + # initialize predictor self.use_predictor = False self.predictor_cfg = predictor_cfg - - def build_resource_estimator( - self, resource_estimator: Union[ResourceEstimator, - Dict]) -> ResourceEstimator: - """Build resource estimator for search loop. - - Examples of ``resource_estimator``: - - # `ResourceEstimator` will be used - resource_estimator = dict() - - # custom resource_estimator - resource_estimator = dict(type='mmrazor.ResourceEstimator') - - Args: - resource_estimator (ResourceEstimator or dict): A - resource_estimator or a dict to build resource estimator. - If ``resource_estimator`` is a resource estimator object, - just returns itself. - - Returns: - :obj:`ResourceEstimator`: Resource estimator object build from - ``resource_estimator``. - """ - if isinstance(resource_estimator, ResourceEstimator): - return resource_estimator - elif not isinstance(resource_estimator, dict): - raise TypeError( - 'resource estimator should be a ResourceEstimator object or' - f'dict, but got {resource_estimator}') - - resource_estimator_cfg = copy.deepcopy( - resource_estimator) # type: ignore - - if 'type' in resource_estimator_cfg: - estimator = TASK_UTILS.build(resource_estimator_cfg) - else: - estimator = ResourceEstimator( - **resource_estimator_cfg) # type: ignore - - return estimator # type: ignore + if self.predictor_cfg is not None: + self.predictor_cfg['score_key'] = self.score_key + self.predictor_cfg['search_groups'] = \ + self.model.mutator.search_groups + self.predictor = TASK_UTILS.build(self.predictor_cfg) def run(self) -> None: """Launch searching.""" self.runner.call_hook('before_train') if self.predictor_cfg is not None: - # initialize predictor - self.predictor_cfg['score_key'] = self.score_key - self.predictor_cfg['search_groups'] = \ - self.model.mutator.search_groups - - self.predictor = TASK_UTILS.build(self.predictor_cfg) self._init_predictor() if self.resume_from: diff --git a/mmrazor/engine/runner/subnet_sampler_loop.py b/mmrazor/engine/runner/subnet_sampler_loop.py index 273561568..56c4f893c 100644 --- a/mmrazor/engine/runner/subnet_sampler_loop.py +++ b/mmrazor/engine/runner/subnet_sampler_loop.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy import math import os import random @@ -13,7 +12,6 @@ from mmengine.utils import is_list_of from torch.utils.data import DataLoader -from mmrazor.models.task_modules import ResourceEstimator from mmrazor.registry import LOOPS, TASK_UTILS from mmrazor.structures import Candidates from mmrazor.utils import SupportRandomSubnet @@ -102,8 +100,8 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop): candidates. Defaults to 'accuracy_top-1'. constraints_range (Dict[str, Any]): Constraints to be used for screening candidates. Defaults to dict(flops=(0, 330)). - resource_estimator_cfg (dict, Optional): Used for building a - resource estimator. Defaults to None. + estimator_cfg (dict, Optional): Used for building a resource estimator. + Defaults to None. num_candidates (int): The number of the candidates consist of samples from supernet and itself. Defaults to 1000. num_samples (int): The number of sample in each sampling subnet. @@ -138,7 +136,7 @@ def __init__(self, val_interval: int = 1000, score_key: str = 'accuracy/top1', constraints_range: Dict[str, Any] = dict(flops=(0, 330)), - resource_estimator_cfg: Optional[Dict] = None, + estimator_cfg: Optional[Dict] = None, num_candidates: int = 1000, num_samples: int = 10, top_k: int = 5, @@ -176,51 +174,11 @@ def __init__(self, self.candidates = Candidates() self.top_k_candidates = Candidates() - # Build resource estimator. - resource_estimator_cfg = dict( - ) if resource_estimator_cfg is None else resource_estimator_cfg - self.estimator = self.build_resource_estimator(resource_estimator_cfg) - - def build_resource_estimator( - self, resource_estimator: Union[ResourceEstimator, - Dict]) -> ResourceEstimator: - """Build resource estimator for search loop. - - Examples of ``resource_estimator``: - - # `ResourceEstimator` will be used - resource_estimator = dict() - - # custom resource_estimator - resource_estimator = dict(type='mmrazor.ResourceEstimator') - - Args: - resource_estimator (ResourceEstimator or dict): - A resource_estimator or a dict to build resource estimator. - If ``resource_estimator`` is a resource estimator object, - just returns itself. - - Returns: - :obj:`ResourceEstimator`: Resource estimator object build from - ``resource_estimator``. - """ - if isinstance(resource_estimator, ResourceEstimator): - return resource_estimator - elif not isinstance(resource_estimator, dict): - raise TypeError( - 'resource estimator should be a ResourceEstimator object or' - f'dict, but got {resource_estimator}') - - resource_estimator_cfg = copy.deepcopy( - resource_estimator) # type: ignore - - if 'type' in resource_estimator_cfg: - estimator = TASK_UTILS.build(resource_estimator_cfg) - else: - estimator = ResourceEstimator( - **resource_estimator_cfg) # type: ignore - - return estimator # type: ignore + # initialize estimator + estimator_cfg = dict() if estimator_cfg is None else estimator_cfg + if 'type' not in estimator_cfg: + estimator_cfg['type'] = 'mmrazor.ResourceEstimator' + self.estimator = TASK_UTILS.build(estimator_cfg) def run(self) -> None: """Launch training.""" diff --git a/mmrazor/models/algorithms/nas/dsnas.py b/mmrazor/models/algorithms/nas/dsnas.py index 5434ce0ac..a763c75ea 100644 --- a/mmrazor/models/algorithms/nas/dsnas.py +++ b/mmrazor/models/algorithms/nas/dsnas.py @@ -68,8 +68,10 @@ def __init__(self, **kwargs): super().__init__(architecture, data_preprocessor, **kwargs) - if estimator_cfg is None: - estimator_cfg = dict(type='mmrazor.ResourceEstimator') + # initialize estimator + estimator_cfg = dict() if estimator_cfg is None else estimator_cfg + if 'type' not in estimator_cfg: + estimator_cfg['type'] = 'mmrazor.ResourceEstimator' self.estimator = TASK_UTILS.build(estimator_cfg) if fix_subnet: # Avoid circular import