-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Improvement] Update NasMutator to build search_space in NAS #426
Changes from 4 commits
b8aa9fc
e20da2f
a8d6e17
675cb37
255c3df
120faca
fc8e820
02de63d
7866460
e7ae4e9
f7902b6
477cf38
bba830e
671aedc
6b1966b
645c70e
bf567f9
8fdcf09
9a8b0d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,6 @@ | |
_scope_='mmrazor', | ||
type='sub_model', | ||
cfg=_base_.architecture, | ||
fix_subnet='configs/pruning/mmdet/dcff/fix_subnet.json', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert |
||
mode='mutator', | ||
init_cfg=dict( | ||
type='Pretrained', | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,6 @@ | |
_scope_='mmrazor', | ||
type='sub_model', | ||
cfg=_base_.architecture, | ||
fix_subnet='configs/pruning/mmpose/dcff/fix_subnet.json', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert |
||
mode='mutator', | ||
init_cfg=dict( | ||
type='Pretrained', | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,6 @@ | |
_scope_='mmrazor', | ||
type='sub_model', | ||
cfg=_base_.architecture, | ||
fix_subnet='configs/pruning/mmseg/dcff/fix_subnet.json', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert |
||
mode='mutator', | ||
init_cfg=dict( | ||
type='Pretrained', | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,7 +104,7 @@ def export_subnet(self, model) -> torch.nn.Module: | |
""" | ||
# Avoid circular import | ||
from mmrazor.models.mutables.base_mutable import BaseMutable | ||
from mmrazor.structures import load_fix_subnet | ||
from mmrazor.structures import export_fix_subnet, load_fix_subnet | ||
|
||
# delete non-leaf tensor to get deepcopy(model). | ||
# TODO solve the hard case. | ||
|
@@ -114,7 +114,9 @@ def export_subnet(self, model) -> torch.nn.Module: | |
delattr(module, 'arch_weights') | ||
|
||
copied_model = copy.deepcopy(model) | ||
fix_mutable = copied_model.search_subnet() | ||
copied_model.set_subnet(copied_model.sample_subnet()) | ||
|
||
fix_mutable = export_fix_subnet(copied_model)[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unify |
||
load_fix_subnet(copied_model, fix_mutable) | ||
|
||
return copied_model |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,8 @@ | |
from torch.utils.data import DataLoader | ||
|
||
from mmrazor.registry import LOOPS, TASK_UTILS | ||
from mmrazor.structures import Candidates, export_fix_subnet | ||
from mmrazor.structures import (Candidates, convert_fix_subnet, | ||
export_fix_subnet) | ||
from mmrazor.utils import SupportRandomSubnet | ||
from .utils import CalibrateBNMixin, check_subnet_resources, crossover | ||
|
||
|
@@ -129,8 +130,7 @@ def __init__(self, | |
self.predictor_cfg = predictor_cfg | ||
if self.predictor_cfg is not None: | ||
self.predictor_cfg['score_key'] = self.score_key | ||
self.predictor_cfg['search_groups'] = \ | ||
self.model.mutator.search_groups | ||
self.predictor_cfg['search_groups'] = self.model.search_space | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert |
||
self.predictor = TASK_UTILS.build(self.predictor_cfg) | ||
|
||
def run(self) -> None: | ||
|
@@ -327,25 +327,14 @@ def _save_best_fix_subnet(self): | |
f'{self.runner.work_dir}') | ||
|
||
save_name = 'best_fix_subnet.yaml' | ||
best_fix_subnet = self._convert_fix_subnet(best_fix_subnet) | ||
best_fix_subnet = convert_fix_subnet(best_fix_subnet) | ||
fileio.dump(best_fix_subnet, | ||
osp.join(self.runner.work_dir, save_name)) | ||
self.runner.logger.info( | ||
f'Subnet config {save_name} saved in {self.runner.work_dir}.') | ||
|
||
self.runner.logger.info('Search finished.') | ||
|
||
def _convert_fix_subnet(self, fix_subnet: Dict[str, Any]): | ||
"""Convert the fixed subnet to avoid python typing error.""" | ||
from mmrazor.utils.typing import DumpChosen | ||
|
||
converted_fix_subnet = dict() | ||
for k, v in fix_subnet.items(): | ||
assert isinstance(v, DumpChosen) | ||
converted_fix_subnet[k] = dict(chosen=v.chosen) | ||
|
||
return converted_fix_subnet | ||
|
||
@torch.no_grad() | ||
def _val_candidate(self, use_predictor: bool = False) -> Dict: | ||
"""Run validation. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,7 @@ def run(self): | |
|
||
def _save_fix_subnet(self): | ||
"""Save model subnet config.""" | ||
# TO DO: Modify export_fix_subnet's output. Might contain weight return | ||
# TODO: Modify export_fix_subnet's output. Might contain weight return | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. del this line (already done.) |
||
fix_subnet, static_model = export_fix_subnet( | ||
self.model, export_subnet_mode='mutator', slice_weight=True) | ||
fix_subnet = json.dumps(fix_subnet, indent=4, separators=(',', ':')) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert