diff --git a/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py b/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py index 4f5edb316..87553ec39 100644 --- a/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py +++ b/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py @@ -13,5 +13,5 @@ num_mutation=25, num_crossover=25, mutate_prob=0.1, - flops_range=(0., 465.), + constraints_range=dict(flops=(0., 465.)), score_key='accuracy/top1') diff --git a/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py b/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py index f3f963e40..f5a5e88f4 100644 --- a/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py +++ b/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py @@ -13,5 +13,5 @@ num_mutation=25, num_crossover=25, mutate_prob=0.1, - flops_range=(0., 330.), + constraints_range=dict(flops=(0, 330)), score_key='accuracy/top1') diff --git a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py index d1dd1637a..689618362 100644 --- a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py +++ b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py @@ -13,5 +13,5 @@ num_mutation=20, num_crossover=20, mutate_prob=0.1, - flops_range=(0., 300.), + constraints_range=dict(flops=(0, 330)), score_key='coco/bbox_mAP') diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index a9a76b383..009280778 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import os import os.path as osp import random import warnings -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from mmengine import fileio @@ -14,10 +15,10 @@ from torch.utils.data import DataLoader from mmrazor.models.task_modules import ResourceEstimator -from mmrazor.registry import LOOPS +from mmrazor.registry import LOOPS, TASK_UTILS from mmrazor.structures import Candidates, export_fix_subnet from mmrazor.utils import SupportRandomSubnet -from .utils import check_subnet_flops, crossover +from .utils import check_subnet_resources, crossover @LOOPS.register_module() @@ -41,10 +42,11 @@ class EvolutionSearchLoop(EpochBasedTrainLoop): num_crossover (int): The number of candidates got by crossover. Defaults to 25. mutate_prob (float): The probability of mutation. Defaults to 0.1. - flops_range (tuple, optional): It is used for screening candidates. - resource_estimator_cfg (dict): The config for building estimator, which - is be used to estimate the flops of sampled subnet. Defaults to - None, which means default config is used. + crossover_prob (float): The probability of crossover. Defaults to 0.5. + constraints_range (Dict[str, Any]): Constraints to be used for + screening candidates. Defaults to dict(flops=(0, 330)). + resource_estimator_cfg (dict, Optional): Used for building a + resource estimator. Defaults to None. score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. init_candidates (str, optional): The candidates file path, which is @@ -64,8 +66,9 @@ def __init__(self, num_mutation: int = 25, num_crossover: int = 25, mutate_prob: float = 0.1, - flops_range: Optional[Tuple[float, float]] = (0., 330.), - resource_estimator_cfg: Optional[dict] = None, + crossover_prob: float = 0.5, + constraints_range: Dict[str, Any] = dict(flops=(0., 330.)), + resource_estimator_cfg: Optional[Dict] = None, score_key: str = 'accuracy/top1', init_candidates: Optional[str] = None) -> None: super().__init__(runner, dataloader, max_epochs) @@ -83,11 +86,12 @@ def __init__(self, self.num_candidates = num_candidates self.top_k = top_k - self.flops_range = flops_range + self.constraints_range = constraints_range self.score_key = score_key self.num_mutation = num_mutation self.num_crossover = num_crossover self.mutate_prob = mutate_prob + self.crossover_prob = crossover_prob self.max_keep_ckpts = max_keep_ckpts self.resume_from = resume_from @@ -99,16 +103,58 @@ def __init__(self, correct init candidates file' self.top_k_candidates = Candidates() - if resource_estimator_cfg is None: - self.estimator = ResourceEstimator() - else: - self.estimator = ResourceEstimator(**resource_estimator_cfg) if self.runner.distributed: self.model = runner.model.module else: self.model = runner.model + # Build resource estimator. + resource_estimator_cfg = dict( + ) if resource_estimator_cfg is None else resource_estimator_cfg + self.estimator = self.build_resource_estimator(resource_estimator_cfg) + + def build_resource_estimator( + self, resource_estimator: Union[ResourceEstimator, + Dict]) -> ResourceEstimator: + """Build resource estimator for search loop. + + Examples of ``resource_estimator``: + + # `ResourceEstimator` will be used + resource_estimator = dict() + + # custom resource_estimator + resource_estimator = dict(type='mmrazor.ResourceEstimator') + + Args: + resource_estimator (ResourceEstimator or dict): A + resource_estimator or a dict to build resource estimator. + If ``resource_estimator`` is a resource estimator object, + just returns itself. + + Returns: + :obj:`ResourceEstimator`: Resource estimator object build from + ``resource_estimator``. + """ + if isinstance(resource_estimator, ResourceEstimator): + return resource_estimator + elif not isinstance(resource_estimator, dict): + raise TypeError( + 'resource estimator should be a ResourceEstimator object or' + f'dict, but got {resource_estimator}') + + resource_estimator_cfg = copy.deepcopy( + resource_estimator) # type: ignore + + if 'type' in resource_estimator_cfg: + estimator = TASK_UTILS.build(resource_estimator_cfg) + else: + estimator = ResourceEstimator( + **resource_estimator_cfg) # type: ignore + + return estimator # type: ignore + def run(self) -> None: """Launch searching.""" self.runner.call_hook('before_train') @@ -144,33 +190,49 @@ def run_epoch(self) -> None: f'{scores_before}') self.candidates.extend(self.top_k_candidates) - self.candidates.sort(key=lambda x: x[1], reverse=True) - self.top_k_candidates = Candidates(self.candidates[:self.top_k]) + self.candidates.sort_by(key_indicator='score', reverse=True) + self.top_k_candidates = Candidates(self.candidates.data[:self.top_k]) scores_after = self.top_k_candidates.scores self.runner.logger.info(f'top k scores after update: ' f'{scores_after}') mutation_candidates = self.gen_mutation_candidates() + self.candidates_mutator_crossover = Candidates(mutation_candidates) crossover_candidates = self.gen_crossover_candidates() - candidates = mutation_candidates + crossover_candidates - assert len(candidates) <= self.num_candidates, 'Total of mutation and \ - crossover should be no more than the number of candidates.' + self.candidates_mutator_crossover.extend(crossover_candidates) - self.candidates = Candidates(candidates) + assert len(self.candidates_mutator_crossover + ) <= self.num_candidates, 'Total of mutation and \ + crossover should be less than the number of candidates.' + + self.candidates = self.candidates_mutator_crossover self._epoch += 1 def sample_candidates(self) -> None: """Update candidate pool contains specified number of candicates.""" + candidates_resources = [] + init_candidates = len(self.candidates) if self.runner.rank == 0: while len(self.candidates) < self.num_candidates: candidate = self.model.sample_subnet() - if self._check_constraints(random_subnet=candidate): + is_pass, result = self._check_constraints( + random_subnet=candidate) + if is_pass: self.candidates.append(candidate) + candidates_resources.append(result) + self.candidates = Candidates(self.candidates.data) else: - self.candidates = Candidates([None] * self.num_candidates) + self.candidates = Candidates([dict()] * self.num_candidates) + + if len(candidates_resources) > 0: + self.candidates.update_resources( + candidates_resources, + start=len(self.candidates.data) - len(candidates_resources)) # broadcast candidates to val with multi-GPUs. broadcast_object_list(self.candidates.data) + assert init_candidates + len( + candidates_resources) == self.num_candidates def update_candidates_scores(self) -> None: """Validate candicate one by one from the candicate pool, and update @@ -180,14 +242,18 @@ def update_candidates_scores(self) -> None: metrics = self._val_candidate() score = metrics[self.score_key] \ if len(metrics) != 0 else 0. - self.candidates.set_score(i, score) + self.candidates.set_resource(i, score, 'score') self.runner.logger.info( f'Epoch:[{self._epoch}/{self._max_epochs}] ' f'Candidate:[{i + 1}/{self.num_candidates}] ' - f'Score:{score}') + f'Flops: {self.candidates.resources("flops")[i]} ' + f'Params: {self.candidates.resources("params")[i]} ' + f'Latency: {self.candidates.resources("latency")[i]} ' + f'Score: {self.candidates.scores} ') - def gen_mutation_candidates(self) -> List: + def gen_mutation_candidates(self): """Generate specified number of mutation candicates.""" + mutation_resources = [] mutation_candidates: List = [] max_mutate_iters = self.num_mutation * 10 mutate_iter = 0 @@ -198,12 +264,20 @@ def gen_mutation_candidates(self) -> List: mutation_candidate = self._mutation() - if self._check_constraints(random_subnet=mutation_candidate): + is_pass, result = self._check_constraints( + random_subnet=mutation_candidate) + if is_pass: mutation_candidates.append(mutation_candidate) + mutation_resources.append(result) + + mutation_candidates = Candidates(mutation_candidates) + mutation_candidates.update_resources(mutation_resources) + return mutation_candidates - def gen_crossover_candidates(self) -> List: + def gen_crossover_candidates(self): """Generate specofied number of crossover candicates.""" + crossover_resources = [] crossover_candidates: List = [] crossover_iter = 0 max_crossover_iters = self.num_crossover * 10 @@ -214,8 +288,15 @@ def gen_crossover_candidates(self) -> List: crossover_candidate = self._crossover() - if self._check_constraints(random_subnet=crossover_candidate): + is_pass, result = self._check_constraints( + random_subnet=crossover_candidate) + if is_pass: crossover_candidates.append(crossover_candidate) + crossover_resources.append(result) + + crossover_candidates = Candidates(crossover_candidates) + crossover_candidates.update_resources(crossover_resources) + return crossover_candidates def _mutation(self) -> SupportRandomSubnet: @@ -229,7 +310,7 @@ def _crossover(self) -> SupportRandomSubnet: """Crossover.""" candidate1 = random.choice(self.top_k_candidates.subnets) candidate2 = random.choice(self.top_k_candidates.subnets) - candidate = crossover(candidate1, candidate2) + candidate = crossover(candidate1, candidate2, prob=self.crossover_prob) return candidate def _resume(self): @@ -263,7 +344,7 @@ def _val_candidate(self) -> Dict: self.runner.model.eval() for data_batch in self.dataloader: outputs = self.runner.model.val_step(data_batch) - self.evaluator.process(data_samples=outputs, data_batch=data_batch) + self.evaluator.process(outputs, data_batch) metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) return metrics @@ -295,16 +376,17 @@ def _save_searcher_ckpt(self) -> None: if osp.isfile(ckpt_path): os.remove(ckpt_path) - def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: + def _check_constraints( + self, random_subnet: SupportRandomSubnet) -> Tuple[bool, Dict]: """Check whether is beyond constraints. Returns: - bool: The result of checking. + bool, result: The result of checking. """ - is_pass = check_subnet_flops( + is_pass, results = check_subnet_resources( model=self.model, subnet=random_subnet, estimator=self.estimator, - flops_range=self.flops_range) + constraints_range=self.constraints_range) - return is_pass + return is_pass, results diff --git a/mmrazor/engine/runner/subnet_sampler_loop.py b/mmrazor/engine/runner/subnet_sampler_loop.py index 1127aab21..273561568 100644 --- a/mmrazor/engine/runner/subnet_sampler_loop.py +++ b/mmrazor/engine/runner/subnet_sampler_loop.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import math import os import random from abc import abstractmethod -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch from mmengine import fileio @@ -13,10 +14,10 @@ from torch.utils.data import DataLoader from mmrazor.models.task_modules import ResourceEstimator -from mmrazor.registry import LOOPS +from mmrazor.registry import LOOPS, TASK_UTILS from mmrazor.structures import Candidates from mmrazor.utils import SupportRandomSubnet -from .utils import check_subnet_flops +from .utils import check_subnet_resources class BaseSamplerTrainLoop(IterBasedTrainLoop): @@ -77,18 +78,15 @@ def run_iter(self, data_batch: Sequence[dict]) -> None: @LOOPS.register_module() class GreedySamplerTrainLoop(BaseSamplerTrainLoop): """IterBasedTrainLoop for greedy sampler. - In GreedySamplerTrainLoop, `Greedy` means that only use some top sampled candidates to train the supernet. So GreedySamplerTrainLoop mainly picks the top candidates based on their val socres, then use them to train the supernet one by one. - Steps: 1. Sample from the supernet and the candidates. 2. Validate these sampled candidates to get each candidate's score. 3. Get top-k candidates based on their scores, then use them to train the supernet one by one. - Args: runner (Runner): A reference of runner. dataloader (Dataloader or dict): A dataloader object or a dict to @@ -102,10 +100,10 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop): val_interval (int): Validation interval. Defaults to 1000. score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. - flops_range (dict): Constraints to be used for screening candidates. - resource_estimator_cfg (dict): The config for building estimator, which - is be used to estimate the flops of sampled subnet. Defaults to - None, which means default config is used. + constraints_range (Dict[str, Any]): Constraints to be used for + screening candidates. Defaults to dict(flops=(0, 330)). + resource_estimator_cfg (dict, Optional): Used for building a + resource estimator. Defaults to None. num_candidates (int): The number of the candidates consist of samples from supernet and itself. Defaults to 1000. num_samples (int): The number of sample in each sampling subnet. @@ -139,8 +137,8 @@ def __init__(self, val_begin: int = 1, val_interval: int = 1000, score_key: str = 'accuracy/top1', - flops_range: Optional[Tuple[float, float]] = (0., 330), - resource_estimator_cfg: Optional[dict] = None, + constraints_range: Dict[str, Any] = dict(flops=(0, 330)), + resource_estimator_cfg: Optional[Dict] = None, num_candidates: int = 1000, num_samples: int = 10, top_k: int = 5, @@ -163,7 +161,7 @@ def __init__(self, self.evaluator = evaluator self.score_key = score_key - self.flops_range = flops_range + self.constraints_range = constraints_range self.num_candidates = num_candidates self.num_samples = num_samples self.top_k = top_k @@ -177,10 +175,52 @@ def __init__(self, self.candidates = Candidates() self.top_k_candidates = Candidates() - if resource_estimator_cfg is None: - self.estimator = ResourceEstimator() + + # Build resource estimator. + resource_estimator_cfg = dict( + ) if resource_estimator_cfg is None else resource_estimator_cfg + self.estimator = self.build_resource_estimator(resource_estimator_cfg) + + def build_resource_estimator( + self, resource_estimator: Union[ResourceEstimator, + Dict]) -> ResourceEstimator: + """Build resource estimator for search loop. + + Examples of ``resource_estimator``: + + # `ResourceEstimator` will be used + resource_estimator = dict() + + # custom resource_estimator + resource_estimator = dict(type='mmrazor.ResourceEstimator') + + Args: + resource_estimator (ResourceEstimator or dict): + A resource_estimator or a dict to build resource estimator. + If ``resource_estimator`` is a resource estimator object, + just returns itself. + + Returns: + :obj:`ResourceEstimator`: Resource estimator object build from + ``resource_estimator``. + """ + if isinstance(resource_estimator, ResourceEstimator): + return resource_estimator + elif not isinstance(resource_estimator, dict): + raise TypeError( + 'resource estimator should be a ResourceEstimator object or' + f'dict, but got {resource_estimator}') + + resource_estimator_cfg = copy.deepcopy( + resource_estimator) # type: ignore + + if 'type' in resource_estimator_cfg: + estimator = TASK_UTILS.build(resource_estimator_cfg) else: - self.estimator = ResourceEstimator(**resource_estimator_cfg) + estimator = ResourceEstimator( + **resource_estimator_cfg) # type: ignore + + return estimator # type: ignore def run(self) -> None: """Launch training.""" @@ -230,9 +270,11 @@ def sample_subnet(self) -> SupportRandomSubnet: self.update_candidates_scores() - self.candidates.sort(key=lambda x: x[1], reverse=True) - self.candidates = Candidates(self.candidates[:self.num_candidates]) - self.top_k_candidates = Candidates(self.candidates[:self.top_k]) + self.candidates.sort_by(key_indicator='score', reverse=True) + self.candidates = Candidates( + self.candidates.data[:self.num_candidates]) + self.top_k_candidates = Candidates( + self.candidates.data[:self.top_k]) top1_score = self.top_k_candidates.scores[0] if (self._iter % self.val_interval) < self.top_k: @@ -243,7 +285,7 @@ def sample_subnet(self) -> SupportRandomSubnet: f'{num_sample_from_supernet}/{self.num_samples} ' f'top1_score {top1_score:.3f} ' f'cur_num_candidates: {len(self.candidates)}') - return self.top_k_candidates.pop(0)[0] + return self.top_k_candidates.subnets[0] def update_cur_prob(self, cur_iter: int) -> None: """update current probablity of sampling from the candidates, which is @@ -278,7 +320,8 @@ def get_candidates_with_sample(self, for _ in range(num_samples): if random.random() >= self.cur_prob or len(self.candidates) == 0: subnet = self._sample_from_supernet() - if self._check_constraints(subnet): + is_pass, _ = self._check_constraints(subnet) + if is_pass: sampled_candidates.append(subnet) num_sample_from_supernet += 1 else: @@ -292,7 +335,7 @@ def update_candidates_scores(self) -> None: self.model.set_subnet(candidate) metrics = self._val_candidate() score = metrics[self.score_key] if len(metrics) != 0 else 0. - self.candidates.set_score(i, score) + self.candidates.set_resource(i, score, 'score') @torch.no_grad() def _val_candidate(self) -> Dict: @@ -312,22 +355,22 @@ def _sample_from_supernet(self) -> SupportRandomSubnet: def _sample_from_candidates(self) -> SupportRandomSubnet: """Sample from the candidates.""" assert len(self.candidates) > 0 - subnet = random.choice(self.candidates) + subnet = random.choice(self.candidates.data) return subnet - def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: + def _check_constraints(self, random_subnet: SupportRandomSubnet): """Check whether is beyond constraints. Returns: - bool: The result of checking. + bool, result: The result of checking. """ - is_pass = check_subnet_flops( + is_pass, results = check_subnet_resources( model=self.model, subnet=random_subnet, estimator=self.estimator, - flops_range=self.flops_range) + constraints_range=self.constraints_range) - return is_pass + return is_pass, results def _save_candidates(self) -> None: """Save the candidates to init the next searching.""" diff --git a/mmrazor/engine/runner/utils/__init__.py b/mmrazor/engine/runner/utils/__init__.py index ec2f2cb29..557002e2c 100644 --- a/mmrazor/engine/runner/utils/__init__.py +++ b/mmrazor/engine/runner/utils/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .check import check_subnet_flops +from .check import check_subnet_resources from .genetic import crossover -__all__ = ['crossover', 'check_subnet_flops'] +__all__ = ['crossover', 'check_subnet_resources'] diff --git a/mmrazor/engine/runner/utils/check.py b/mmrazor/engine/runner/utils/check.py index e2fdcfcc6..b1c581708 100644 --- a/mmrazor/engine/runner/utils/check.py +++ b/mmrazor/engine/runner/utils/check.py @@ -1,8 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -from typing import Optional, Tuple - -import torch.nn as nn +from typing import Any, Dict, Tuple from mmrazor.models import ResourceEstimator from mmrazor.structures import export_fix_subnet, load_fix_subnet @@ -15,18 +13,19 @@ BaseDetector = get_placeholder('mmdet') -def check_subnet_flops( - model: nn.Module, - subnet: SupportRandomSubnet, - estimator: ResourceEstimator, - flops_range: Optional[Tuple[float, float]] = None) -> bool: - """Check whether is beyond flops constraints. +def check_subnet_resources( + model, + subnet: SupportRandomSubnet, + estimator: ResourceEstimator, + constraints_range: Dict[str, Any] = dict(flops=(0, 330)) +) -> Tuple[bool, Dict]: + """Check whether is beyond resources constraints. Returns: - bool: The result of checking. + bool, result: The result of checking. """ - if flops_range is None: - return True + if constraints_range is None: + return True, dict() assert hasattr(model, 'set_subnet') and hasattr(model, 'architecture') model.set_subnet(subnet) @@ -40,9 +39,9 @@ def check_subnet_flops( else: results = estimator.estimate(model=model_to_check) - flops = results['flops'] - flops_mix, flops_max = flops_range - if flops_mix <= flops <= flops_max: # type: ignore - return True - else: - return False + for k, v in constraints_range.items(): + if not isinstance(v, (list, tuple)): + v = (0, v) + if results[k] < v[0] or results[k] > v[1]: + return False, results + return True, results diff --git a/mmrazor/structures/subnet/candidate.py b/mmrazor/structures/subnet/candidate.py index f65f0b48b..ca0c8236e 100644 --- a/mmrazor/structures/subnet/candidate.py +++ b/mmrazor/structures/subnet/candidate.py @@ -1,35 +1,44 @@ # Copyright (c) OpenMMLab. All rights reserved. from collections import UserList -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union class Candidates(UserList): - """The data structure of sampled candidate. The format is [(any, float), - (any, float), ...]. - + """The data structure of sampled candidate. The format is Union[Dict[str, + Dict], List[Dict[str, Dict]]]. Examples: >>> candidates = Candidates() - >>> subnet_1 = {'choice_1': 'layer_1', 'choice_2': 'layer_2'} + >>> subnet_1 = {'1': 'choice1', '2': 'choice2'} >>> candidates.append(subnet_1) >>> candidates - [({'choice_1': 'layer_1', 'choice_2': 'layer_2'}, 0.0)] - >>> candidates.set_score(0, 0.9) + [{"{'1': 'choice1', '2': 'choice2'}": + {'score': 0.0, 'flops': 0.0, 'params': 0.0, 'latency': 0.0}}] + >>> candidates.set_resources(0, 49.9, 'flops') + >>> candidates.set_score(0, 100.) >>> candidates - [({'choice_1': 'layer_1', 'choice_2': 'layer_2'}, 0.9)] + [{"{'1': 'choice1', '2': 'choice2'}": + {'score': 100.0, 'flops': 49.9, 'params': 0.0, 'latency': 0.0}}] >>> subnet_2 = {'choice_3': 'layer_3', 'choice_4': 'layer_4'} - >>> candidates.append((subnet_2, 0.5)) + >>> candidates.append(subnet_2) >>> candidates - [({'choice_1': 'layer_1', 'choice_2': 'layer_2'}, 0.9), - ({'choice_3': 'layer_3', 'choice_4': 'layer_4'}, 0.5)] + [{"{'1': 'choice1', '2': 'choice2'}": + {'score': 100.0, 'flops': 49.9, 'params': 0.0, 'latency': 0.0}}, + {"{'choice_3': 'layer_3', 'choice_4':'layer_4'}": + {'score': 0.0, 'flops': 0.0, 'params': 0.0, 'latency': 0.0}}] >>> candidates.subnets - [{'choice_1': 'layer_1', 'choice_2': 'layer_2'}, + [{'1': 'choice1', '2': 'choice2'}, {'choice_3': 'layer_3', 'choice_4': 'layer_4'}] + >>> candidates.resources('flops') + [49.9, 0.0] >>> candidates.scores - [0.9, 0.5] + [100.0, 0.0] """ - _format_return = Union[Tuple[Any, float], List[Tuple[Any, float]]] + _format_return = Union[Dict[str, Dict], List[Dict[str, Dict]]] + _format_input = Union[Dict, List[Dict], Dict[str, Dict], List[Dict[str, + Dict]]] + _indicators = ('score', 'flops', 'params', 'latency') - def __init__(self, initdata: Optional[Any] = None): + def __init__(self, initdata: Optional[_format_input] = None): self.data = [] if initdata is not None: initdata = self._format(initdata) @@ -41,23 +50,59 @@ def __init__(self, initdata: Optional[Any] = None): @property def scores(self) -> List[float]: """The scores of candidates.""" - return [item[1] for item in self.data] + return [ + value.get('score', 0.) for item in self.data + for _, value in item.items() + ] + + def resources(self, key_indicator: str = 'flops') -> List[float]: + """The resources of candidates.""" + assert key_indicator in ['flops', 'params', 'latency'] + return [ + value.get(key_indicator, 0.) for item in self.data + for _, value in item.items() + ] @property def subnets(self) -> List[Dict]: """The subnets of candidates.""" - return [item[0] for item in self.data] + return [eval(key) for item in self.data for key, _ in item.items()] - def _format(self, data: Any) -> _format_return: - """Transform [any, ...] to [tuple(any, float), ...] Transform any to - tuple(any, float).""" + def _format(self, data: _format_input) -> _format_return: + """Transform [Dict, ...] to Union[Dict[str, Dict], List[Dict[str, + Dict]]]. - def _format_item(item: Any): - """Transform any to tuple(any, float).""" - if isinstance(item, tuple): - return (item[0], float(item[1])) + Args: + data: Four types of input are supported: + 1. Dict: only include network information. + 2. List[Dict]: multiple candidates only include network + information. + 3. Dict[str, Dict]: network information and the corresponding + resources. + 4. List[Dict[str, Dict]]: multiple candidate information. + Returns: + Union[Dict[str, Dict], UserList[Dict[str, Dict]]]: + A dict or a list of dict that contains a pair of network + information and the corresponding Score | FLOPs | Params | + Latency results in each candidate. + Notes: + Score | FLOPs | Params | Latency: + 1. a candidate resources with a default value of -1 indicates + that it has not been estimated. + 2. a candidate resources with a default value of 0 indicates + that some indicators have been evaluated. + """ + + def _format_item( + cond: Union[Dict, Dict[str, Dict]]) -> Dict[str, Dict]: + """Transform Dict to Dict[str, Dict].""" + if isinstance(list(cond.values())[0], dict): + for value in list(cond.values()): + for key in list(self._indicators): + value.setdefault(key, 0.) + return cond else: - return (item, 0.) + return {str(cond): {}.fromkeys(self._indicators, -1)} if isinstance(data, UserList): return [_format_item(i) for i in data.data] @@ -68,12 +113,15 @@ def _format_item(item: Any): else: return _format_item(data) - def append(self, item: Any) -> None: + def append(self, item: _format_input) -> None: """Append operation.""" item = self._format(item) - self.data.append(item) + if isinstance(item, list): + self.data = self.data + item + else: + self.data.append(item) - def insert(self, i: int, item: Any) -> None: + def insert(self, i: int, item: _format_input) -> None: """Insert operation.""" item = self._format(item) self.data.insert(i, item) @@ -88,4 +136,35 @@ def extend(self, other: Any) -> None: def set_score(self, i: int, score: float) -> None: """Set score to the specified subnet by index.""" - self.data[i] = (self.data[i][0], float(score)) + self.set_resource(i, score, 'score') + + def set_resource(self, + i: int, + resources: float, + key_indicator: str = 'flops') -> None: + """Set resources to the specified subnet by index.""" + assert key_indicator in ['score', 'flops', 'params', 'latency'] + for _, value in self.data[i].items(): + value[key_indicator] = resources + + def update_resources(self, resources: list, start: int = 0) -> None: + """Update resources to the specified candidate.""" + end = start + len(resources) + assert len( + self.data) >= end, 'Check the number of candidate resources.' + for i, item in enumerate(self.data[start:end]): + for _, value in item.items(): + value.update(resources[i]) + + def sort_by(self, + key_indicator: str = 'score', + reverse: bool = True) -> None: + """Sort by a specific indicator in descending order. + + Args: + key_indicator (str): sort all candidates by key_indicator. + Defaults to 'score'. + reverse (bool): sort all candidates in descending order. + """ + self.data.sort( + key=lambda x: list(x.values())[0][key_indicator], reverse=reverse) diff --git a/tests/test_models/test_subnet/test_candidate.py b/tests/test_models/test_subnet/test_candidate.py index 4cf44846d..7f8bfe640 100644 --- a/tests/test_models/test_subnet/test_candidate.py +++ b/tests/test_models/test_subnet/test_candidate.py @@ -10,7 +10,27 @@ class TestCandidates(TestCase): def setUp(self) -> None: self.fake_subnet = {'1': 'choice1', '2': 'choice2'} - self.fake_subnet_with_score = (self.fake_subnet, 1.) + self.fake_subnet_with_resource = { + str(self.fake_subnet): { + 'score': 0., + 'flops': 50., + 'params': 0., + 'latency': 0. + } + } + self.fake_subnet_with_score = { + str(self.fake_subnet): { + 'score': 99., + 'flops': 0., + 'params': 0., + 'latency': 0. + } + } + self.has_flops_network = { + str(self.fake_subnet): { + 'flops': 50., + } + } def test_init(self): # initlist is None @@ -23,16 +43,25 @@ def test_init(self): # initlist is UserList data = UserList([self.fake_subnet] * 2) self.assertEqual(len(candidates.data), 2) + self.assertEqual(candidates.resources('flops'), [-1, -1]) + # initlist is list(Dict[str, Dict]) + candidates = Candidates([self.has_flops_network] * 2) + self.assertEqual(candidates.resources('flops'), [50., 50.]) def test_scores(self): # test property: scores data = [self.fake_subnet_with_score] * 2 candidates = Candidates(data) - self.assertEqual(candidates.scores, [1., 1.]) + self.assertEqual(candidates.scores, [99., 99.]) + + def test_resources(self): + data = [self.fake_subnet_with_resource] * 2 + candidates = Candidates(data) + self.assertEqual(candidates.resources('flops'), [50., 50.]) def test_subnets(self): # test property: subnets - data = [self.fake_subnet_with_score] * 2 + data = [self.fake_subnet] * 2 candidates = Candidates(data) self.assertEqual(candidates.subnets, [self.fake_subnet] * 2) @@ -41,17 +70,20 @@ def test_append(self): candidates = Candidates() candidates.append(self.fake_subnet) self.assertEqual(len(candidates), 1) - # item is tuple + # item is List candidates = Candidates() - candidates.append(self.fake_subnet_with_score) - self.assertEqual(len(candidates), 1) + candidates.append([self.fake_subnet_with_score]) + # item is Candidates + candidates_2 = Candidates([self.fake_subnet_with_resource]) + candidates.append(candidates_2) + self.assertEqual(len(candidates), 2) def test_insert(self): # item is dict - candidates = Candidates([self.fake_subnet_with_score]) + candidates = Candidates(self.fake_subnet_with_score) candidates.insert(1, self.fake_subnet) self.assertEqual(len(candidates), 2) - # item is tuple + # item is List candidates = Candidates([self.fake_subnet_with_score]) candidates.insert(1, self.fake_subnet_with_score) self.assertEqual(len(candidates), 2) @@ -61,13 +93,60 @@ def test_extend(self): candidates = Candidates([self.fake_subnet_with_score]) candidates.extend([self.fake_subnet]) self.assertEqual(len(candidates), 2) - # other is UserList + # other is Candidates candidates = Candidates([self.fake_subnet_with_score]) - candidates.extend(UserList([self.fake_subnet_with_score])) + candidates_2 = Candidates([self.fake_subnet_with_resource]) + candidates.extend(candidates_2) + self.assertEqual(len(candidates), 2) + + def test_set_resource(self): + # test set_resource + candidates = Candidates([self.fake_subnet]) + for kk in ['flops', 'params', 'latency']: + self.assertEqual(candidates.resources(kk)[0], -1) + candidates.set_resource(0, 49.9, kk) + self.assertEqual(candidates.resources(kk)[0], 49.9) + candidates.insert(0, self.fake_subnet_with_resource) self.assertEqual(len(candidates), 2) + self.assertEqual(candidates.resources('flops'), [50., 49.9]) + self.assertEqual(candidates.resources('latency'), [0., 49.9]) + candidates = Candidates([self.fake_subnet_with_score]) + candidates.set_resource(0, 100.0, 'score') + self.assertEqual(candidates.scores[0], 100.) + candidates = Candidates([self.fake_subnet_with_score]) + candidates.set_resource(0, 100.0, 'score') + candidates.extend(UserList([self.fake_subnet_with_resource])) + candidates.set_resource(1, 99.9, 'score') + self.assertEqual(candidates.scores, [100., 99.9]) + + def test_update_resources(self): + # test update_resources + candidates = Candidates([self.fake_subnet]) + candidates.append([self.fake_subnet_with_score]) + candidates_2 = Candidates(self.fake_subnet_with_resource) + candidates.append(candidates_2) + self.assertEqual(len(candidates), 3) + self.assertEqual(candidates.resources('flops'), [-1, 0., 50.]) + self.assertEqual(candidates.resources('latency'), [-1, 0., 0.]) + resources = [{'flops': -2}, {'latency': 4.}] + candidates.update_resources(resources, start=1) + self.assertEqual(candidates.resources('flops'), [-1, -2, 50.]) + self.assertEqual(candidates.resources('latency'), [-1, 0., 4]) + candidates.update_resources(resources, start=0) + self.assertEqual(candidates.resources('flops'), [-2, -2, 50.]) + self.assertEqual(candidates.resources('latency'), [-1, 4., 4.]) - def test_set_score(self): - # test set_score + def test_sort(self): + # test set_sort candidates = Candidates([self.fake_subnet_with_score]) - candidates.set_score(0, 0.5) - self.assertEqual(candidates[0][1], 0.5) + candidates.extend(UserList([self.fake_subnet_with_resource])) + candidates.insert(0, self.fake_subnet) + candidates.set_resource(0, 100., 'score') + candidates.set_resource(2, 98., 'score') + self.assertEqual(candidates.scores, [100., 99., 98.]) + candidates.sort_by(key_indicator='score', reverse=False) + self.assertEqual(candidates.scores, [98., 99., 100.]) + candidates.sort_by(key_indicator='latency') + self.assertEqual(candidates.scores, [98., 99., 100.]) + candidates.sort_by(key_indicator='flops', reverse=False) + self.assertEqual(candidates.scores, [100., 99., 98.]) diff --git a/tests/test_runners/test_evolution_search_loop.py b/tests/test_runners/test_evolution_search_loop.py index f30019274..6d8814a7b 100644 --- a/tests/test_runners/test_evolution_search_loop.py +++ b/tests/test_runners/test_evolution_search_loop.py @@ -82,7 +82,7 @@ def setUp(self): num_mutation=2, num_crossover=2, mutate_prob=0.1, - flops_range=None, + constraints_range=dict(flops=(0, 330)), score_key='coco/bbox_mAP') self.train_cfg = Config(train_cfg) self.runner = MagicMock(spec=ToyRunner) @@ -103,7 +103,7 @@ def test_init(self): # test init_candidates is not None fake_subnet = {'1': 'choice1', '2': 'choice2'} - fake_candidates = Candidates((fake_subnet, 0.)) + fake_candidates = Candidates(fake_subnet) init_candidates_path = os.path.join(self.temp_dir, 'candidates.yaml') fileio.dump(fake_candidates, init_candidates_path) loop_cfg.init_candidates = init_candidates_path @@ -111,8 +111,12 @@ def test_init(self): self.assertIsInstance(loop, EvolutionSearchLoop) self.assertEqual(loop.candidates, fake_candidates) - @patch('mmrazor.engine.runner.evolution_search_loop.export_fix_subnet') - def test_run_epoch(self, mock_export_fix_subnet): + @patch('mmrazor.engine.runner.utils.check.load_fix_subnet') + @patch('mmrazor.engine.runner.utils.check.export_fix_subnet') + @patch('mmrazor.models.task_modules.estimators.resource_estimator.' + 'get_model_flops_params') + def test_run_epoch(self, flops_params, mock_export_fix_subnet, + load_status): # test_run_epoch: distributed == False loop_cfg = copy.deepcopy(self.train_cfg) loop_cfg.runner = self.runner @@ -120,20 +124,20 @@ def test_run_epoch(self, mock_export_fix_subnet): loop_cfg.evaluator = self.evaluator loop = LOOPS.build(loop_cfg) self.runner.rank = 0 - loop._epoch = 1 self.runner.distributed = False self.runner.work_dir = self.temp_dir fake_subnet = {'1': 'choice1', '2': 'choice2'} - self.runner.model.sample_subnet = MagicMock(return_value=fake_subnet) + loop.model.sample_subnet = MagicMock(return_value=fake_subnet) + load_status.return_value = True + flops_params.return_value = 0, 0 loop.run_epoch() self.assertEqual(len(loop.candidates), 4) self.assertEqual(len(loop.top_k_candidates), 2) - self.assertEqual(loop._epoch, 2) + self.assertEqual(loop._epoch, 1) # test_run_epoch: distributed == True loop = LOOPS.build(loop_cfg) self.runner.rank = 0 - loop._epoch = 1 self.runner.distributed = True self.runner.work_dir = self.temp_dir fake_subnet = {'1': 'choice1', '2': 'choice2'} @@ -141,26 +145,27 @@ def test_run_epoch(self, mock_export_fix_subnet): loop.run_epoch() self.assertEqual(len(loop.candidates), 4) self.assertEqual(len(loop.top_k_candidates), 2) - self.assertEqual(loop._epoch, 2) + self.assertEqual(loop._epoch, 1) # test_check_constraints - loop_cfg.flops_range = (0, 100) + loop_cfg.constraints_range = dict(params=(0, 100)) loop = LOOPS.build(loop_cfg) self.runner.rank = 0 - loop._epoch = 1 self.runner.distributed = True self.runner.work_dir = self.temp_dir fake_subnet = {'1': 'choice1', '2': 'choice2'} loop.model.sample_subnet = MagicMock(return_value=fake_subnet) - loop._check_constraints = MagicMock(return_value=True) + flops_params.return_value = (50., 1) mock_export_fix_subnet.return_value = fake_subnet loop.run_epoch() self.assertEqual(len(loop.candidates), 4) self.assertEqual(len(loop.top_k_candidates), 2) - self.assertEqual(loop._epoch, 2) + self.assertEqual(loop._epoch, 1) - @patch('mmrazor.engine.runner.evolution_search_loop.export_fix_subnet') - def test_run(self, mock_export_fix_subnet): + @patch('mmrazor.engine.runner.utils.check.export_fix_subnet') + @patch('mmrazor.models.task_modules.estimators.resource_estimator.' + 'get_model_flops_params') + def test_run_loop(self, mock_flops, mock_export_fix_subnet): # test a new search: resume == None loop_cfg = copy.deepcopy(self.train_cfg) loop_cfg.runner = self.runner @@ -169,16 +174,26 @@ def test_run(self, mock_export_fix_subnet): loop = LOOPS.build(loop_cfg) self.runner.rank = 0 loop._epoch = 1 + fake_subnet = {'1': 'choice1', '2': 'choice2'} self.runner.work_dir = self.temp_dir loop.update_candidate_pool = MagicMock() loop.val_candidate_pool = MagicMock() + + mutation_candidates = Candidates([fake_subnet] * loop.num_mutation) + for i in range(loop.num_mutation): + mutation_candidates.set_resource(i, 0.1 + 0.1 * i, 'flops') + mutation_candidates.set_resource(i, 99 + i, 'score') + crossover_candidates = Candidates([fake_subnet] * loop.num_crossover) + for i in range(loop.num_crossover): + crossover_candidates.set_resource(i, 0.1 + 0.1 * i, 'flops') + crossover_candidates.set_resource(i, 99 + i, 'score') loop.gen_mutation_candidates = \ - MagicMock(return_value=[fake_subnet]*loop.num_mutation) + MagicMock(return_value=mutation_candidates) loop.gen_crossover_candidates = \ - MagicMock(return_value=[fake_subnet]*loop.num_crossover) - loop.top_k_candidates = Candidates([(fake_subnet, 1.0), - (fake_subnet, 0.9)]) + MagicMock(return_value=crossover_candidates) + loop.candidates = Candidates([fake_subnet] * 4) + mock_flops.return_value = (0.5, 101) mock_export_fix_subnet.return_value = fake_subnet loop.run() assert os.path.exists( diff --git a/tests/test_runners/test_subnet_sampler_loop.py b/tests/test_runners/test_subnet_sampler_loop.py index fca29b823..1c9422fc1 100644 --- a/tests/test_runners/test_subnet_sampler_loop.py +++ b/tests/test_runners/test_subnet_sampler_loop.py @@ -119,7 +119,7 @@ def setUp(self): max_iters=12, val_interval=2, score_key='acc', - flops_range=None, + constraints_range=None, num_candidates=4, num_samples=2, top_k=2, @@ -190,7 +190,7 @@ def test_sample_subnet(self): loop._iter = loop.val_interval subnet = loop.sample_subnet() self.assertEqual(subnet, fake_subnet) - self.assertEqual(len(loop.top_k_candidates), loop.top_k - 1) + self.assertEqual(len(loop.top_k_candidates), loop.top_k) def test_run(self): # test run with _check_constraints @@ -200,7 +200,7 @@ def test_run(self): fake_subnet = {'1': 'choice1', '2': 'choice2'} runner.model.sample_subnet = MagicMock(return_value=fake_subnet) loop = runner.build_train_loop(cfg.train_cfg) - loop._check_constraints = MagicMock(return_value=True) + loop._check_constraints = MagicMock(return_value=(True, dict())) runner.train() self.assertEqual(runner.iter, runner.max_iters) diff --git a/tests/test_runners/test_utils/test_check.py b/tests/test_runners/test_utils/test_check.py index b9bd57989..2f3a80eaa 100644 --- a/tests/test_runners/test_utils/test_check.py +++ b/tests/test_runners/test_utils/test_check.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest.mock import patch -from mmrazor.engine.runner.utils import check_subnet_flops +from mmrazor.engine.runner.utils import check_subnet_resources try: from mmdet.models.detectors import BaseDetector @@ -12,29 +12,33 @@ @patch('mmrazor.models.ResourceEstimator') @patch('mmrazor.models.SPOS') -def test_check_subnet_flops(mock_model, mock_estimator): - # flops_range = None - flops_range = None +def test_check_subnet_resources(mock_model, mock_estimator): + # constraints_range = dict() + constraints_range = dict() fake_subnet = {'1': 'choice1', '2': 'choice2'} - result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, - flops_range) - assert result is True + is_pass, _ = check_subnet_resources(mock_model, fake_subnet, + mock_estimator, constraints_range) + assert is_pass is True - # flops_range is not None + # constraints_range is not None # architecturte is BaseDetector - flops_range = (0., 100.) + constraints_range = dict(flops=(0, 330)) mock_model.architecture = BaseDetector fake_results = {'flops': 50.} mock_estimator.estimate.return_value = fake_results - result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, - flops_range) - assert result is True + is_pass, _ = check_subnet_resources( + mock_model, + fake_subnet, + mock_estimator, + constraints_range, + ) + assert is_pass is True - # flops_range is not None + # constraints_range is not None # architecturte is BaseDetector - flops_range = (0., 100.) + constraints_range = dict(flops=(0, 330)) fake_results = {'flops': -50.} mock_estimator.estimate.return_value = fake_results - result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, - flops_range) - assert result is False + is_pass, _ = check_subnet_resources(mock_model, fake_subnet, + mock_estimator, constraints_range) + assert is_pass is False