Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improvement] Update Candidate with multi-dim search constraints. #322

Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
num_mutation=25,
num_crossover=25,
mutate_prob=0.1,
flops_range=(0., 465.),
constraints_range=dict(flops=(0., 465.)),
score_key='accuracy/top1')
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
num_mutation=25,
num_crossover=25,
mutate_prob=0.1,
flops_range=(0., 330.),
constraints_range=dict(flops=(0, 330)),
score_key='accuracy/top1')
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
num_mutation=20,
num_crossover=20,
mutate_prob=0.1,
flops_range=(0., 300.),
constraints_range=dict(flops=(0, 330)),
score_key='coco/bbox_mAP')
152 changes: 117 additions & 35 deletions mmrazor/engine/runner/evolution_search_loop.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import os.path as osp
import random
import warnings
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from mmengine import fileio
Expand All @@ -14,10 +15,10 @@
from torch.utils.data import DataLoader

from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.registry import LOOPS
from mmrazor.registry import LOOPS, TASK_UTILS
from mmrazor.structures import Candidates, export_fix_subnet
from mmrazor.utils import SupportRandomSubnet
from .utils import check_subnet_flops, crossover
from .utils import check_subnet_resources, crossover


@LOOPS.register_module()
Expand All @@ -41,10 +42,11 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
num_crossover (int): The number of candidates got by crossover.
Defaults to 25.
mutate_prob (float): The probability of mutation. Defaults to 0.1.
aptsunny marked this conversation as resolved.
Show resolved Hide resolved
flops_range (tuple, optional): It is used for screening candidates.
resource_estimator_cfg (dict): The config for building estimator, which
is be used to estimate the flops of sampled subnet. Defaults to
None, which means default config is used.
crossover_prob (float): The probability of crossover. Defaults to 0.5.
constraints_range (Dict[str, Any]): Constraints to be used for
aptsunny marked this conversation as resolved.
Show resolved Hide resolved
screening candidates. Defaults to dict(flops=(0, 330)).
resource_estimator_cfg (dict, Optional): Used for building a
resource estimator. Defaults to None.
score_key (str): Specify one metric in evaluation results to score
candidates. Defaults to 'accuracy_top-1'.
init_candidates (str, optional): The candidates file path, which is
Expand All @@ -64,8 +66,9 @@ def __init__(self,
num_mutation: int = 25,
num_crossover: int = 25,
mutate_prob: float = 0.1,
flops_range: Optional[Tuple[float, float]] = (0., 330.),
resource_estimator_cfg: Optional[dict] = None,
crossover_prob: float = 0.5,
constraints_range: Dict[str, Any] = dict(flops=(0., 330.)),
resource_estimator_cfg: Optional[Dict] = None,
score_key: str = 'accuracy/top1',
init_candidates: Optional[str] = None) -> None:
super().__init__(runner, dataloader, max_epochs)
Expand All @@ -83,11 +86,12 @@ def __init__(self,

self.num_candidates = num_candidates
self.top_k = top_k
self.flops_range = flops_range
self.constraints_range = constraints_range
self.score_key = score_key
self.num_mutation = num_mutation
self.num_crossover = num_crossover
self.mutate_prob = mutate_prob
self.crossover_prob = crossover_prob
self.max_keep_ckpts = max_keep_ckpts
self.resume_from = resume_from

Expand All @@ -99,16 +103,58 @@ def __init__(self,
correct init candidates file'

self.top_k_candidates = Candidates()
if resource_estimator_cfg is None:
self.estimator = ResourceEstimator()
else:
self.estimator = ResourceEstimator(**resource_estimator_cfg)

if self.runner.distributed:
self.model = runner.model.module
else:
self.model = runner.model

# Build resource estimator.
resource_estimator_cfg = dict(
) if resource_estimator_cfg is None else resource_estimator_cfg
self.estimator = self.build_resource_estimator(resource_estimator_cfg)

def build_resource_estimator(
self, resource_estimator: Union[ResourceEstimator,
Dict]) -> ResourceEstimator:
"""Build resource estimator for search loop.

Examples of ``resource_estimator``:

# `ResourceEstimator` will be used
resource_estimator = dict()

# custom resource_estimator
resource_estimator = dict(type='mmrazor.ResourceEstimator')

Args:
resource_estimator (ResourceEstimator or dict): A
resource_estimator or a dict to build resource estimator.
If ``resource_estimator`` is a resource estimator object,
just returns itself.

Returns:
:obj:`ResourceEstimator`: Resource estimator object build from
``resource_estimator``.
"""
if isinstance(resource_estimator, ResourceEstimator):
return resource_estimator
elif not isinstance(resource_estimator, dict):
raise TypeError(
'resource estimator should be a ResourceEstimator object or'
f'dict, but got {resource_estimator}')

resource_estimator_cfg = copy.deepcopy(
resource_estimator) # type: ignore

if 'type' in resource_estimator_cfg:
estimator = TASK_UTILS.build(resource_estimator_cfg)
else:
estimator = ResourceEstimator(
**resource_estimator_cfg) # type: ignore

return estimator # type: ignore

def run(self) -> None:
"""Launch searching."""
self.runner.call_hook('before_train')
Expand Down Expand Up @@ -144,33 +190,49 @@ def run_epoch(self) -> None:
f'{scores_before}')

self.candidates.extend(self.top_k_candidates)
self.candidates.sort(key=lambda x: x[1], reverse=True)
self.top_k_candidates = Candidates(self.candidates[:self.top_k])
self.candidates.sort_by(key_indicator='score', reverse=True)
self.top_k_candidates = Candidates(self.candidates.data[:self.top_k])

scores_after = self.top_k_candidates.scores
self.runner.logger.info(f'top k scores after update: '
f'{scores_after}')

mutation_candidates = self.gen_mutation_candidates()
self.candidates_mutator_crossover = Candidates(mutation_candidates)
crossover_candidates = self.gen_crossover_candidates()
candidates = mutation_candidates + crossover_candidates
assert len(candidates) <= self.num_candidates, 'Total of mutation and \
crossover should be no more than the number of candidates.'
self.candidates_mutator_crossover.extend(crossover_candidates)

self.candidates = Candidates(candidates)
assert len(self.candidates_mutator_crossover
) <= self.num_candidates, 'Total of mutation and \
crossover should be less than the number of candidates.'

self.candidates = self.candidates_mutator_crossover
self._epoch += 1

def sample_candidates(self) -> None:
"""Update candidate pool contains specified number of candicates."""
candidates_resources = []
init_candidates = len(self.candidates)
if self.runner.rank == 0:
while len(self.candidates) < self.num_candidates:
candidate = self.model.sample_subnet()
if self._check_constraints(random_subnet=candidate):
is_pass, result = self._check_constraints(
random_subnet=candidate)
if is_pass:
self.candidates.append(candidate)
candidates_resources.append(result)
self.candidates = Candidates(self.candidates.data)
else:
self.candidates = Candidates([None] * self.num_candidates)
self.candidates = Candidates([dict()] * self.num_candidates)

if len(candidates_resources) > 0:
self.candidates.update_resources(
candidates_resources,
start=len(self.candidates.data) - len(candidates_resources))
# broadcast candidates to val with multi-GPUs.
broadcast_object_list(self.candidates.data)
assert init_candidates + len(
candidates_resources) == self.num_candidates

def update_candidates_scores(self) -> None:
"""Validate candicate one by one from the candicate pool, and update
Expand All @@ -180,14 +242,18 @@ def update_candidates_scores(self) -> None:
metrics = self._val_candidate()
score = metrics[self.score_key] \
if len(metrics) != 0 else 0.
self.candidates.set_score(i, score)
self.candidates.set_resource(i, score, 'score')
self.runner.logger.info(
f'Epoch:[{self._epoch}/{self._max_epochs}] '
f'Candidate:[{i + 1}/{self.num_candidates}] '
f'Score:{score}')
f'Flops: {self.candidates.resources("flops")[i]} '
f'Params: {self.candidates.resources("params")[i]} '
f'Latency: {self.candidates.resources("latency")[i]} '
f'Score: {self.candidates.scores} ')

def gen_mutation_candidates(self) -> List:
def gen_mutation_candidates(self):
"""Generate specified number of mutation candicates."""
mutation_resources = []
mutation_candidates: List = []
max_mutate_iters = self.num_mutation * 10
mutate_iter = 0
Expand All @@ -198,12 +264,20 @@ def gen_mutation_candidates(self) -> List:

mutation_candidate = self._mutation()

if self._check_constraints(random_subnet=mutation_candidate):
is_pass, result = self._check_constraints(
random_subnet=mutation_candidate)
if is_pass:
mutation_candidates.append(mutation_candidate)
mutation_resources.append(result)

mutation_candidates = Candidates(mutation_candidates)
mutation_candidates.update_resources(mutation_resources)

return mutation_candidates

def gen_crossover_candidates(self) -> List:
def gen_crossover_candidates(self):
"""Generate specofied number of crossover candicates."""
crossover_resources = []
crossover_candidates: List = []
crossover_iter = 0
max_crossover_iters = self.num_crossover * 10
Expand All @@ -214,8 +288,15 @@ def gen_crossover_candidates(self) -> List:

crossover_candidate = self._crossover()

if self._check_constraints(random_subnet=crossover_candidate):
is_pass, result = self._check_constraints(
random_subnet=crossover_candidate)
if is_pass:
crossover_candidates.append(crossover_candidate)
crossover_resources.append(result)

crossover_candidates = Candidates(crossover_candidates)
crossover_candidates.update_resources(crossover_resources)

return crossover_candidates

def _mutation(self) -> SupportRandomSubnet:
Expand All @@ -229,7 +310,7 @@ def _crossover(self) -> SupportRandomSubnet:
"""Crossover."""
candidate1 = random.choice(self.top_k_candidates.subnets)
candidate2 = random.choice(self.top_k_candidates.subnets)
candidate = crossover(candidate1, candidate2)
candidate = crossover(candidate1, candidate2, prob=self.crossover_prob)
return candidate

def _resume(self):
Expand Down Expand Up @@ -263,7 +344,7 @@ def _val_candidate(self) -> Dict:
self.runner.model.eval()
for data_batch in self.dataloader:
outputs = self.runner.model.val_step(data_batch)
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.evaluator.process(outputs, data_batch)
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
return metrics

Expand Down Expand Up @@ -295,16 +376,17 @@ def _save_searcher_ckpt(self) -> None:
if osp.isfile(ckpt_path):
os.remove(ckpt_path)

def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool:
def _check_constraints(
self, random_subnet: SupportRandomSubnet) -> Tuple[bool, Dict]:
"""Check whether is beyond constraints.

Returns:
bool: The result of checking.
bool, result: The result of checking.
"""
is_pass = check_subnet_flops(
is_pass, results = check_subnet_resources(
model=self.model,
subnet=random_subnet,
estimator=self.estimator,
flops_range=self.flops_range)
constraints_range=self.constraints_range)

return is_pass
return is_pass, results
Loading