Skip to content

Commit

Permalink
move predictor-build in __init__ & simplify estimator-build
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyang07 committed Nov 14, 2022
1 parent e2fdefc commit 7b46037
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 112 deletions.
11 changes: 7 additions & 4 deletions mmrazor/engine/hooks/estimate_resources_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mmengine.registry import HOOKS
from mmengine.structures import BaseDataElement

from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.registry import TASK_UTILS

DATA_BATCH = Optional[Sequence[dict]]

Expand All @@ -23,7 +23,7 @@ class EstimateResourcesHook(Hook):
by_epoch (bool): Saving checkpoints by epoch or by iteration.
Default to True.
estimator_cfg (Dict[str, Any]): Used for building a resource estimator.
Default to dict().
Default to None.
Example:
>>> add the `EstimatorResourcesHook` in custom_hooks as follows:
Expand All @@ -41,11 +41,14 @@ class EstimateResourcesHook(Hook):
def __init__(self,
interval: int = -1,
by_epoch: bool = True,
estimator_cfg: Dict[str, Any] = dict(),
estimator_cfg: Dict[str, Any] = None,
**kwargs) -> None:
self.interval = interval
self.by_epoch = by_epoch
self.estimator = ResourceEstimator(**estimator_cfg)
estimator_cfg = dict() if estimator_cfg is None else estimator_cfg
if 'type' not in estimator_cfg:
estimator_cfg['type'] = 'mmrazor.ResourceEstimator'
self.estimator = TASK_UTILS.build(estimator_cfg)

def after_val_epoch(self,
runner,
Expand Down
70 changes: 14 additions & 56 deletions mmrazor/engine/runner/evolution_search_loop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import os.path as osp
import random
Expand All @@ -15,7 +14,6 @@
from mmengine.utils import is_list_of
from torch.utils.data import DataLoader

from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.registry import LOOPS, TASK_UTILS
from mmrazor.structures import Candidates, export_fix_subnet
from mmrazor.utils import SupportRandomSubnet
Expand Down Expand Up @@ -46,8 +44,8 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
crossover_prob (float): The probability of crossover. Defaults to 0.5.
constraints_range (Dict[str, Any]): Constraints to be used for
screening candidates. Defaults to dict(flops=(0, 330)).
resource_estimator_cfg (dict, Optional): Used for building a
resource estimator. Defaults to None.
estimator_cfg (dict, Optional): Used for building a resource estimator.
Defaults to None.
predictor_cfg (dict, Optional): Used for building a metric predictor.
Defaults to None.
score_key (str): Specify one metric in evaluation results to score
Expand All @@ -71,7 +69,7 @@ def __init__(self,
mutate_prob: float = 0.1,
crossover_prob: float = 0.5,
constraints_range: Dict[str, Any] = dict(flops=(0., 330.)),
resource_estimator_cfg: Optional[Dict] = None,
estimator_cfg: Optional[Dict] = None,
predictor_cfg: Optional[Dict] = None,
score_key: str = 'accuracy/top1',
init_candidates: Optional[str] = None) -> None:
Expand Down Expand Up @@ -113,66 +111,26 @@ def __init__(self,
else:
self.model = runner.model

# Build resource estimator.
resource_estimator_cfg = dict(
) if resource_estimator_cfg is None else resource_estimator_cfg
self.estimator = self.build_resource_estimator(resource_estimator_cfg)
# initialize estimator
estimator_cfg = dict() if estimator_cfg is None else estimator_cfg
if 'type' not in estimator_cfg:
estimator_cfg['type'] = 'mmrazor.ResourceEstimator'
self.estimator = TASK_UTILS.build(estimator_cfg)

# initialize predictor
self.use_predictor = False
self.predictor_cfg = predictor_cfg

def build_resource_estimator(
self, resource_estimator: Union[ResourceEstimator,
Dict]) -> ResourceEstimator:
"""Build resource estimator for search loop.
Examples of ``resource_estimator``:
# `ResourceEstimator` will be used
resource_estimator = dict()
# custom resource_estimator
resource_estimator = dict(type='mmrazor.ResourceEstimator')
Args:
resource_estimator (ResourceEstimator or dict): A
resource_estimator or a dict to build resource estimator.
If ``resource_estimator`` is a resource estimator object,
just returns itself.
Returns:
:obj:`ResourceEstimator`: Resource estimator object build from
``resource_estimator``.
"""
if isinstance(resource_estimator, ResourceEstimator):
return resource_estimator
elif not isinstance(resource_estimator, dict):
raise TypeError(
'resource estimator should be a ResourceEstimator object or'
f'dict, but got {resource_estimator}')

resource_estimator_cfg = copy.deepcopy(
resource_estimator) # type: ignore

if 'type' in resource_estimator_cfg:
estimator = TASK_UTILS.build(resource_estimator_cfg)
else:
estimator = ResourceEstimator(
**resource_estimator_cfg) # type: ignore

return estimator # type: ignore
if self.predictor_cfg is not None:
self.predictor_cfg['score_key'] = self.score_key
self.predictor_cfg['search_groups'] = \
self.model.mutator.search_groups
self.predictor = TASK_UTILS.build(self.predictor_cfg)

def run(self) -> None:
"""Launch searching."""
self.runner.call_hook('before_train')

if self.predictor_cfg is not None:
# initialize predictor
self.predictor_cfg['score_key'] = self.score_key
self.predictor_cfg['search_groups'] = \
self.model.mutator.search_groups

self.predictor = TASK_UTILS.build(self.predictor_cfg)
self._init_predictor()

if self.resume_from:
Expand Down
58 changes: 8 additions & 50 deletions mmrazor/engine/runner/subnet_sampler_loop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
import os
import random
Expand All @@ -13,7 +12,6 @@
from mmengine.utils import is_list_of
from torch.utils.data import DataLoader

from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.registry import LOOPS, TASK_UTILS
from mmrazor.structures import Candidates
from mmrazor.utils import SupportRandomSubnet
Expand Down Expand Up @@ -102,8 +100,8 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
candidates. Defaults to 'accuracy_top-1'.
constraints_range (Dict[str, Any]): Constraints to be used for
screening candidates. Defaults to dict(flops=(0, 330)).
resource_estimator_cfg (dict, Optional): Used for building a
resource estimator. Defaults to None.
estimator_cfg (dict, Optional): Used for building a resource estimator.
Defaults to None.
num_candidates (int): The number of the candidates consist of samples
from supernet and itself. Defaults to 1000.
num_samples (int): The number of sample in each sampling subnet.
Expand Down Expand Up @@ -138,7 +136,7 @@ def __init__(self,
val_interval: int = 1000,
score_key: str = 'accuracy/top1',
constraints_range: Dict[str, Any] = dict(flops=(0, 330)),
resource_estimator_cfg: Optional[Dict] = None,
estimator_cfg: Optional[Dict] = None,
num_candidates: int = 1000,
num_samples: int = 10,
top_k: int = 5,
Expand Down Expand Up @@ -176,51 +174,11 @@ def __init__(self,
self.candidates = Candidates()
self.top_k_candidates = Candidates()

# Build resource estimator.
resource_estimator_cfg = dict(
) if resource_estimator_cfg is None else resource_estimator_cfg
self.estimator = self.build_resource_estimator(resource_estimator_cfg)

def build_resource_estimator(
self, resource_estimator: Union[ResourceEstimator,
Dict]) -> ResourceEstimator:
"""Build resource estimator for search loop.
Examples of ``resource_estimator``:
# `ResourceEstimator` will be used
resource_estimator = dict()
# custom resource_estimator
resource_estimator = dict(type='mmrazor.ResourceEstimator')
Args:
resource_estimator (ResourceEstimator or dict):
A resource_estimator or a dict to build resource estimator.
If ``resource_estimator`` is a resource estimator object,
just returns itself.
Returns:
:obj:`ResourceEstimator`: Resource estimator object build from
``resource_estimator``.
"""
if isinstance(resource_estimator, ResourceEstimator):
return resource_estimator
elif not isinstance(resource_estimator, dict):
raise TypeError(
'resource estimator should be a ResourceEstimator object or'
f'dict, but got {resource_estimator}')

resource_estimator_cfg = copy.deepcopy(
resource_estimator) # type: ignore

if 'type' in resource_estimator_cfg:
estimator = TASK_UTILS.build(resource_estimator_cfg)
else:
estimator = ResourceEstimator(
**resource_estimator_cfg) # type: ignore

return estimator # type: ignore
# initialize estimator
estimator_cfg = dict() if estimator_cfg is None else estimator_cfg
if 'type' not in estimator_cfg:
estimator_cfg['type'] = 'mmrazor.ResourceEstimator'
self.estimator = TASK_UTILS.build(estimator_cfg)

def run(self) -> None:
"""Launch training."""
Expand Down
6 changes: 4 additions & 2 deletions mmrazor/models/algorithms/nas/dsnas.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ def __init__(self,
**kwargs):
super().__init__(architecture, data_preprocessor, **kwargs)

if estimator_cfg is None:
estimator_cfg = dict(type='mmrazor.ResourceEstimator')
# initialize estimator
estimator_cfg = dict() if estimator_cfg is None else estimator_cfg
if 'type' not in estimator_cfg:
estimator_cfg['type'] = 'mmrazor.ResourceEstimator'
self.estimator = TASK_UTILS.build(estimator_cfg)
if fix_subnet:
# Avoid circular import
Expand Down

0 comments on commit 7b46037

Please sign in to comment.