Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improvement] Update NasMutator to build search_space in NAS #426

Merged
merged 19 commits into from
Feb 1, 2023
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ repos:
rev: 4.0.1
hooks:
- id: flake8
- repo: https://github.com/timothycrosley/isort
rev: 5.10.1
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-yapf
Expand Down
30 changes: 10 additions & 20 deletions configs/_base_/settings/cifar10_darts_supernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,36 +48,26 @@

# optimizer
optim_wrapper = dict(
constructor='mmrazor.SeparateOptimWrapperConstructor',
architecture=dict(
type='mmcls.SGD', lr=0.025, momentum=0.9, weight_decay=3e-4),
mutator=dict(type='mmcls.Adam', lr=3e-4, weight_decay=1e-3),
clip_grad=dict(max_norm=5, norm_type=2))
optimizer=dict(
type='mmcls.SGD', lr=0.025, momentum=0.9, weight_decay=3e-4),
clip_grad=dict(max_norm=5, norm_type=2)),
mutator=dict(
optimizer=dict(type='mmcls.Adam', lr=3e-4, weight_decay=1e-3)))

search_epochs = 50
# leanring policy
# TODO support different optim use different scheduler (wait mmengine)
param_scheduler = [
dict(
type='mmcls.CosineAnnealingLR',
T_max=50,
T_max=search_epochs,
eta_min=1e-3,
begin=0,
end=50),
end=search_epochs),
]
# param_scheduler = dict(
# architecture = dict(
# type='mmcls.CosineAnnealingLR',
# T_max=50,
# eta_min=1e-3,
# begin=0,
# end=50),
# mutator = dict(
# type='mmcls.ConstantLR',
# factor=1,
# begin=0,
# end=50))

# train, val, test setting
# TODO split cifar dataset
train_cfg = dict(
type='mmrazor.DartsEpochBasedTrainLoop',
mutator_dataloader=dict(
Expand All @@ -92,7 +82,7 @@
sampler=dict(type='mmcls.DefaultSampler', shuffle=True),
persistent_workers=True,
),
max_epochs=50)
max_epochs=search_epochs)

val_cfg = dict() # validate each epoch
test_cfg = dict() # dataset settings
12 changes: 1 addition & 11 deletions configs/nas/mmcls/autoformer/autoformer_supernet_32xb256_in1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,7 @@
type='mmrazor.Autoformer',
architecture=supernet,
fix_subnet=None,
mutators=dict(
channel_mutator=dict(
type='mmrazor.OneShotChannelMutator',
channel_unit_cfg={
'type': 'OneShotMutableChannelUnit',
'default_args': {
'unit_predefined': True
}
},
parse_cfg={'type': 'Predefined'}),
value_mutator=dict(type='mmrazor.DynamicValueMutator')))
mutator=dict(type='mmrazor.NasMutator'))

# runtime setting
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,7 @@
loss_kl=dict(
preds_S=dict(recorder='fc', from_student=True),
preds_T=dict(recorder='fc', from_student=False)))),
mutators=dict(
channel_mutator=dict(
type='mmrazor.OneShotChannelMutator',
channel_unit_cfg={
'type': 'OneShotMutableChannelUnit',
'default_args': {
'unit_predefined': True
}
},
parse_cfg={'type': 'Predefined'}),
value_mutator=dict(type='DynamicValueMutator')))
mutators=dict(type='mmrazor.NasMutator'))

model_wrapper_cfg = dict(
type='mmrazor.BigNASDDP',
Expand Down
20 changes: 5 additions & 15 deletions configs/nas/mmcls/darts/darts_supernet_unroll_1xb96_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
'mmcls::_base_/default_runtime.py',
]

# model
mutator = dict(type='mmrazor.DiffModuleMutator')
custom_hooks = [
dict(type='mmrazor.DumpSubnetHook', interval=10, by_epoch=True)
]

# model
model = dict(
type='mmrazor.Darts',
architecture=dict(
Expand All @@ -20,24 +22,12 @@
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
cal_acc=True)),
mutator=dict(type='mmrazor.DiffModuleMutator'),
mutator=dict(type='mmrazor.NasMutator'),
unroll=True)

model_wrapper_cfg = dict(
type='mmrazor.DartsDDP',
broadcast_buffers=False,
find_unused_parameters=False)

# TRAINING
optim_wrapper = dict(
_delete_=True,
constructor='mmrazor.SeparateOptimWrapperConstructor',
architecture=dict(
type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.025, momentum=0.9, weight_decay=3e-4),
clip_grad=dict(max_norm=5, norm_type=2)),
mutator=dict(
type='OptimWrapper',
optimizer=dict(type='Adam', lr=3e-4, weight_decay=1e-3)))

find_unused_parameter = False
2 changes: 1 addition & 1 deletion configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
mode='original',
loss_weight=1.0),
topk=(1, 5))),
mutator=dict(type='mmrazor.DiffModuleMutator'),
mutator=dict(type='mmrazor.NasMutator'),
pretrain_epochs=15,
finetune_epochs=_base_.search_epochs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,7 @@
loss_kl=dict(
preds_S=dict(recorder='fc', from_student=True),
preds_T=dict(recorder='fc', from_student=False)))),
mutators=dict(
channel_mutator=dict(
type='mmrazor.OneShotChannelMutator',
channel_unit_cfg={
'type': 'OneShotMutableChannelUnit',
'default_args': {
'unit_predefined': True
}
},
parse_cfg={'type': 'Predefined'}),
value_mutator=dict(type='DynamicValueMutator')))
mutators=dict(type='mmrazor.NasMutator'))

model_wrapper_cfg = dict(
type='mmrazor.BigNASDDP',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@
model = dict(
type='mmrazor.SPOS',
architecture=supernet,
mutator=dict(type='mmrazor.OneShotModuleMutator'))
mutator=dict(type='mmrazor.NasMutator'))

find_unused_parameters = True
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@
model = dict(
type='mmrazor.SPOS',
architecture=supernet,
mutator=dict(type='mmrazor.OneShotModuleMutator'))
mutator=dict(type='mmrazor.NasMutator'))

find_unused_parameters = True
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@
_delete_=True,
type='mmrazor.SPOS',
architecture=supernet,
mutator=dict(type='mmrazor.OneShotModuleMutator'))
mutator=dict(type='mmrazor.NasMutator'))

find_unused_parameters = True
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
_delete_=True,
type='mmrazor.SPOS',
architecture=supernet,
mutator=dict(type='mmrazor.OneShotModuleMutator'))
mutator=dict(type='mmrazor.NasMutator'))

find_unused_parameters = True
1 change: 0 additions & 1 deletion configs/pruning/mmcls/dcff/dcff_resnet_8xb32_in1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
type='ChannelAnalyzer',
demo_input=(1, 3, 224, 224),
tracer_type='BackwardTracer')),
fix_subnet=None,
data_preprocessor=None,
target_pruning_ratio=target_pruning_ratio,
step_freq=1,
Expand Down
23 changes: 18 additions & 5 deletions mmrazor/engine/hooks/dump_subnet_hook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from pathlib import Path
from typing import Optional, Sequence, Union
Expand All @@ -8,6 +9,9 @@
from mmengine.hooks import Hook
from mmengine.registry import HOOKS

from mmrazor.models.mutables.base_mutable import BaseMutable
from mmrazor.structures import convert_fix_subnet, export_fix_subnet

DATA_BATCH = Optional[Sequence[dict]]


Expand Down Expand Up @@ -103,16 +107,25 @@ def after_train_epoch(self, runner) -> None:

@master_only
def _save_subnet(self, runner) -> None:
"""Save the current subnet and delete outdated subnet.
"""Save the current best subnet.

Args:
runner (Runner): The runner of the training process.
"""
model = runner.model.module if runner.distributed else runner.model

if runner.distributed:
subnet_dict = runner.model.module.search_subnet()
else:
subnet_dict = runner.model.search_subnet()
# delete non-leaf tensor to get deepcopy(model).
# TODO solve the hard case.
for module in model.architecture.modules():
if isinstance(module, BaseMutable):
if hasattr(module, 'arch_weights'):
delattr(module, 'arch_weights')

copied_model = copy.deepcopy(model)
copied_model.mutator.set_choices(copied_model.sample_choices())

subnet_dict = export_fix_subnet(copied_model)[0]
subnet_dict = convert_fix_subnet(subnet_dict)

if self.by_epoch:
subnet_filename = self.args.get(
Expand Down
8 changes: 5 additions & 3 deletions mmrazor/engine/hooks/estimate_resources_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def export_subnet(self, model) -> torch.nn.Module:
"""
# Avoid circular import
from mmrazor.models.mutables.base_mutable import BaseMutable
from mmrazor.structures import load_fix_subnet
from mmrazor.structures import export_fix_subnet, load_fix_subnet

# delete non-leaf tensor to get deepcopy(model).
# TODO solve the hard case.
Expand All @@ -114,7 +114,9 @@ def export_subnet(self, model) -> torch.nn.Module:
delattr(module, 'arch_weights')

copied_model = copy.deepcopy(model)
fix_mutable = copied_model.search_subnet()
load_fix_subnet(copied_model, fix_mutable)
copied_model.mutator.set_choices(copied_model.mutator.sample_choices())

subnet_dict = export_fix_subnet(copied_model)[0]
load_fix_subnet(copied_model, subnet_dict)

return copied_model
30 changes: 10 additions & 20 deletions mmrazor/engine/runner/autoslim_greedy_search_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.utils.data import DataLoader

from mmrazor.registry import LOOPS, TASK_UTILS
from mmrazor.structures import export_fix_subnet
from mmrazor.structures import convert_fix_subnet, export_fix_subnet
from .utils import check_subnet_resources


Expand Down Expand Up @@ -68,14 +68,15 @@ def __init__(self,
self.model = runner.model

assert hasattr(self.model, 'mutator')
search_groups = self.model.mutator.search_groups
units = self.model.mutator.mutable_units

self.candidate_choices = {}
for group_id, modules in search_groups.items():
self.candidate_choices[group_id] = modules[0].candidate_choices
for unit in units:
self.candidate_choices[unit.alias] = unit.candidate_choices

self.max_subnet = {}
for group_id, candidate_choices in self.candidate_choices.items():
self.max_subnet[group_id] = len(candidate_choices)
for name, candidate_choices in self.candidate_choices.items():
self.max_subnet[name] = len(candidate_choices)
self.current_subnet = self.max_subnet

current_subnet_choices = self._channel_bins2choices(
Expand Down Expand Up @@ -117,7 +118,7 @@ def run(self) -> None:
pruned_subnet[unit_name] -= 1
pruned_subnet_choices = self._channel_bins2choices(
pruned_subnet)
self.model.set_subnet(pruned_subnet_choices)
self.model.mutator.set_choices(pruned_subnet_choices)
metrics = self._val_subnet()
score = metrics[self.score_key] \
if len(metrics) != 0 else 0.
Expand Down Expand Up @@ -195,27 +196,16 @@ def _save_searcher_ckpt(self) -> None:

def _save_searched_subnet(self):
"""Save the final searched subnet dict."""

def _convert_fix_subnet(fixed_subnet: Dict[str, Any]):
from mmrazor.utils.typing import DumpChosen

converted_fix_subnet = dict()
for key, val in fixed_subnet.items():
assert isinstance(val, DumpChosen)
converted_fix_subnet[key] = dict(val._asdict())

return converted_fix_subnet

if self.runner.rank != 0:
return
self.runner.logger.info('Search finished:')
for subnet, flops in zip(self.searched_subnet,
self.searched_subnet_flops):
subnet_choice = self._channel_bins2choices(subnet)
self.model.set_subnet(subnet_choice)
self.model.mutator.set_choices(subnet_choice)
fixed_subnet, _ = export_fix_subnet(self.model)
save_name = 'FLOPS_{:.2f}M.yaml'.format(flops)
fixed_subnet = _convert_fix_subnet(fixed_subnet)
fixed_subnet = convert_fix_subnet(fixed_subnet)
fileio.dump(fixed_subnet, osp.join(self.runner.work_dir,
save_name))
self.runner.logger.info(
Expand Down
Loading