Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improvement] Update NasMutator to build search_space in NAS #426

Merged
merged 19 commits into from
Feb 1, 2023
30 changes: 10 additions & 20 deletions configs/_base_/settings/cifar10_darts_supernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,36 +48,26 @@

# optimizer
optim_wrapper = dict(
constructor='mmrazor.SeparateOptimWrapperConstructor',
architecture=dict(
type='mmcls.SGD', lr=0.025, momentum=0.9, weight_decay=3e-4),
mutator=dict(type='mmcls.Adam', lr=3e-4, weight_decay=1e-3),
clip_grad=dict(max_norm=5, norm_type=2))
optimizer=dict(
type='mmcls.SGD', lr=0.025, momentum=0.9, weight_decay=3e-4),
clip_grad=dict(max_norm=5, norm_type=2)),
search_params=dict(
optimizer=dict(type='mmcls.Adam', lr=3e-4, weight_decay=1e-3)))

search_epochs = 50
# leanring policy
# TODO support different optim use different scheduler (wait mmengine)
param_scheduler = [
dict(
type='mmcls.CosineAnnealingLR',
T_max=50,
T_max=search_epochs,
eta_min=1e-3,
begin=0,
end=50),
end=search_epochs),
]
# param_scheduler = dict(
# architecture = dict(
# type='mmcls.CosineAnnealingLR',
# T_max=50,
# eta_min=1e-3,
# begin=0,
# end=50),
# mutator = dict(
# type='mmcls.ConstantLR',
# factor=1,
# begin=0,
# end=50))

# train, val, test setting
# TODO split cifar dataset
train_cfg = dict(
type='mmrazor.DartsEpochBasedTrainLoop',
mutator_dataloader=dict(
Expand All @@ -92,7 +82,7 @@
sampler=dict(type='mmcls.DefaultSampler', shuffle=True),
persistent_workers=True,
),
max_epochs=50)
max_epochs=search_epochs)

val_cfg = dict() # validate each epoch
test_cfg = dict() # dataset settings
4 changes: 2 additions & 2 deletions configs/_base_/settings/imagenet_bs1024_dsnas.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
optimizer=dict(
type='mmcls.SGD', lr=0.5, momentum=0.9, weight_decay=4e-5),
paramwise_cfg=paramwise_cfg),
mutator=dict(
search_params=dict(
optimizer=dict(
type='mmcls.Adam', lr=0.001, weight_decay=0.0, betas=(0.5,
0.999))))
Expand Down Expand Up @@ -94,7 +94,7 @@
by_epoch=True,
convert_to_iter_based=True)
],
mutator=[])
search_params=[])

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=240)
Expand Down
18 changes: 4 additions & 14 deletions configs/nas/mmcls/darts/darts_supernet_unroll_1xb96_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
'mmcls::_base_/default_runtime.py',
]

# model
mutator = dict(type='mmrazor.DiffModuleMutator')
custom_hooks = [
dict(type='mmrazor.DumpSubnetHook', interval=10, by_epoch=True)
]

# model
model = dict(
type='mmrazor.Darts',
architecture=dict(
Expand All @@ -28,16 +30,4 @@
broadcast_buffers=False,
find_unused_parameters=False)

# TRAINING
optim_wrapper = dict(
_delete_=True,
constructor='mmrazor.SeparateOptimWrapperConstructor',
architecture=dict(
type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.025, momentum=0.9, weight_decay=3e-4),
clip_grad=dict(max_norm=5, norm_type=2)),
mutator=dict(
type='OptimWrapper',
optimizer=dict(type='Adam', lr=3e-4, weight_decay=1e-3)))

find_unused_parameter = False
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
type='sub_model',
cfg=dict(
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False),
fix_subnet='configs/pruning/mmcls/dcff/fix_subnet.json',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

mode='mutator',
init_cfg=dict(
type='Pretrained',
Expand Down
1 change: 0 additions & 1 deletion configs/pruning/mmcls/dcff/dcff_resnet_8xb32_in1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
type='ChannelAnalyzer',
demo_input=(1, 3, 224, 224),
tracer_type='BackwardTracer')),
fix_subnet=None,
data_preprocessor=None,
target_pruning_ratio=target_pruning_ratio,
step_freq=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
_scope_='mmrazor',
type='sub_model',
cfg=_base_.architecture,
fix_subnet='configs/pruning/mmdet/dcff/fix_subnet.json',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

mode='mutator',
init_cfg=dict(
type='Pretrained',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
_scope_='mmrazor',
type='sub_model',
cfg=_base_.architecture,
fix_subnet='configs/pruning/mmpose/dcff/fix_subnet.json',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

mode='mutator',
init_cfg=dict(
type='Pretrained',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
_scope_='mmrazor',
type='sub_model',
cfg=_base_.architecture,
fix_subnet='configs/pruning/mmseg/dcff/fix_subnet.json',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

mode='mutator',
init_cfg=dict(
type='Pretrained',
Expand Down
23 changes: 18 additions & 5 deletions mmrazor/engine/hooks/dump_subnet_hook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from pathlib import Path
from typing import Optional, Sequence, Union
Expand All @@ -8,6 +9,9 @@
from mmengine.hooks import Hook
from mmengine.registry import HOOKS

from mmrazor.models.mutables.base_mutable import BaseMutable
from mmrazor.structures import convert_fix_subnet, export_fix_subnet

DATA_BATCH = Optional[Sequence[dict]]


Expand Down Expand Up @@ -103,16 +107,25 @@ def after_train_epoch(self, runner) -> None:

@master_only
def _save_subnet(self, runner) -> None:
"""Save the current subnet and delete outdated subnet.
"""Save the current best subnet.

Args:
runner (Runner): The runner of the training process.
"""
model = runner.model.module if runner.distributed else runner.model

if runner.distributed:
subnet_dict = runner.model.module.search_subnet()
else:
subnet_dict = runner.model.search_subnet()
# delete non-leaf tensor to get deepcopy(model).
# TODO solve the hard case.
for module in model.architecture.modules():
if isinstance(module, BaseMutable):
if hasattr(module, 'arch_weights'):
delattr(module, 'arch_weights')

copied_model = copy.deepcopy(model)
copied_model.set_subnet(copied_model.sample_subnet())

subnet_dict = export_fix_subnet(copied_model)[0]
subnet_dict = convert_fix_subnet(subnet_dict)

if self.by_epoch:
subnet_filename = self.args.get(
Expand Down
6 changes: 4 additions & 2 deletions mmrazor/engine/hooks/estimate_resources_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def export_subnet(self, model) -> torch.nn.Module:
"""
# Avoid circular import
from mmrazor.models.mutables.base_mutable import BaseMutable
from mmrazor.structures import load_fix_subnet
from mmrazor.structures import export_fix_subnet, load_fix_subnet

# delete non-leaf tensor to get deepcopy(model).
# TODO solve the hard case.
Expand All @@ -114,7 +114,9 @@ def export_subnet(self, model) -> torch.nn.Module:
delattr(module, 'arch_weights')

copied_model = copy.deepcopy(model)
fix_mutable = copied_model.search_subnet()
copied_model.set_subnet(copied_model.sample_subnet())

fix_mutable = export_fix_subnet(copied_model)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unify fix_mutable to subnet_dict.

load_fix_subnet(copied_model, fix_mutable)

return copied_model
20 changes: 4 additions & 16 deletions mmrazor/engine/runner/autoslim_greedy_search_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.utils.data import DataLoader

from mmrazor.registry import LOOPS, TASK_UTILS
from mmrazor.structures import export_fix_subnet
from mmrazor.structures import convert_fix_subnet, export_fix_subnet
from .utils import check_subnet_resources


Expand Down Expand Up @@ -106,8 +106,7 @@ def run(self) -> None:
continue

while self.current_flops > target:
best_score, best_subnet = None, None

best_score, best_subnet = 0., None
for unit_name in sorted(self.current_subnet.keys()):
if self.current_subnet[unit_name] == 1:
# The number of channel_bin has reached the minimum
Expand All @@ -124,7 +123,7 @@ def run(self) -> None:
self.runner.logger.info(
f'Slimming unit {unit_name}, {self.score_key}: {score}'
)
if best_score is None or score > best_score:
if score >= best_score:
best_score = score
best_subnet = pruned_subnet

Expand Down Expand Up @@ -195,17 +194,6 @@ def _save_searcher_ckpt(self) -> None:

def _save_searched_subnet(self):
"""Save the final searched subnet dict."""

def _convert_fix_subnet(fixed_subnet: Dict[str, Any]):
from mmrazor.utils.typing import DumpChosen

converted_fix_subnet = dict()
for key, val in fixed_subnet.items():
assert isinstance(val, DumpChosen)
converted_fix_subnet[key] = dict(val._asdict())

return converted_fix_subnet

if self.runner.rank != 0:
return
self.runner.logger.info('Search finished:')
Expand All @@ -215,7 +203,7 @@ def _convert_fix_subnet(fixed_subnet: Dict[str, Any]):
self.model.set_subnet(subnet_choice)
fixed_subnet, _ = export_fix_subnet(self.model)
save_name = 'FLOPS_{:.2f}M.yaml'.format(flops)
fixed_subnet = _convert_fix_subnet(fixed_subnet)
fixed_subnet = convert_fix_subnet(fixed_subnet)
fileio.dump(fixed_subnet, osp.join(self.runner.work_dir,
save_name))
self.runner.logger.info(
Expand Down
19 changes: 4 additions & 15 deletions mmrazor/engine/runner/evolution_search_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from torch.utils.data import DataLoader

from mmrazor.registry import LOOPS, TASK_UTILS
from mmrazor.structures import Candidates, export_fix_subnet
from mmrazor.structures import (Candidates, convert_fix_subnet,
export_fix_subnet)
from mmrazor.utils import SupportRandomSubnet
from .utils import CalibrateBNMixin, check_subnet_resources, crossover

Expand Down Expand Up @@ -129,8 +130,7 @@ def __init__(self,
self.predictor_cfg = predictor_cfg
if self.predictor_cfg is not None:
self.predictor_cfg['score_key'] = self.score_key
self.predictor_cfg['search_groups'] = \
self.model.mutator.search_groups
self.predictor_cfg['search_groups'] = self.model.search_space
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

self.predictor = TASK_UTILS.build(self.predictor_cfg)

def run(self) -> None:
Expand Down Expand Up @@ -327,25 +327,14 @@ def _save_best_fix_subnet(self):
f'{self.runner.work_dir}')

save_name = 'best_fix_subnet.yaml'
best_fix_subnet = self._convert_fix_subnet(best_fix_subnet)
best_fix_subnet = convert_fix_subnet(best_fix_subnet)
fileio.dump(best_fix_subnet,
osp.join(self.runner.work_dir, save_name))
self.runner.logger.info(
f'Subnet config {save_name} saved in {self.runner.work_dir}.')

self.runner.logger.info('Search finished.')

def _convert_fix_subnet(self, fix_subnet: Dict[str, Any]):
"""Convert the fixed subnet to avoid python typing error."""
from mmrazor.utils.typing import DumpChosen

converted_fix_subnet = dict()
for k, v in fix_subnet.items():
assert isinstance(v, DumpChosen)
converted_fix_subnet[k] = dict(chosen=v.chosen)

return converted_fix_subnet

@torch.no_grad()
def _val_candidate(self, use_predictor: bool = False) -> Dict:
"""Run validation.
Expand Down
2 changes: 1 addition & 1 deletion mmrazor/engine/runner/iteprune_val_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def run(self):

def _save_fix_subnet(self):
"""Save model subnet config."""
# TO DO: Modify export_fix_subnet's output. Might contain weight return
# TODO: Modify export_fix_subnet's output. Might contain weight return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

del this line (already done.)

fix_subnet, static_model = export_fix_subnet(
self.model, export_subnet_mode='mutator', slice_weight=True)
fix_subnet = json.dumps(fix_subnet, indent=4, separators=(',', ':'))
Expand Down
29 changes: 3 additions & 26 deletions mmrazor/models/algorithms/nas/autoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from mmrazor.registry import MODELS
from mmrazor.utils import ValidFixMutable
from ..base import BaseAlgorithm, LossResults
from ..space_mixin import SpaceMixin


@MODELS.register_module()
class Autoformer(BaseAlgorithm):
class Autoformer(BaseAlgorithm, SpaceMixin):
"""Implementation of `Autoformer <https://arxiv.org/abs/2107.00651>`_

AutoFormer is dedicated to vision transformer search. AutoFormer
Expand Down Expand Up @@ -75,33 +76,9 @@ def __init__(self,
else:
raise TypeError('mutator should be a `dict` but got '
f'{type(mutator)}')

self._build_search_space()
self.is_supernet = True

def sample_subnet(self) -> Dict:
"""Random sample subnet by mutator."""
value_subnet = dict()
channel_subnet = dict()
for name, mutator in self.mutators.items():
if name == 'value_mutator':
value_subnet.update(mutator.sample_choices())
elif name == 'channel_mutator':
channel_subnet.update(mutator.sample_choices())
else:
raise NotImplementedError
return dict(value_subnet=value_subnet, channel_subnet=channel_subnet)

def set_subnet(self, subnet: Dict[str, Dict[int, Union[int,
list]]]) -> None:
"""Set the subnet sampled by :meth:sample_subnet."""
for name, mutator in self.mutators.items():
if name == 'value_mutator':
mutator.set_choices(subnet['value_subnet'])
elif name == 'channel_mutator':
mutator.set_choices(subnet['channel_subnet'])
else:
raise NotImplementedError

def loss(
self,
batch_inputs: torch.Tensor,
Expand Down
Loading