diff --git a/configs/nas/mmcls/darts/DARTS_SUBNET_CIFAR_PAPER_ALIAS.yaml b/configs/nas/mmcls/darts/DARTS_SUBNET_CIFAR_PAPER_ALIAS.yaml index d9faa6497..c56cfe46e 100644 --- a/configs/nas/mmcls/darts/DARTS_SUBNET_CIFAR_PAPER_ALIAS.yaml +++ b/configs/nas/mmcls/darts/DARTS_SUBNET_CIFAR_PAPER_ALIAS.yaml @@ -1,56 +1,80 @@ normal_n2: - - normal_n2_p0 - - normal_n2_p1 + chosen: + - normal_n2_p0 + - normal_n2_p1 normal_n2_p0: - - sep_conv_3x3 + chosen: + - sep_conv_3x3 normal_n2_p1: - - sep_conv_3x3 + chosen: + - sep_conv_3x3 normal_n3: - - normal_n3_p0 - - normal_n3_p1 + chosen: + - normal_n3_p0 + - normal_n3_p1 normal_n3_p0: - - skip_connect + chosen: + - skip_connect normal_n3_p1: - - sep_conv_5x5 + chosen: + - sep_conv_5x5 normal_n4: - - normal_n4_p0 - - normal_n4_p1 + chosen: + - normal_n4_p0 + - normal_n4_p1 normal_n4_p0: - - sep_conv_3x3 + chosen: + - sep_conv_3x3 normal_n4_p1: - - skip_connect + chosen: + - skip_connect normal_n5: - - normal_n5_p0 - - normal_n5_p1 + chosen: + - normal_n5_p0 + - normal_n5_p1 normal_n5_p0: - - skip_connect + chosen: + - skip_connect normal_n5_p1: - - skip_connect + chosen: + - skip_connect reduce_n2: - - reduce_n2_p0 - - reduce_n2_p1 + chosen: + - reduce_n2_p0 + - reduce_n2_p1 reduce_n2_p0: - - max_pool_3x3 + chosen: + - max_pool_3x3 reduce_n2_p1: - - sep_conv_3x3 + chosen: + - sep_conv_3x3 reduce_n3: - - reduce_n3_p0 - - reduce_n3_p2 + chosen: + - reduce_n3_p0 + - reduce_n3_p2 reduce_n3_p0: - - max_pool_3x3 + chosen: + - max_pool_3x3 reduce_n3_p2: - - dil_conv_5x5 + chosen: + - dil_conv_5x5 reduce_n4: - - reduce_n4_p0 - - reduce_n4_p2 + chosen: + - reduce_n4_p0 + - reduce_n4_p2 reduce_n4_p0: - - max_pool_3x3 + chosen: + - max_pool_3x3 reduce_n4_p2: - - skip_connect + chosen: + - skip_connect reduce_n5: - - reduce_n5_p0 - - reduce_n5_p2 + chosen: + - reduce_n5_p0 + - reduce_n5_p2 reduce_n5_p0: - - max_pool_3x3 + chosen: + - max_pool_3x3 reduce_n5_p2: - - skip_connect + chosen: + - skip_connect diff --git a/configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml b/configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml index d2fa294d3..0c35c01b5 100644 --- a/configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml +++ b/configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml @@ -1,20 +1,40 @@ -backbone.layers.0.0: shuffle_3x3 -backbone.layers.0.1: shuffle_3x3 -backbone.layers.0.2: shuffle_xception -backbone.layers.0.3: shuffle_3x3 -backbone.layers.1.0: shuffle_xception -backbone.layers.1.1: shuffle_7x7 -backbone.layers.1.2: shuffle_3x3 -backbone.layers.1.3: shuffle_3x3 -backbone.layers.2.0: shuffle_xception -backbone.layers.2.1: shuffle_xception -backbone.layers.2.2: shuffle_7x7 -backbone.layers.2.3: shuffle_xception -backbone.layers.2.4: shuffle_xception -backbone.layers.2.5: shuffle_xception -backbone.layers.2.6: shuffle_7x7 -backbone.layers.2.7: shuffle_3x3 -backbone.layers.3.0: shuffle_3x3 -backbone.layers.3.1: shuffle_xception -backbone.layers.3.2: shuffle_xception -backbone.layers.3.3: shuffle_3x3 +backbone.layers.0.0: + chosen: shuffle_3x3 +backbone.layers.0.1: + chosen: shuffle_7x7 +backbone.layers.0.2: + chosen: shuffle_3x3 +backbone.layers.0.3: + chosen: shuffle_5x5 +backbone.layers.1.0: + chosen: shuffle_3x3 +backbone.layers.1.1: + chosen: shuffle_3x3 +backbone.layers.1.2: + chosen: shuffle_3x3 +backbone.layers.1.3: + chosen: shuffle_7x7 +backbone.layers.2.0: + chosen: shuffle_xception +backbone.layers.2.1: + chosen: shuffle_3x3 +backbone.layers.2.2: + chosen: shuffle_3x3 +backbone.layers.2.3: + chosen: shuffle_5x5 +backbone.layers.2.4: + chosen: shuffle_3x3 +backbone.layers.2.5: + chosen: shuffle_5x5 +backbone.layers.2.6: + chosen: shuffle_7x7 +backbone.layers.2.7: + chosen: shuffle_7x7 +backbone.layers.3.0: + chosen: shuffle_xception +backbone.layers.3.1: + chosen: shuffle_3x3 +backbone.layers.3.2: + chosen: shuffle_7x7 +backbone.layers.3.3: + chosen: shuffle_3x3 diff --git a/configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py b/configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py index ca30a5946..a96c81f82 100644 --- a/configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py +++ b/configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py @@ -1,28 +1,7 @@ _base_ = ['./dsnas_supernet_8xb128_in1k.py'] # NOTE: Replace this with the mutable_cfg searched by yourself. -fix_subnet = { - 'backbone.layers.0.0': 'shuffle_3x3', - 'backbone.layers.0.1': 'shuffle_7x7', - 'backbone.layers.0.2': 'shuffle_3x3', - 'backbone.layers.0.3': 'shuffle_5x5', - 'backbone.layers.1.0': 'shuffle_3x3', - 'backbone.layers.1.1': 'shuffle_3x3', - 'backbone.layers.1.2': 'shuffle_3x3', - 'backbone.layers.1.3': 'shuffle_7x7', - 'backbone.layers.2.0': 'shuffle_xception', - 'backbone.layers.2.1': 'shuffle_3x3', - 'backbone.layers.2.2': 'shuffle_3x3', - 'backbone.layers.2.3': 'shuffle_5x5', - 'backbone.layers.2.4': 'shuffle_3x3', - 'backbone.layers.2.5': 'shuffle_5x5', - 'backbone.layers.2.6': 'shuffle_7x7', - 'backbone.layers.2.7': 'shuffle_7x7', - 'backbone.layers.3.0': 'shuffle_xception', - 'backbone.layers.3.1': 'shuffle_3x3', - 'backbone.layers.3.2': 'shuffle_7x7', - 'backbone.layers.3.3': 'shuffle_3x3', -} +fix_subnet = 'configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml' model = dict(fix_subnet=fix_subnet) diff --git a/configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py b/configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py index ea821da40..50d11dee2 100644 --- a/configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py +++ b/configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py @@ -6,7 +6,7 @@ # model model = dict( - type='mmrazor.Dsnas', + type='mmrazor.DSNAS', architecture=dict( type='ImageClassifier', data_preprocessor=_base_.data_preprocessor, @@ -29,7 +29,7 @@ ) model_wrapper_cfg = dict( - type='mmrazor.DsnasDDP', + type='mmrazor.DSNASDDP', broadcast_buffers=False, find_unused_parameters=True) diff --git a/configs/nas/mmcls/spos/SPOS_SUBNET.yaml b/configs/nas/mmcls/spos/SPOS_SUBNET.yaml new file mode 100644 index 000000000..ba809da1d --- /dev/null +++ b/configs/nas/mmcls/spos/SPOS_SUBNET.yaml @@ -0,0 +1,40 @@ +backbone.layers.0.0: + chosen: shuffle_7x7 +backbone.layers.0.1: + chosen: shuffle_3x3 +backbone.layers.0.2: + chosen: shuffle_7x7 +backbone.layers.0.3: + chosen: shuffle_3x3 +backbone.layers.1.0: + chosen: shuffle_xception +backbone.layers.1.1: + chosen: shuffle_5x5 +backbone.layers.1.2: + chosen: shuffle_5x5 +backbone.layers.1.3: + chosen: shuffle_3x3 +backbone.layers.2.0: + chosen: shuffle_3x3 +backbone.layers.2.1: + chosen: shuffle_5x5 +backbone.layers.2.2: + chosen: shuffle_3x3 +backbone.layers.2.3: + chosen: shuffle_5x5 +backbone.layers.2.4: + chosen: shuffle_3x3 +backbone.layers.2.5: + chosen: shuffle_xception +backbone.layers.2.6: + chosen: shuffle_5x5 +backbone.layers.2.7: + chosen: shuffle_7x7 +backbone.layers.3.0: + chosen: shuffle_7x7 +backbone.layers.3.1: + chosen: shuffle_3x3 +backbone.layers.3.2: + chosen: shuffle_5x5 +backbone.layers.3.3: + chosen: shuffle_xception diff --git a/configs/nas/mmcls/spos/spos_shufflenet_subnet_8xb128_in1k.py b/configs/nas/mmcls/spos/spos_shufflenet_subnet_8xb128_in1k.py index ff7c3bf8c..1243d16b2 100644 --- a/configs/nas/mmcls/spos/spos_shufflenet_subnet_8xb128_in1k.py +++ b/configs/nas/mmcls/spos/spos_shufflenet_subnet_8xb128_in1k.py @@ -1,7 +1,7 @@ _base_ = ['./spos_shufflenet_supernet_8xb128_in1k.py'] -# FIXME: you may replace this with the mutable_cfg searched by yourself -fix_subnet = 'https://download.openmmlab.com/mmrazor/v1/spos/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20220715-aa94d5ef_subnet_cfg_v1.yaml' # noqa: E501 +# FIXME: you may replace this with the searched by yourself +fix_subnet = 'configs/nas/mmcls/spos/SPOS_SUBNET.yaml' model = dict(fix_subnet=fix_subnet) diff --git a/configs/nas/mmdet/detnas/DETNAS_SUBNET.yaml b/configs/nas/mmdet/detnas/DETNAS_SUBNET.yaml new file mode 100644 index 000000000..c7bcab916 --- /dev/null +++ b/configs/nas/mmdet/detnas/DETNAS_SUBNET.yaml @@ -0,0 +1,40 @@ +backbone.layers.0.0: + chosen: shuffle_5x5 +backbone.layers.0.1: + chosen: shuffle_3x3 +backbone.layers.0.2: + chosen: shuffle_3x3 +backbone.layers.0.3: + chosen: shuffle_3x3 +backbone.layers.1.0: + chosen: shuffle_xception +backbone.layers.1.1: + chosen: shuffle_3x3 +backbone.layers.1.2: + chosen: shuffle_xception +backbone.layers.1.3: + chosen: shuffle_7x7 +backbone.layers.2.0: + chosen: shuffle_7x7 +backbone.layers.2.1: + chosen: shuffle_7x7 +backbone.layers.2.2: + chosen: shuffle_xception +backbone.layers.2.3: + chosen: shuffle_xception +backbone.layers.2.4: + chosen: shuffle_3x3 +backbone.layers.2.5: + chosen: shuffle_7x7 +backbone.layers.2.6: + chosen: shuffle_5x5 +backbone.layers.2.7: + chosen: shuffle_xception +backbone.layers.3.0: + chosen: shuffle_7x7 +backbone.layers.3.1: + chosen: shuffle_7x7 +backbone.layers.3.2: + chosen: shuffle_7x7 +backbone.layers.3.3: + chosen: shuffle_5x5 diff --git a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py index 43d0f4983..8334c78b8 100644 --- a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py +++ b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py @@ -1,7 +1,7 @@ _base_ = ['./detnas_frcnn_shufflenet_supernet_coco_1x.py'] # FIXME: you may replace this with the searched by yourself -fix_subnet = 'https://download.openmmlab.com/mmrazor/v1/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20220715-61d2e900_subnet_cfg_v1.yaml' # noqa: E501 +fix_subnet = 'configs/nas/mmdet/detnas/DETNAS_SUBNET.yaml' model = dict(fix_subnet=fix_subnet) diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index e6258b012..a5129acb4 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -3,7 +3,7 @@ from .distill import (DAFLDataFreeDistillation, DataFreeDistillation, FpnTeacherDistill, OverhaulFeatureDistillation, SelfDistill, SingleTeacherDistill) -from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP, Dsnas, DsnasDDP +from .nas import DSNAS, DSNASDDP, SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP from .pruning import SlimmableNetwork, SlimmableNetworkDDP from .pruning.ite_prune_algorithm import ItePruneAlgorithm @@ -23,6 +23,6 @@ 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'ItePruneAlgorithm', - 'Dsnas', - 'DsnasDDP', + 'DSNAS', + 'DSNASDDP', ] diff --git a/mmrazor/models/algorithms/nas/__init__.py b/mmrazor/models/algorithms/nas/__init__.py index 17eab7e86..b290afa0a 100644 --- a/mmrazor/models/algorithms/nas/__init__.py +++ b/mmrazor/models/algorithms/nas/__init__.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .autoslim import AutoSlim, AutoSlimDDP from .darts import Darts, DartsDDP -from .dsnas import Dsnas, DsnasDDP +from .dsnas import DSNAS, DSNASDDP from .spos import SPOS __all__ = [ - 'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'Dsnas', 'DsnasDDP' + 'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'DSNAS', 'DSNASDDP' ] diff --git a/mmrazor/models/algorithms/nas/dsnas.py b/mmrazor/models/algorithms/nas/dsnas.py index 62c2c7f04..5434ce0ac 100644 --- a/mmrazor/models/algorithms/nas/dsnas.py +++ b/mmrazor/models/algorithms/nas/dsnas.py @@ -23,7 +23,7 @@ @MODELS.register_module() -class Dsnas(BaseAlgorithm): +class DSNAS(BaseAlgorithm): """Implementation of `DSNAS `_ Args: @@ -272,7 +272,7 @@ def handle_grads(self): @MODEL_WRAPPERS.register_module() -class DsnasDDP(MMDistributedDataParallel): +class DSNASDDP(MMDistributedDataParallel): def __init__(self, *, diff --git a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py index cca03a71f..4b592740a 100644 --- a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py +++ b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py @@ -7,6 +7,7 @@ from mmengine.model import BaseModel from mmengine.structures import BaseDataElement +from mmrazor.models.mutables import MutableChannelUnit from mmrazor.models.mutators import ChannelMutator from mmrazor.registry import MODELS from ..base import BaseAlgorithm @@ -107,7 +108,7 @@ def __init__(self, channel_unit_cfg=dict( type='SequentialMutableChannelUnit')), data_preprocessor: Optional[Union[Dict, nn.Module]] = None, - target_pruning_ratio={}, + target_pruning_ratio: Optional[Dict[str, float]] = None, step_epoch=1, prune_times=1, init_cfg: Optional[Dict] = None) -> None: @@ -118,15 +119,49 @@ def __init__(self, self.mutator: ChannelMutator = MODELS.build(mutator_cfg) self.mutator.prepare_from_supernet(self.architecture) + if target_pruning_ratio is None: + group_target_ratio = self.mutator.current_choices + else: + group_target_ratio = self.group_target_pruning_ratio( + target_pruning_ratio, self.mutator.search_groups) + # config_manager - self.check_prune_targe(target_pruning_ratio) self.prune_config_manager = ItePruneConfigManager( - target_pruning_ratio, - self.mutator.choice_template, + group_target_ratio, + self.mutator.current_choices, step_epoch, times=prune_times) - def check_prune_targe(self, config: Dict): + def group_target_pruning_ratio( + self, target: Dict[str, float], + search_groups: Dict[int, + List[MutableChannelUnit]]) -> Dict[int, float]: + """According to the target pruning ratio of each unit, set the target + ratio of each search group.""" + group_target: Dict[int, float] = dict() + for group_id, units in search_groups.items(): + for unit in units: + unit_name = unit.name + # The config of target pruning ratio does not + # contain all units. + if unit_name not in target: + continue + if group_id in group_target: + unit_target = target[unit_name] + if unit_target != group_target[group_id]: + group_names = [u.name for u in units] + raise ValueError( + f"'{unit_name}' target ratio is different from " + f'other units in the same group {group_names}. ' + 'Pls check your target pruning ratio config.') + else: + unit_target = target[unit_name] + assert isinstance(unit_target, (float, int)) + group_target[group_id] = unit_target + + return group_target + + def check_prune_target(self, config: Dict): """Check if the prune-target is supported.""" for value in config.values(): assert isinstance(value, int) or isinstance(value, float) @@ -141,7 +176,9 @@ def forward(self, self._iteration): config = self.prune_config_manager.prune_at(self._epoch) + self.mutator.set_choices(config) + logger = MMLogger.get_current_instance() logger.info(f'The model is pruned at {self._epoch}th epoch once.') diff --git a/mmrazor/models/mutables/base_mutable.py b/mmrazor/models/mutables/base_mutable.py index b0df98d4f..2b5972d9f 100644 --- a/mmrazor/models/mutables/base_mutable.py +++ b/mmrazor/models/mutables/base_mutable.py @@ -1,14 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod -from typing import Dict, Generic, Optional, TypeVar +from typing import Dict, Optional from mmengine.model import BaseModule -CHOICE_TYPE = TypeVar('CHOICE_TYPE') -CHOSEN_TYPE = TypeVar('CHOSEN_TYPE') +from mmrazor.utils.typing import DumpChosen -class BaseMutable(BaseModule, ABC, Generic[CHOICE_TYPE, CHOSEN_TYPE]): +class BaseMutable(BaseModule, ABC): """Base Class for mutables. Mutable means a searchable module widely used in Neural Architecture Search(NAS). @@ -17,13 +16,12 @@ class BaseMutable(BaseModule, ABC, Generic[CHOICE_TYPE, CHOSEN_TYPE]): All subclass should implement the following APIs: - - ``forward()`` - ``fix_chosen()`` - - ``choices()`` + - ``dump_chosen()`` + - ``current_choice.setter()`` + - ``current_choice.getter()`` Args: - module_kwargs (dict[str, dict], optional): Module initialization named - arguments. Defaults to None. alias (str, optional): alias of the `MUTABLE`. init_cfg (dict, optional): initialization configuration dict for ``BaseModule``. OpenMMLab has implement 5 initializer including @@ -38,19 +36,18 @@ def __init__(self, self.alias = alias self._is_fixed = False - self._current_choice: Optional[CHOICE_TYPE] = None - @property - def current_choice(self) -> Optional[CHOICE_TYPE]: + @property # type: ignore + @abstractmethod + def current_choice(self): """Current choice will affect :meth:`forward` and will be used in :func:`mmrazor.core.subnet.utils.export_fix_subnet` or mutator. """ - return self._current_choice - @current_choice.setter - def current_choice(self, choice: Optional[CHOICE_TYPE]) -> None: + @current_choice.setter # type: ignore + @abstractmethod + def current_choice(self, choice) -> None: """Current choice setter will be executed in mutator.""" - self._current_choice = choice @property def is_fixed(self) -> bool: @@ -76,22 +73,22 @@ def is_fixed(self, is_fixed: bool) -> None: self._is_fixed = is_fixed @abstractmethod - def fix_chosen(self, chosen: CHOSEN_TYPE) -> None: - """Fix mutable with choice. This function would fix the choice of - Mutable. The :attr:`is_fixed` will be set to True and only the selected + def fix_chosen(self, chosen) -> None: + """Fix mutable with chosen. This function would fix the chosen of + mutable. The :attr:`is_fixed` will be set to True and only the selected operations can be retained. All subclasses must implement this method. Note: This operation is irreversible. """ + raise NotImplementedError() - # TODO - # type hint @abstractmethod - def dump_chosen(self) -> CHOSEN_TYPE: - ... + def dump_chosen(self) -> DumpChosen: + """Save the current state of the mutable as a dictionary. - @property - @abstractmethod - def num_choices(self) -> int: - pass + ``DumpChosen`` has ``chosen`` and ``meta`` fields. ``chosen`` is + necessary, ``fix_chosen`` will use the ``chosen`` . ``meta`` is used to + store some non-essential information. + """ + raise NotImplementedError() diff --git a/mmrazor/models/mutables/derived_mutable.py b/mmrazor/models/mutables/derived_mutable.py index 98f680ee9..ddbf6adeb 100644 --- a/mmrazor/models/mutables/derived_mutable.py +++ b/mmrazor/models/mutables/derived_mutable.py @@ -15,8 +15,9 @@ from mmengine.logging import print_log from torch import Tensor +from mmrazor.utils.typing import DumpChosen from ..utils import make_divisible -from .base_mutable import CHOICE_TYPE, BaseMutable +from .base_mutable import BaseMutable class MutableProtocol(Protocol): # pragma: no cover @@ -172,8 +173,7 @@ def derive_concat_mutable( return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn) -class DerivedMutable(BaseMutable[CHOICE_TYPE, CHOICE_TYPE], - DerivedMethodMixin): +class DerivedMutable(BaseMutable, DerivedMethodMixin): """Class for derived mutable. A derived mutable is a mutable derived from other mutables that has @@ -242,7 +242,7 @@ def __init__(self, # TODO # has no effect - def fix_chosen(self, chosen: CHOICE_TYPE) -> None: + def fix_chosen(self, chosen) -> None: """Fix mutable with subnet config. Warning: @@ -253,7 +253,7 @@ def fix_chosen(self, chosen: CHOICE_TYPE) -> None: 'which will have no effect.', level=logging.WARNING) - def dump_chosen(self) -> CHOICE_TYPE: + def dump_chosen(self) -> DumpChosen: """Dump information of chosen. Returns: @@ -263,6 +263,9 @@ def dump_chosen(self) -> CHOICE_TYPE: 'Trying to dump chosen for derived mutable, ' 'but its value depend on the source mutables.', level=logging.WARNING) + return DumpChosen(chosen=self.export_chosen(), meta=None) + + def export_chosen(self): return self.current_choice @property @@ -314,12 +317,12 @@ def num_choices(self) -> int: return 1 @property - def current_choice(self) -> CHOICE_TYPE: + def current_choice(self): """Current choice of derived mutable.""" return self.choice_fn() @current_choice.setter - def current_choice(self, choice: CHOICE_TYPE) -> None: + def current_choice(self, choice) -> None: """Setter of current choice. Raises: diff --git a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py index 28f1e4854..65d5a44d6 100644 --- a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py @@ -4,6 +4,7 @@ import torch +from mmrazor.utils.typing import DumpChosen from ..base_mutable import BaseMutable from ..derived_mutable import DerivedMethodMixin @@ -20,9 +21,9 @@ class BaseMutableChannel(BaseMutable, DerivedMethodMixin): |mutable_out_channel(BaseMutableChannel)| |---------------------------------------| - All subclasses should implement the following APIs: + All subclasses should implement the following APIs and the other + abstract method in ``BaseMutable`` - - ``current_choice`` - ``current_mask`` Args: @@ -34,20 +35,6 @@ def __init__(self, num_channels: int, **kwargs): self.name = '' self.num_channels = num_channels - # choice - - @property # type: ignore - @abstractmethod - def current_choice(self): - """get current choice.""" - raise NotImplementedError() - - @current_choice.setter # type: ignore - @abstractmethod - def current_choice(self): - """set current choice.""" - raise NotImplementedError() - @property # type: ignore @abstractmethod def current_mask(self) -> torch.Tensor: @@ -73,9 +60,15 @@ def fix_chosen(self, chosen=None): self.is_fixed = True - def dump_chosen(self): - """dump current choice to a dict.""" - raise NotImplementedError() + def dump_chosen(self) -> DumpChosen: + """Dump chosen.""" + meta = dict(max_channels=self.mask.size(0)) + chosen = self.export_chosen() + + return DumpChosen(chosen=chosen, meta=meta) + + def export_chosen(self) -> int: + return self.activated_channels def num_choices(self) -> int: """Number of available choices.""" diff --git a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py index eae559d41..9b891e349 100644 --- a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py @@ -69,10 +69,6 @@ def fix_chosen(self, chosen=...): self.current_choice = chosen self.is_fixed = True - def dump_chosen(self): - """Dump chosen.""" - return self.current_choice - def __rmul__(self, other) -> DerivedMutable: return self * other diff --git a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py index 576412ec0..e494b4018 100644 --- a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py @@ -149,9 +149,10 @@ class ChannelUnit(BaseModule): def __init__(self, num_channels: int, **kwargs): super().__init__() + self.num_channels = num_channels - self.output_related: nn.ModuleList = nn.ModuleList() - self.input_related: nn.ModuleList = nn.ModuleList() + self.output_related: List[nn.Module] = list() + self.input_related: List[nn.Module] = list() self.init_args: Dict = { } # is used to generate new channel unit with same args @@ -208,14 +209,14 @@ def init_from_graph(cls, def init_from_base_channel_unit(base_channel_unit: BaseChannelUnit): unit = cls(len(base_channel_unit.channel_elems), **unit_args) - unit.input_related = nn.ModuleList([ + unit.input_related = [ Channel.init_from_base_channel(channel) for channel in base_channel_unit.input_related - ]) - unit.output_related = nn.ModuleList([ + ] + unit.output_related = [ Channel.init_from_base_channel(channel) for channel in base_channel_unit.output_related - ]) + ] return unit unit_graph = ChannelGraph.copy_from(graph, @@ -239,6 +240,11 @@ def name(self) -> str: name = f'{first_module_name}_{self.num_channels}' return name + @property + def alias(self) -> str: + """str: alias of the unit""" + return self.name + def config_template(self, with_init_args=False, with_channels=False) -> Dict: diff --git a/mmrazor/models/mutables/mutable_module/diff_mutable_module.py b/mmrazor/models/mutables/mutable_module/diff_mutable_module.py index 5e44c330d..e524ec67c 100644 --- a/mmrazor/models/mutables/mutable_module/diff_mutable_module.py +++ b/mmrazor/models/mutables/mutable_module/diff_mutable_module.py @@ -9,13 +9,13 @@ from torch import Tensor from mmrazor.registry import MODELS -from ..base_mutable import CHOICE_TYPE, CHOSEN_TYPE +from mmrazor.utils.typing import DumpChosen from .mutable_module import MutableModule PartialType = Callable[[Any, Optional[nn.Parameter]], Any] -class DiffMutableModule(MutableModule[CHOICE_TYPE, CHOSEN_TYPE]): +class DiffMutableModule(MutableModule): """Base class for differentiable mutables. Args: @@ -34,9 +34,12 @@ class DiffMutableModule(MutableModule[CHOICE_TYPE, CHOSEN_TYPE]): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - def forward(self, - x: Any, - arch_param: Optional[nn.Parameter] = None) -> Any: + @abstractmethod + def sample_choice(self, arch_param: Tensor): + """Sample choice according arch parameters.""" + raise NotImplementedError + + def forward(self, x: Any, arch_param: Optional[nn.Parameter] = None): """Calls either :func:`forward_fixed` or :func:`forward_arch_param` depending on whether :func:`is_fixed` is ``True`` and whether :func:`arch_param` is None. @@ -60,27 +63,17 @@ def forward(self, if self.is_fixed: return self.forward_fixed(x) else: - return self.forward_arch_param(x, arch_param=arch_param) + if arch_param is None: + return self.forward_all(x) + else: + return self.forward_arch_param(x, arch_param=arch_param) def compute_arch_probs(self, arch_param: nn.Parameter) -> Tensor: """compute chosen probs according to architecture params.""" return F.softmax(arch_param, -1) @abstractmethod - def forward_fixed(self, x: Any) -> Any: - """Forward when the mutable is fixed. - - All subclasses must implement this method. - """ - - @abstractmethod - def forward_all(self, x: Any) -> Any: - """Forward all choices.""" - - @abstractmethod - def forward_arch_param(self, - x: Any, - arch_param: Optional[nn.Parameter] = None) -> Any: + def forward_arch_param(self, x, arch_param: nn.Parameter): """Forward when the mutable is not fixed. All subclasses must implement this method. @@ -94,7 +87,7 @@ def set_forward_args(self, arch_param: nn.Parameter) -> None: @MODELS.register_module() -class DiffMutableOP(DiffMutableModule[str, str]): +class DiffMutableOP(DiffMutableModule): """A type of ``MUTABLES`` for differentiable architecture search, such as DARTS. Search the best module by learnable parameters `arch_param`. @@ -159,7 +152,7 @@ def _build_ops(candidates: Dict[str, Dict], ops[name] = MODELS.build(op_cfg) return ops - def forward_fixed(self, x: Any) -> Tensor: + def forward_fixed(self, x) -> Tensor: """Forward when the mutable is in `fixed` mode. Args: @@ -171,10 +164,7 @@ def forward_fixed(self, x: Any) -> Tensor: """ return sum(self._candidates[choice](x) for choice in self._chosen) - def forward_arch_param(self, - x: Any, - arch_param: Optional[nn.Parameter] = None - ) -> Tensor: + def forward_arch_param(self, x, arch_param: nn.Parameter) -> Tensor: """Forward with architecture parameters. Args: @@ -187,21 +177,19 @@ def forward_arch_param(self, Returns: Tensor: the result of forward with ``arch_param``. """ - if arch_param is None: - return self.forward_all(x) - else: - # compute the probs of choice - probs = self.compute_arch_probs(arch_param=arch_param) - # forward based on probs - outputs = list() - for prob, module in zip(probs, self._candidates.values()): - if prob > 0.: - outputs.append(prob * module(x)) + # compute the probs of choice + probs = self.compute_arch_probs(arch_param=arch_param) - return sum(outputs) + # forward based on probs + outputs = list() + for prob, module in zip(probs, self._candidates.values()): + if prob > 0.: + outputs.append(prob * module(x)) - def forward_all(self, x: Any) -> Tensor: + return sum(outputs) + + def forward_all(self, x) -> Tensor: """Forward all choices. Used to calculate FLOPs. Args: @@ -240,12 +228,16 @@ def fix_chosen(self, chosen: Union[str, List[str]]) -> None: self._chosen = chosen self.is_fixed = True - def sample_choice(self, arch_param): + def sample_choice(self, arch_param: Tensor) -> str: """Sample choice based on arch_parameters.""" return self.choices[torch.argmax(arch_param).item()] - def dump_chosen(self): - """Dump current choice.""" + def dump_chosen(self) -> DumpChosen: + chosen = self.export_chosen() + meta = dict(all_choices=self.choices) + return DumpChosen(chosen=chosen, meta=meta) + + def export_chosen(self) -> str: assert self.current_choice is not None return self.current_choice @@ -297,10 +289,11 @@ def sample_weights(self, m = D.one_hot_categorical.OneHotCategorical(probs=probs) return m.sample() - def forward_arch_param(self, - x: Any, - arch_param: Optional[nn.Parameter] = None - ) -> Tensor: + def forward_arch_param( + self, + x: Any, + arch_param: nn.Parameter, + ) -> Tensor: """Forward with architecture parameters. Args: @@ -312,39 +305,35 @@ def forward_arch_param(self, Returns: Tensor: the result of forward with ``arch_param``. """ - if arch_param is None: - return self.forward_all(x) - else: - # compute the probs of choice - probs = self.compute_arch_probs(arch_param=arch_param) - - if not self.is_fixed: - self.arch_weights = self.sample_weights(arch_param, probs) - sorted_param = torch.topk(probs, 2) - index = ( - sorted_param[0][0] - sorted_param[0][1] >= - self.fix_threshold) - if index: - self.fix_chosen(self.choices[index]) - - if self.is_fixed: - index = self.choices.index(self._chosen[0]) - self.arch_weights.data.zero_() - self.arch_weights.data[index].fill_(1.0) - self.arch_weights.requires_grad_() - - # forward based on self.arch_weights - outputs = list() - for prob, module in zip(self.arch_weights, - self._candidates.values()): - if prob > 0.: - outputs.append(prob * module(x)) - - return sum(outputs) + + # compute the probs of choice + probs = self.compute_arch_probs(arch_param=arch_param) + + if not self.is_fixed: + self.arch_weights = self.sample_weights(arch_param, probs) + sorted_param = torch.topk(probs, 2) + index = ( + sorted_param[0][0] - sorted_param[0][1] >= self.fix_threshold) + if index: + self.fix_chosen(self.choices[index]) + + if self.is_fixed: + index = self.choices.index(self._chosen[0]) + self.arch_weights.data.zero_() + self.arch_weights.data[index].fill_(1.0) + self.arch_weights.requires_grad_() + + # forward based on self.arch_weights + outputs = list() + for prob, module in zip(self.arch_weights, self._candidates.values()): + if prob > 0.: + outputs.append(prob * module(x)) + + return sum(outputs) @MODELS.register_module() -class DiffChoiceRoute(DiffMutableModule[str, List[str]]): +class DiffChoiceRoute(DiffMutableModule): """A type of ``MUTABLES`` for Neural Architecture Search, which can select inputs from different edges in a differentiable or non-differentiable way. It is commonly used in DARTS. @@ -404,6 +393,35 @@ def __init__( self._candidates: nn.ModuleDict = edges self.num_chosen = num_chosen + def forward(self, x: Any, arch_param: Optional[nn.Parameter] = None): + """Calls either :func:`forward_fixed` or :func:`forward_arch_param` + depending on whether :func:`is_fixed` is ``True`` and whether + :func:`arch_param` is None. + + To reduce the coupling between `Mutable` and `Mutator`, the + `arch_param` is generated by the `Mutator` and is passed to the + forward function as an argument. + + Note: + :meth:`forward_fixed` is called when in `fixed` mode. + :meth:`forward_arch_param` is called when in `unfixed` mode. + + Args: + x (Any): input data for forward computation. + arch_param (nn.Parameter, optional): the architecture parameters + for ``DiffMutableModule``. + + Returns: + Any: the result of forward + """ + if self.is_fixed: + return self.forward_fixed(x) + else: + if arch_param is not None and self._with_arch_param: + return self.forward_arch_param(x, arch_param=arch_param) + else: + return self.forward_all(x) + def forward_fixed(self, inputs: Union[List, Tuple]) -> Tensor: """Forward when the mutable is in `fixed` mode. @@ -424,10 +442,7 @@ def forward_fixed(self, inputs: Union[List, Tuple]) -> Tensor: outputs.append(self._candidates[choice](x)) return sum(outputs) - def forward_arch_param( - self, - x: Union[List[Any], Tuple[Any]], - arch_param: Optional[nn.Parameter] = None) -> Tensor: + def forward_arch_param(self, x, arch_param: nn.Parameter) -> Tensor: """Forward with architecture parameters. Args: @@ -443,21 +458,17 @@ def forward_arch_param( f'Length of `edges` {len(self._candidates)} should be ' \ f'same as the length of inputs {len(x)}.' - if self._with_arch_param: - probs = self.compute_arch_probs(arch_param=arch_param) + probs = self.compute_arch_probs(arch_param=arch_param) - outputs = list() - for prob, module, input in zip(probs, self._candidates.values(), - x): - if prob > 0: - # prob may equal to 0 in gumbel softmax. - outputs.append(prob * module(input)) + outputs = list() + for prob, module, input in zip(probs, self._candidates.values(), x): + if prob > 0: + # prob may equal to 0 in gumbel softmax. + outputs.append(prob * module(input)) - return sum(outputs) - else: - return self.forward_all(x) + return sum(outputs) - def forward_all(self, x: Any) -> Tensor: + def forward_all(self, x): """Forward all choices. Args: @@ -500,16 +511,20 @@ def fix_chosen(self, chosen: List[str]) -> None: self.is_fixed = True @property - def choices(self) -> List[CHOSEN_TYPE]: + def choices(self) -> List[str]: """list: all choices. """ return list(self._candidates.keys()) - def dump_chosen(self): - """dump current choice.""" + def dump_chosen(self) -> DumpChosen: + chosen = self.export_chosen() + meta = dict(all_choices=self.choices) + return DumpChosen(chosen=chosen, meta=meta) + + def export_chosen(self) -> str: assert self.current_choice is not None return self.current_choice - def sample_choice(self, arch_param): + def sample_choice(self, arch_param: Tensor) -> List[str]: """sample choice based on `arch_param`.""" sort_idx = torch.argsort(-arch_param).cpu().numpy().tolist() choice_idx = sort_idx[:self.num_chosen] diff --git a/mmrazor/models/mutables/mutable_module/mutable_module.py b/mmrazor/models/mutables/mutable_module/mutable_module.py index 8840fd783..c71f1a969 100644 --- a/mmrazor/models/mutables/mutable_module/mutable_module.py +++ b/mmrazor/models/mutables/mutable_module/mutable_module.py @@ -2,20 +2,22 @@ from abc import abstractmethod from typing import Any, Dict, List, Optional -from ..base_mutable import CHOICE_TYPE, CHOSEN_TYPE, BaseMutable +from ..base_mutable import BaseMutable -class MutableModule(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]): +class MutableModule(BaseMutable): """Base Class for mutables. Mutable means a searchable module widely used in Neural Architecture Search(NAS). It mainly consists of some optional operations, and achieving searchable function by handling choice with ``MUTATOR``. - All subclass should implement the following APIs: + All subclass should implement the following APIs and the other + abstract method in ``BaseMutable``: - ``forward()`` - - ``fix_chosen()`` + - ``forward_all()`` + - ``forward_fix()`` - ``choices()`` Args: @@ -30,20 +32,48 @@ class MutableModule(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]): def __init__(self, module_kwargs: Optional[Dict[str, Dict]] = None, - **kwargs) -> None: - super().__init__(**kwargs) + alias: Optional[str] = None, + init_cfg: Optional[Dict] = None) -> None: + super().__init__(alias, init_cfg) self.module_kwargs = module_kwargs + self._current_choice = None + + @property + def current_choice(self): + """Current choice will affect :meth:`forward` and will be used in + :func:`mmrazor.core.subnet.utils.export_fix_subnet` or mutator. + """ + return self._current_choice + + @current_choice.setter + def current_choice(self, choice) -> None: + """Current choice setter will be executed in mutator.""" + self._current_choice = choice @property @abstractmethod - def choices(self) -> List[CHOICE_TYPE]: + def choices(self) -> List[str]: """list: all choices. All subclasses must implement this method.""" @abstractmethod def forward(self, x: Any) -> Any: """Forward computation.""" + @abstractmethod + def forward_fixed(self, x): + """Forward with the fixed mutable. + + All subclasses must implement this method. + """ + + @abstractmethod + def forward_all(self, x): + """Forward all choices. + + All subclasses must implement this method. + """ + @property def num_choices(self) -> int: """Number of choices.""" diff --git a/mmrazor/models/mutables/mutable_module/one_shot_mutable_module.py b/mmrazor/models/mutables/mutable_module/one_shot_mutable_module.py index f04c61eb5..434b05079 100644 --- a/mmrazor/models/mutables/mutable_module/one_shot_mutable_module.py +++ b/mmrazor/models/mutables/mutable_module/one_shot_mutable_module.py @@ -8,30 +8,20 @@ from torch import Tensor from mmrazor.registry import MODELS -from ..base_mutable import CHOICE_TYPE, CHOSEN_TYPE +from mmrazor.utils.typing import DumpChosen from .mutable_module import MutableModule -class OneShotMutableModule(MutableModule[CHOICE_TYPE, CHOSEN_TYPE]): +class OneShotMutableModule(MutableModule): """Base class for one shot mutable module. A base type of ``MUTABLES`` for single path supernet such as Single Path One Shot. - All subclass should implement the following APIs: + All subclass should implement the following APIs and the other + abstract method in ``MutableModule``: - ``sample_choice()`` - - ``forward_fixed()`` - - ``forward_all()`` - ``forward_choice()`` - Args: - module_kwargs (dict[str, dict], optional): Module initialization named - arguments. Defaults to None. - alias (str, optional): alias of the `MUTABLE`. - init_cfg (dict, optional): initialization configuration dict for - ``BaseModule``. OpenMMLab has implement 5 initializer including - `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, - and `Pretrained`. - Note: :meth:`forward_all` is called when calculating FLOPs. """ @@ -63,29 +53,15 @@ def forward(self, x: Any) -> Any: return self.forward_choice(x, choice=self.current_choice) @abstractmethod - def sample_choice(self) -> CHOICE_TYPE: + def sample_choice(self) -> str: """Sample random choice. Returns: - CHOICE_TYPE: the chosen key in ``MUTABLE``. - """ - - @abstractmethod - def forward_fixed(self, x: Any) -> Any: - """Forward with the fixed mutable. - - All subclasses must implement this method. - """ - - @abstractmethod - def forward_all(self, x: Any) -> Any: - """Forward all choices. - - All subclasses must implement this method. + str: the chosen key in ``MUTABLE``. """ @abstractmethod - def forward_choice(self, x: Any, choice: CHOICE_TYPE) -> Any: + def forward_choice(self, x, choice: str): """Forward with the unfixed mutable and current_choice is not None. All subclasses must implement this method. @@ -93,7 +69,7 @@ def forward_choice(self, x: Any, choice: CHOICE_TYPE) -> Any: @MODELS.register_module() -class OneShotMutableOP(OneShotMutableModule[str, str]): +class OneShotMutableOP(OneShotMutableModule): """A type of ``MUTABLES`` for single path supernet, such as Single Path One Shot. In single path supernet, each choice block only has one choice invoked at the same time. A path is obtained by sampling all the choice @@ -117,7 +93,6 @@ class OneShotMutableOP(OneShotMutableModule[str, str]): >>> candidates = nn.ModuleDict({ ... 'conv3x3': nn.Conv2d(32, 32, 3, 1, 1), ... 'conv5x5': nn.Conv2d(32, 32, 5, 1, 2), - ... 'conv7x7': nn.Conv2d(32, 32, 7, 1, 3)}) >>> input = torch.randn(1, 32, 64, 64) >>> op = OneShotMutableOP(candidates) @@ -214,7 +189,7 @@ def forward_fixed(self, x: Any) -> Tensor: """ return self._candidates[self._chosen](x) - def forward_choice(self, x: Any, choice: str) -> Tensor: + def forward_choice(self, x, choice: str) -> Tensor: """Forward with the `unfixed` mutable and current choice is not None. Args: @@ -228,7 +203,7 @@ def forward_choice(self, x: Any, choice: str) -> Tensor: assert isinstance(choice, str) and choice in self.choices return self._candidates[choice](x) - def forward_all(self, x: Any) -> Tensor: + def forward_all(self, x) -> Tensor: """Forward all choices. Used to calculate FLOPs. Args: @@ -263,9 +238,13 @@ def fix_chosen(self, chosen: str) -> None: self._chosen = chosen self.is_fixed = True - def dump_chosen(self) -> str: - assert self.current_choice is not None + def dump_chosen(self) -> DumpChosen: + chosen = self.export_chosen() + meta = dict(all_choices=self.choices) + return DumpChosen(chosen=chosen, meta=meta) + def export_chosen(self) -> str: + assert self.current_choice is not None return self.current_choice def sample_choice(self) -> str: @@ -277,10 +256,6 @@ def choices(self) -> List[str]: """list: all choices. """ return list(self._candidates.keys()) - @property - def num_choices(self): - return len(self.choices) - @MODELS.register_module() class OneShotProbMutableOP(OneShotMutableOP): diff --git a/mmrazor/models/mutables/mutable_value/mutable_value.py b/mmrazor/models/mutables/mutable_value/mutable_value.py index 49a0c870f..20055287d 100644 --- a/mmrazor/models/mutables/mutable_value/mutable_value.py +++ b/mmrazor/models/mutables/mutable_value/mutable_value.py @@ -3,12 +3,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union from mmrazor.registry import MODELS +from mmrazor.utils.typing import DumpChosen from ..base_mutable import BaseMutable from ..derived_mutable import DerivedMethodMixin, DerivedMutable +Value = Union[int, float] + @MODELS.register_module() -class MutableValue(BaseMutable[Any, Dict], DerivedMethodMixin): +class MutableValue(BaseMutable, DerivedMethodMixin): """Base class for mutable value. A mutable value is actually a mutable that adds some functionality to a @@ -26,7 +29,7 @@ class MutableValue(BaseMutable[Any, Dict], DerivedMethodMixin): """ def __init__(self, - value_list: List[Any], + value_list: List[Value], default_value: Optional[Any] = None, alias: Optional[str] = None, init_cfg: Optional[Dict] = None) -> None: @@ -59,7 +62,7 @@ def choices(self) -> List[Any]: """List of choices.""" return self._value_list - def fix_chosen(self, chosen: Dict[str, Any]) -> None: + def fix_chosen(self, chosen: Value) -> None: """Fix mutable value with subnet config. Args: @@ -68,24 +71,23 @@ def fix_chosen(self, chosen: Dict[str, Any]) -> None: if self.is_fixed: raise RuntimeError('MutableValue can not be fixed twice') - all_choices = chosen['all_choices'] - current_choice = chosen['current_choice'] + assert chosen in self.choices - assert all_choices == self.choices, \ - f'Expect choices to be: {self.choices}, but got: {all_choices}' - assert current_choice in self.choices - - self.current_choice = current_choice + self.current_choice = chosen self.is_fixed = True - def dump_chosen(self) -> Dict[str, Any]: + def dump_chosen(self) -> DumpChosen: """Dump information of chosen. Returns: Dict[str, Any]: Dumped information. """ - return dict( - current_choice=self.current_choice, all_choices=self.choices) + chosen = self.export_chosen() + meta = dict(all_choices=self.choices) + return DumpChosen(chosen=chosen, meta=meta) + + def export_chosen(self): + return self.current_choice @property def num_choices(self) -> int: diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb b/mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb index 1d7aad669..307ffc669 100644 --- a/mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb +++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -63,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -116,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -175,7 +175,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -258,19 +258,19 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'net.conv0_(0, 8)_8': 8, 'net.conv1_(0, 16)_16': 16}\n" + "{0: 8, 1: 16}\n" ] } ], "source": [ - "print(mutator.choice_template)" + "print(mutator.current_choices)" ] }, { @@ -282,7 +282,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -296,14 +296,14 @@ " 3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)\n", " (mutable_attrs): ModuleDict(\n", " (in_channels): MutableChannelContainer(num_channels=3, activated_channels=3)\n", - " (out_channels): MutableChannelContainer(num_channels=8, activated_channels=6)\n", + " (out_channels): MutableChannelContainer(num_channels=8, activated_channels=4)\n", " )\n", " )\n", " (relu): ReLU()\n", " (conv1): DynamicConv2d(\n", " 8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)\n", " (mutable_attrs): ModuleDict(\n", - " (in_channels): MutableChannelContainer(num_channels=8, activated_channels=6)\n", + " (in_channels): MutableChannelContainer(num_channels=8, activated_channels=4)\n", " (out_channels): MutableChannelContainer(num_channels=16, activated_channels=8)\n", " )\n", " )\n", @@ -322,7 +322,7 @@ ], "source": [ "mutator.set_choices(\n", - " {'net.conv0_(0, 8)_8': 0.75, 'net.conv1_(0, 16)_16': 0.5}\n", + " {0: 4, 1: 8}\n", ")\n", "print(model)" ] @@ -337,7 +337,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.12 ('mmlab')", + "display_name": "Python 3.9.13 ('lab2max')", "language": "python", "name": "python3" }, @@ -351,12 +351,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.9.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "feec882ee78c63cb8d4b485f1b52bbb873bb9a7b094435863200c7afba202382" + "hash": "e31a827d0913016ad78e01c7b97f787f4b9e53102dd62d238e8548bcd97ff875" } } }, diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.py b/mmrazor/models/mutators/channel_mutator/channel_mutator.py index c04aa0204..7a19f1c72 100644 --- a/mmrazor/models/mutators/channel_mutator/channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -from typing import Dict, Generic, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, Union from mmengine import fileio -from torch.nn import Module +from torch.nn import Module, ModuleList from mmrazor.models.architectures.dynamic_ops import DynamicChannelMixin from mmrazor.models.mutables import (ChannelUnitType, MutableChannelUnit, @@ -13,6 +13,7 @@ from mmrazor.registry import MODELS from mmrazor.structures.graph import ModuleGraph from ..base_mutator import BaseMutator +from ..group_mixin import GroupMixin def is_dynamic_op_for_fx_tracer(module, name): @@ -20,7 +21,7 @@ def is_dynamic_op_for_fx_tracer(module, name): @MODELS.register_module() -class ChannelMutator(BaseMutator, Generic[ChannelUnitType]): +class ChannelMutator(BaseMutator, Generic[ChannelUnitType], GroupMixin): """ChannelMutator manages the pruning structure of a model. Args: @@ -48,6 +49,10 @@ class ChannelMutator(BaseMutator, Generic[ChannelUnitType]): dict( type='BackwardTracer', loss_calculator=dict(type='ImageClassifierPseudoLoss')). + custom_groups (list[list[str]], optional): User-defined search groups. + All searchable modules that are not in ``custom_group`` will be + grouped separately. + init_cfg (dict, optional): initialization configuration dict for BaseModule. @@ -70,6 +75,7 @@ def __init__(self, parse_cfg: Dict = dict( type='BackwardTracer', loss_calculator=dict(type='ImageClassifierPseudoLoss')), + custom_groups: Optional[List[List[str]]] = None, init_cfg: Optional[Dict] = None) -> None: super().__init__(init_cfg) @@ -83,7 +89,7 @@ def __init__(self, # units self._name2unit: Dict[str, ChannelUnitType] = {} - self.units: List[ChannelUnitType] = [] + self.units: ModuleList[ChannelUnitType] = ModuleList() # unit config self.channel_unit_cfg = channel_unit_cfg @@ -91,6 +97,10 @@ def __init__(self, self._parse_channel_unit_cfg( channel_unit_cfg) + if custom_groups is None: + custom_groups = [] + self._custom_groups = custom_groups + def prepare_from_supernet(self, supernet: Module) -> None: """Prepare from a model for pruning. @@ -113,7 +123,11 @@ def prepare_from_supernet(self, supernet: Module) -> None: for unit in units: unit.prepare_for_pruning(supernet) self._name2unit[unit.name] = unit - self.units = units + self.units = ModuleList(units) + + self._search_groups = self.build_search_groups( + ModuleList(self.mutable_units), self.mutable_class_type, + self._custom_groups) # ~ @@ -193,23 +207,40 @@ def fix_channel_mutables(self): @property def current_choices(self) -> Dict: """Get current choices.""" - config = self.choice_template - for unit in self.mutable_units: - config[unit.name] = unit.current_choice - return config - - def set_choices(self, config: Dict[str, Union[int, float]]): - """Set choices.""" - for name, choice in config.items(): - unit = self._name2unit[name] - unit.current_choice = choice - - def sample_choices(self) -> Dict[str, Union[int, float]]: - """Sample choices(pruning structure).""" - template = self.choice_template - for key in template: - template[key] = self._name2unit[key].sample_choice() - return template + current_choices = dict() + for group_id, modules in self.search_groups.items(): + current_choices[group_id] = modules[0].current_choice + + return current_choices + + def sample_choices(self) -> Dict[int, Any]: + """Sampling by search groups. + + The sampling result of the first mutable of each group is the sampling + result of this group. + + Returns: + Dict[int, Any]: Random choices dict. + """ + random_choices = dict() + for group_id, modules in self.search_groups.items(): + random_choices[group_id] = modules[0].sample_choice() + + return random_choices + + def set_choices(self, choices: Dict[int, Any]) -> None: + """Set mutables' current choice according to choices sample by + :func:`sample_choices`. + + Args: + choices (Dict[int, Any]): Choices dict. The key is group_id in + search groups, and the value is the sampling results + corresponding to this group. + """ + for group_id, modules in self.search_groups.items(): + choice = choices[group_id] + for module in modules: + module.current_choice = choice @property def choice_template(self) -> Dict: @@ -226,12 +257,24 @@ def choice_template(self) -> Dict: template[unit.name] = unit.current_choice return template - # implementation of abstract functions + @property + def search_groups(self) -> Dict[int, List]: + """Search group of the supernet. + + Note: + Search group is different from search space. The key of search + group is called ``group_id``, and the value is corresponding + searchable modules. The searchable modules will have the same + search space if they are in the same group. - def search_groups(self) -> Dict: - return self._name2unit + Returns: + dict: Search group. + """ + return self._search_groups + @property def mutable_class_type(self) -> Type[ChannelUnitType]: + """Mutable class type supported by this mutator.""" return self.unit_class # private methods diff --git a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py index a5350ab2b..1f8e3496b 100644 --- a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py @@ -28,14 +28,16 @@ def __init__(self, def min_choices(self) -> Dict: """Return the minimal pruning subnet(structure).""" - template = self.choice_template - for key in template: - template[key] = self._name2unit[key].min_choice - return template + min_choices = dict() + for group_id, modules in self.search_groups.items(): + min_choices[group_id] = modules[0].min_choice + + return min_choices def max_choices(self) -> Dict: """Return the maximal pruning subnet(structure).""" - template = self.choice_template - for key in template: - template[key] = self._name2unit[key].max_choice - return template + max_choices = dict() + for group_id, modules in self.search_groups.items(): + max_choices[group_id] = modules[0].max_choice + + return max_choices diff --git a/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py index 7c0d24fa6..9f5eb0075 100644 --- a/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py @@ -28,10 +28,20 @@ def __init__(self, loss_calculator=dict(type='ImageClassifierPseudoLoss')), init_cfg: Optional[Dict] = None) -> None: - super().__init__(channel_unit_cfg, parse_cfg, init_cfg) + super().__init__(channel_unit_cfg, parse_cfg, None, init_cfg) self.subnets = self._prepare_subnets(self.units_cfg) + def set_choices(self, config: Dict[str, float]): # type: ignore[override] + """Set choices.""" + for name, choice in config.items(): + unit = self._name2unit[name] + unit.current_choice = choice + + def sample_choices(self): + """Sample choices(pruning structure).""" + raise RuntimeError + # private methods def _prepare_subnets(self, unit_cfg: Dict) -> List[Dict[str, int]]: diff --git a/mmrazor/models/mutators/group_mixin.py b/mmrazor/models/mutators/group_mixin.py new file mode 100644 index 000000000..7e735b263 --- /dev/null +++ b/mmrazor/models/mutators/group_mixin.py @@ -0,0 +1,222 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from collections import Counter +from typing import Dict, List, Type + +from torch.nn import Module + +from ..mutables import BaseMutable + + +class GroupMixin(): + """A mixin for :class:`BaseMutator`, which can group mutables by + ``custom_group`` and ``alias``(see more information in + :class:`BaseMutable`). Grouping by alias and module name are both + supported. + + Note: + Apart from user-defined search group, all other searchable + modules(mutable) will be grouped separately. + + The main difference between using alias and module name for + grouping is that the alias is One-to-Many while the module + name is One-to-One. + + When using both alias and module name in `custom_group`, the + priority of alias is higher than that of module name. + + If alias is set in `custom_group`, then its corresponding module + name should not be in the `custom_group`. + + Moreover, there should be no duplicate keys in the `custom_group`. + + Example: + >>> import torch + >>> from mmrazor.models import DiffModuleMutator + + >>> # Assume that a toy model consists of three mutables + >>> # whose name are op1,op2,op3. The corresponding + >>> # alias names of the three mutables are a1, a1, a2. + >>> model = ToyModel() + + >>> # Using alias for grouping + >>> mutator = DiffModuleMutator(custom_group=[['a1'], ['a2']]) + >>> mutator.prepare_from_supernet(model) + >>> mutator.search_groups + {0: [op1, op2], 1: [op3]} + + >>> # Using module name for grouping + >>> mutator = DiffModuleMutator(custom_group=[['op1', 'op2'], ['op3']]) + + >>> # Using module name for grouping + >>> mutator.prepare_from_supernet(model) + >>> mutator.search_groups + {0: [op1, op2], 1: [op3]} + + >>> # Using both alias and module name for grouping + >>> mutator = DiffModuleMutator(custom_group=[['a2'], ['op2']]) + >>> mutator.prepare_from_supernet(model) + >>> # The last operation would be grouped + >>> mutator.search_groups + {0: [op3], 1: [op2], 2: [op1]} + + """ + + def _build_name_mutable_mapping( + self, supernet: Module, + support_mutables: Type) -> Dict[str, BaseMutable]: + """Mapping module name to mutable.""" + name2mutable: Dict[str, BaseMutable] = dict() + for name, module in supernet.named_modules(): + if isinstance(module, support_mutables): + name2mutable[name] = module + self._name2mutable = name2mutable + + return name2mutable + + def _build_alias_names_mapping( + self, supernet: Module, + support_mutables: Type) -> Dict[str, List[str]]: + """Mapping alias to module names.""" + alias2mutable_names: Dict[str, List[str]] = dict() + for name, module in supernet.named_modules(): + if isinstance(module, support_mutables): + + if module.alias is not None: + if module.alias not in alias2mutable_names: + alias2mutable_names[module.alias] = [name] + else: + alias2mutable_names[module.alias].append(name) + + return alias2mutable_names + + def build_search_groups(self, supernet: Module, support_mutables: Type, + custom_groups: List[List[str]]) -> Dict[int, List]: + """Build search group with ``custom_group`` and ``alias``(see more + information in :class:`BaseMutable`). Grouping by alias and module name + are both supported. + + Args: + supernet (:obj:`torch.nn.Module`): The supernet to be searched + in your algorithm. + support_mutables (Type): Mutable type that can be grouped. + custom_group (list, optional): User-defined search groups. + All searchable modules that are not in ``custom_group`` will be + grouped separately. + """ + name2mutable: Dict[str, + BaseMutable] = self._build_name_mutable_mapping( + supernet, support_mutables) + alias2mutable_names = self._build_alias_names_mapping( + supernet, support_mutables) + + # Check whether the custom group is valid + if len(custom_groups) > 0: + self._check_valid_groups(alias2mutable_names, name2mutable, + custom_groups) + + # Construct search_groups based on user-defined group + search_groups: Dict[int, List[BaseMutable]] = dict() + + current_group_nums = 0 + grouped_mutable_names: List[str] = list() + grouped_alias: List[str] = list() + for group in custom_groups: + group_mutables = list() + for item in group: + if item in alias2mutable_names: + # if the item is from alias name + mutable_names: List[str] = alias2mutable_names[item] + grouped_alias.append(item) + group_mutables.extend( + [name2mutable[n] for n in mutable_names]) + grouped_mutable_names.extend(mutable_names) + else: + # if the item is in name2mutable + group_mutables.append(name2mutable[item]) + grouped_mutable_names.append(item) + + search_groups[current_group_nums] = group_mutables + current_group_nums += 1 + + # Construct search_groups based on alias + for alias, mutable_names in alias2mutable_names.items(): + if alias not in grouped_alias: + # Check whether all current names are already grouped + flag_all_grouped = True + for mutable_name in mutable_names: + if mutable_name not in grouped_mutable_names: + flag_all_grouped = False + + # If not all mutables are already grouped + if not flag_all_grouped: + search_groups[current_group_nums] = [] + for mutable_name in mutable_names: + if mutable_name not in grouped_mutable_names: + search_groups[current_group_nums].append( + name2mutable[mutable_name]) + grouped_mutable_names.append(mutable_name) + current_group_nums += 1 + + # check whether all the mutable objects are in the search_groups + for name, module in supernet.named_modules(): + if isinstance(module, support_mutables): + if name in grouped_mutable_names: + continue + else: + search_groups[current_group_nums] = [module] + current_group_nums += 1 + + grouped_counter = Counter(grouped_mutable_names) + + # find duplicate keys + duplicate_keys = list() + for key, count in grouped_counter.items(): + if count > 1: + duplicate_keys.append(key) + + assert len(grouped_mutable_names) == len( + list(set(grouped_mutable_names))), \ + 'There are duplicate keys in grouped mutable names. ' \ + f'The duplicate keys are {duplicate_keys}. ' \ + 'Please check if there are duplicate keys in the `custom_group`.' + + return search_groups + + def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]], + name2mutable: Dict[str, BaseMutable], + custom_group: List[List[str]]) -> None: + + aliases = [*alias2mutable_names.keys()] + module_names = [*name2mutable.keys()] + + # check if all keys are legal + expanded_custom_group: List[str] = [ + _ for group in custom_group for _ in group + ] + legal_keys: List[str] = [*aliases, *module_names] + + for key in expanded_custom_group: + if key not in legal_keys: + raise AssertionError( + f'The key: {key} in `custom_group` is not legal. ' + f'Legal keys are: {legal_keys}. ' + 'Make sure that the keys are either alias or mutable name') + + # when the mutable has alias attribute, the corresponding module + # name should not be used in `custom_group`. + used_aliases = list() + for group in custom_group: + for key in group: + if key in aliases: + used_aliases.append(key) + + for alias_key in used_aliases: + mutable_names: List = alias2mutable_names[alias_key] + # check whether module name is in custom group + for mutable_name in mutable_names: + if mutable_name in expanded_custom_group: + raise AssertionError( + f'When a mutable is set alias attribute :{alias_key},' + f'the corresponding module name {mutable_name} should ' + f'not be used in `custom_group` {custom_group}.') diff --git a/mmrazor/models/mutators/module_mutator/diff_module_mutator.py b/mmrazor/models/mutators/module_mutator/diff_module_mutator.py index ac6358049..1f639ed28 100644 --- a/mmrazor/models/mutators/module_mutator/diff_module_mutator.py +++ b/mmrazor/models/mutators/module_mutator/diff_module_mutator.py @@ -25,9 +25,9 @@ class DiffModuleMutator(ModuleMutator): """ def __init__(self, - custom_group: Optional[List[List[str]]] = None, + custom_groups: Optional[List[List[str]]] = None, init_cfg: Optional[Dict] = None) -> None: - super().__init__(custom_group=custom_group, init_cfg=init_cfg) + super().__init__(custom_groups=custom_groups, init_cfg=init_cfg) def build_arch_param(self, num_choices) -> nn.Parameter: """Build learnable architecture parameters.""" diff --git a/mmrazor/models/mutators/module_mutator/module_mutator.py b/mmrazor/models/mutators/module_mutator/module_mutator.py index dc045932b..f30e933e0 100644 --- a/mmrazor/models/mutators/module_mutator/module_mutator.py +++ b/mmrazor/models/mutators/module_mutator/module_mutator.py @@ -1,14 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import abstractmethod -from collections import Counter from typing import Dict, List, Optional, Type from torch.nn import Module from ..base_mutator import MUTABLE_TYPE, BaseMutator +from ..group_mixin import GroupMixin -class ModuleMutator(BaseMutator[MUTABLE_TYPE]): +class ModuleMutator(BaseMutator[MUTABLE_TYPE], GroupMixin): """The base class for mutable based mutator. All subclass should implement the following APIS: @@ -16,19 +16,19 @@ class ModuleMutator(BaseMutator[MUTABLE_TYPE]): - ``mutable_class_type`` Args: - custom_group (list[list[str]], optional): User-defined search groups. + custom_groups (list[list[str]], optional): User-defined search groups. All searchable modules that are not in ``custom_group`` will be grouped separately. """ def __init__(self, - custom_group: Optional[List[List[str]]] = None, + custom_groups: Optional[List[List[str]]] = None, init_cfg: Optional[Dict] = None) -> None: super().__init__(init_cfg) - if custom_group is None: - custom_group = [] - self._custom_group = custom_group + if custom_groups is None: + custom_groups = [] + self._custom_groups = custom_groups self._search_groups: Optional[Dict[int, List[MUTABLE_TYPE]]] = None # TODO @@ -52,7 +52,9 @@ def prepare_from_supernet(self, supernet: Module) -> None: supernet (:obj:`torch.nn.Module`): The supernet to be searched in your algorithm. """ - self._build_search_groups(supernet) + self._search_groups = self.build_search_groups(supernet, + self.mutable_class_type, + self._custom_groups) @property def name2mutable(self) -> Dict[str, MUTABLE_TYPE]: @@ -90,196 +92,3 @@ def search_groups(self) -> Dict[int, List[MUTABLE_TYPE]]: raise RuntimeError( 'Call `prepare_from_supernet` before access search group!') return self._search_groups - - def _build_name_mutable_mapping( - self, supernet: Module) -> Dict[str, MUTABLE_TYPE]: - """Mapping module name to mutable.""" - name2mutable: Dict[str, MUTABLE_TYPE] = dict() - for name, module in supernet.named_modules(): - if isinstance(module, self.mutable_class_type): - name2mutable[name] = module - self._name2mutable = name2mutable - - return name2mutable - - def _build_alias_names_mapping(self, - supernet: Module) -> Dict[str, List[str]]: - """Mapping alias to module names.""" - alias2mutable_names: Dict[str, List[str]] = dict() - for name, module in supernet.named_modules(): - if isinstance(module, self.mutable_class_type): - if module.alias is not None: - if module.alias not in alias2mutable_names: - alias2mutable_names[module.alias] = [name] - else: - alias2mutable_names[module.alias].append(name) - - return alias2mutable_names - - def _build_search_groups(self, supernet: Module) -> None: - """Build search group with ``custom_group`` and ``alias``(see more - information in :class:`BaseMutable`). Grouping by alias and module name - are both supported. - - Note: - Apart from user-defined search group, all other searchable - modules(mutable) will be grouped separately. - - The main difference between using alias and module name for - grouping is that the alias is One-to-Many while the module - name is One-to-One. - - When using both alias and module name in `custom_group`, the - priority of alias is higher than that of module name. - - If alias is set in `custom_group`, then its corresponding module - name should not be in the `custom_group`. - - Moreover, there should be no duplicate keys in the `custom_group`. - - Example: - >>> import torch - >>> from mmrazor.models.mutables.diff_mutable import DiffMutableOP - - >>> # Assume that a toy model consists of three mutables - >>> # whose name are op1,op2,op3. The corresponding - >>> # alias names of the three mutables are a1, a1, a2. - >>> model = ToyModel() - - >>> # Using alias for grouping - >>> mutator = DiffMutableOP(custom_group=[['a1'], ['a2']]) - >>> mutator.prepare_from_supernet(model) - >>> mutator.search_groups - {0: [op1, op2], 1: [op3]} - - >>> # Using module name for grouping - >>> mutator = DiffMutableOP(custom_group=[['op1', 'op2'], ['op3']]) - >>> mutator.prepare_from_supernet(model) - >>> mutator.search_groups - {0: [op1, op2], 1: [op3]} - - >>> # Using both alias and module name for grouping - >>> mutator = DiffMutableOP(custom_group=[['a2'], ['op2']]) - >>> mutator.prepare_from_supernet(model) - >>> # The last operation would be grouped - >>> mutator.search_groups - {0: [op3], 1: [op2], 2: [op1]} - - - Args: - supernet (:obj:`torch.nn.Module`): The supernet to be searched - in your algorithm. - """ - name2mutable = self._build_name_mutable_mapping(supernet) - alias2mutable_names = self._build_alias_names_mapping(supernet) - - # Check whether the custom group is valid - if len(self._custom_group) > 0: - self._check_valid_groups(alias2mutable_names, name2mutable, - self._custom_group) - - # Construct search_groups based on user-defined group - search_groups: Dict[int, List[MUTABLE_TYPE]] = dict() - - current_group_nums = 0 - grouped_mutable_names: List[str] = list() - grouped_alias: List[str] = list() - for group in self._custom_group: - group_mutables = list() - for item in group: - if item in alias2mutable_names: - # if the item is from alias name - mutable_names: List[str] = alias2mutable_names[item] - grouped_alias.append(item) - group_mutables.extend( - [name2mutable[n] for n in mutable_names]) - grouped_mutable_names.extend(mutable_names) - else: - # if the item is in name2mutable - group_mutables.append(name2mutable[item]) - grouped_mutable_names.append(item) - - search_groups[current_group_nums] = group_mutables - current_group_nums += 1 - - # Construct search_groups based on alias - for alias, mutable_names in alias2mutable_names.items(): - if alias not in grouped_alias: - # Check whether all current names are already grouped - flag_all_grouped = True - for mutable_name in mutable_names: - if mutable_name not in grouped_mutable_names: - flag_all_grouped = False - - # If not all mutables are already grouped - if not flag_all_grouped: - search_groups[current_group_nums] = [] - for mutable_name in mutable_names: - if mutable_name not in grouped_mutable_names: - search_groups[current_group_nums].append( - name2mutable[mutable_name]) - grouped_mutable_names.append(mutable_name) - current_group_nums += 1 - - # check whether all the mutable objects are in the search_groups - for name, module in supernet.named_modules(): - if isinstance(module, self.mutable_class_type): - if name in grouped_mutable_names: - continue - else: - search_groups[current_group_nums] = [module] - current_group_nums += 1 - - grouped_counter = Counter(grouped_mutable_names) - - # find duplicate keys - duplicate_keys = list() - for key, count in grouped_counter.items(): - if count > 1: - duplicate_keys.append(key) - - assert len(grouped_mutable_names) == len( - list(set(grouped_mutable_names))), \ - 'There are duplicate keys in grouped mutable names. ' \ - f'The duplicate keys are {duplicate_keys}. ' \ - 'Please check if there are duplicate keys in the `custom_group`.' - - self._search_groups = search_groups - - def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]], - name2mutable: Dict[str, MUTABLE_TYPE], - custom_group: List[List[str]]) -> None: - - aliases = [*alias2mutable_names.keys()] - module_names = [*name2mutable.keys()] - - # check if all keys are legal - expanded_custom_group: List[str] = [ - _ for group in custom_group for _ in group - ] - legal_keys: List[str] = [*aliases, *module_names] - - for key in expanded_custom_group: - if key not in legal_keys: - raise AssertionError( - f'The key: {key} in `custom_group` is not legal. ' - f'Legal keys are: {legal_keys}. ' - 'Make sure that the keys are either alias or mutable name') - - # when the mutable has alias attribute, the corresponding module - # name should not be used in `custom_group`. - used_aliases = list() - for group in custom_group: - for key in group: - if key in aliases: - used_aliases.append(key) - - for alias_key in used_aliases: - mutable_names: List = alias2mutable_names[alias_key] - # check whether module name is in custom group - for mutable_name in mutable_names: - if mutable_name in expanded_custom_group: - raise AssertionError( - f'When a mutable is set alias attribute :{alias_key},' - f'the corresponding module name {mutable_name} should ' - f'not be used in `custom_group` {custom_group}.') diff --git a/mmrazor/structures/subnet/fix_subnet.py b/mmrazor/structures/subnet/fix_subnet.py index 4eb515371..625e65025 100644 --- a/mmrazor/structures/subnet/fix_subnet.py +++ b/mmrazor/structures/subnet/fix_subnet.py @@ -6,6 +6,7 @@ from torch import nn from mmrazor.utils import FixMutable, ValidFixMutable +from mmrazor.utils.typing import DumpChosen def _dynamic_to_static(model: nn.Module) -> None: @@ -56,6 +57,7 @@ def load_fix_subnet(model: nn.Module, assert alias in fix_mutable, \ f'The alias {alias} is not in fix_modules, ' \ 'please check your `fix_mutable`.' + # {chosen=xx, meta=xx) chosen = fix_mutable.get(alias, None) else: mutable_name = name.lstrip(prefix) @@ -64,8 +66,12 @@ def load_fix_subnet(model: nn.Module, raise RuntimeError( f'The module name {mutable_name} is not in ' 'fix_mutable, please check your `fix_mutable`.') + # {chosen=xx, meta=xx) chosen = fix_mutable.get(mutable_name, None) - module.fix_chosen(chosen) + + if not isinstance(chosen, DumpChosen): + chosen = DumpChosen(**chosen) + module.fix_chosen(chosen.chosen) # convert dynamic op to static op _dynamic_to_static(model) diff --git a/mmrazor/utils/typing.py b/mmrazor/utils/typing.py index 1166d580f..0d1126f2a 100644 --- a/mmrazor/utils/typing.py +++ b/mmrazor/utils/typing.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, NamedTuple, Optional, Union FixMutable = Dict[str, Any] ValidFixMutable = Union[str, Path, FixMutable] @@ -23,3 +23,15 @@ SupportRandomSubnet = Union[SingleMutatorRandomSubnet, MultiMutatorsRandomSubnet] + +Chosen = Union[str, float, List[str]] +ChosenMeta = Optional[Dict[str, Any]] + + +class DumpChosen(NamedTuple): + chosen: Chosen + meta: ChosenMeta = None + + +# DumpChosen = NamedTuple('DumpChosen', [('chosen', Chosen), +# ('meta', ChosenMeta)]) diff --git a/tests/data/models.py b/tests/data/models.py index 60c8a7058..867adc0c9 100644 --- a/tests/data/models.py +++ b/tests/data/models.py @@ -513,12 +513,25 @@ def _expand_mask(): def dump_chosen(self): return super().dump_chosen() + def export_chosen(self): + return super().export_chosen() + def fix_chosen(self, chosen): return super().fix_chosen(chosen) def num_choices(self) -> int: return super().num_choices + @property + def current_choice(self): + return super().current_choice + + @current_choice.setter + def current_choice(self, choice): + super().current_choice(choice) + + + class DynamicLinearModel(nn.Module): """ diff --git a/tests/data/test_models/test_subnet/mockmodel_subnet.yaml b/tests/data/test_models/test_subnet/mockmodel_subnet.yaml index 36e3a9ce0..8d92a99b5 100644 --- a/tests/data/test_models/test_subnet/mockmodel_subnet.yaml +++ b/tests/data/test_models/test_subnet/mockmodel_subnet.yaml @@ -1,2 +1,4 @@ -mutable1: conv1 -mutable2: conv2 +mutable1: + chosen: conv1 +mutable2: + chosen: conv2 diff --git a/tests/data/test_registry/registry_subnet_config.py b/tests/data/test_registry/registry_subnet_config.py index 28ba3a0ea..539a1cdb1 100644 --- a/tests/data/test_registry/registry_subnet_config.py +++ b/tests/data/test_registry/registry_subnet_config.py @@ -7,7 +7,7 @@ type='MockAlgorithm', architecture=supernet, _fix_subnet_ = { - 'architecture.mutable1': 'conv1', - 'architecture.mutable2': 'conv2', + 'architecture.mutable1': {'chosen':'conv1'}, + 'architecture.mutable2': {'chosen':'conv2'}, } ) diff --git a/tests/test_models/test_algorithms/test_darts.py b/tests/test_models/test_algorithms/test_darts.py index 7d33fa047..52f5d10e6 100644 --- a/tests/test_models/test_algorithms/test_darts.py +++ b/tests/test_models/test_algorithms/test_darts.py @@ -104,7 +104,11 @@ def test_init(self) -> None: self.assertIsInstance(algo.mutator, DiffModuleMutator) # initiate darts when `fix_subnet` is not None - fix_subnet = {'normal': ['torch_conv2d_3x3', 'torch_conv2d_7x7']} + fix_subnet = { + 'normal': { + 'chosen': ['torch_conv2d_3x3', 'torch_conv2d_7x7'] + } + } algo = Darts(model, mutator, fix_subnet=fix_subnet) self.assertEqual(algo.architecture.mutable.num_choices, 2) @@ -124,7 +128,11 @@ def test_forward_loss(self) -> None: self.assertIsInstance(loss, dict) # subnet - fix_subnet = {'normal': ['torch_conv2d_3x3', 'torch_conv2d_7x7']} + fix_subnet = { + 'normal': { + 'chosen': ['torch_conv2d_3x3', 'torch_conv2d_7x7'] + } + } algo = Darts(model, fix_subnet=fix_subnet) loss = algo(inputs, mode='loss') self.assertIsInstance(loss, dict) diff --git a/tests/test_models/test_algorithms/test_dsnas.py b/tests/test_models/test_algorithms/test_dsnas.py index 9f6dfc902..2b5bbfa49 100644 --- a/tests/test_models/test_algorithms/test_dsnas.py +++ b/tests/test_models/test_algorithms/test_dsnas.py @@ -14,8 +14,8 @@ from torch import Tensor from torch.optim import SGD -from mmrazor.models import DiffModuleMutator, Dsnas, OneHotMutableOP -from mmrazor.models.algorithms.nas.dsnas import DsnasDDP +from mmrazor.models import DSNAS, DiffModuleMutator, OneHotMutableOP +from mmrazor.models.algorithms.nas.dsnas import DSNASDDP from mmrazor.registry import MODELS MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True) @@ -81,29 +81,29 @@ def test_init(self) -> None: # initiate dsnas when `norm_training` is True. model = ToyDiffModule() mutator = DiffModuleMutator() - algo = Dsnas(architecture=model, mutator=mutator, norm_training=True) + algo = DSNAS(architecture=model, mutator=mutator, norm_training=True) algo.eval() self.assertTrue(model.bn.training) # initiate Dsnas with built mutator model = ToyDiffModule() mutator = DiffModuleMutator() - algo = Dsnas(model, mutator) + algo = DSNAS(model, mutator) self.assertIs(algo.mutator, mutator) # initiate Dsnas with unbuilt mutator mutator = dict(type='DiffModuleMutator') - algo = Dsnas(model, mutator) + algo = DSNAS(model, mutator) self.assertIsInstance(algo.mutator, DiffModuleMutator) # initiate Dsnas when `fix_subnet` is not None - fix_subnet = {'mutable': 'torch_conv2d_5x5'} - algo = Dsnas(model, mutator, fix_subnet=fix_subnet) + fix_subnet = {'mutable': {'chosen': 'torch_conv2d_5x5'}} + algo = DSNAS(model, mutator, fix_subnet=fix_subnet) self.assertEqual(algo.architecture.mutable.num_choices, 1) # initiate Dsnas with error type `mutator` with self.assertRaisesRegex(TypeError, 'mutator should be'): - Dsnas(model, model) + DSNAS(model, model) def test_forward_loss(self) -> None: inputs = torch.randn(1, 3, 8, 8) @@ -112,13 +112,13 @@ def test_forward_loss(self) -> None: # supernet mutator = DiffModuleMutator() mutator.prepare_from_supernet(model) - algo = Dsnas(model, mutator) + algo = DSNAS(model, mutator) loss = algo(inputs, mode='loss') self.assertIsInstance(loss, dict) # subnet - fix_subnet = {'mutable': 'torch_conv2d_5x5'} - algo = Dsnas(model, fix_subnet=fix_subnet) + fix_subnet = {'mutable': {'chosen': 'torch_conv2d_5x5'}} + algo = DSNAS(model, fix_subnet=fix_subnet) loss = algo(inputs, mode='loss') self.assertIsInstance(loss, dict) @@ -135,7 +135,7 @@ def test_search_subnet(self) -> None: mutator = DiffModuleMutator() mutator.prepare_from_supernet(model) - algo = Dsnas(model, mutator) + algo = DSNAS(model, mutator) subnet = algo.search_subnet() self.assertIsInstance(subnet, dict) @@ -146,14 +146,14 @@ def test_dsnas_train_step(self, mock_get_info) -> None: mutator.prepare_from_supernet(model) mock_get_info.return_value = 2 - algo = Dsnas(model, mutator) + algo = DSNAS(model, mutator) data = self._prepare_fake_data() optim_wrapper = build_optim_wrapper(algo, self.OPTIM_WRAPPER_CFG) loss = algo.train_step(data, optim_wrapper) self.assertTrue(isinstance(loss['loss'], Tensor)) - algo = Dsnas(model, mutator) + algo = DSNAS(model, mutator) optim_wrapper_dict = OptimWrapperDict( architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)), mutator=OptimWrapper(SGD(model.parameters(), lr=0.01))) @@ -173,16 +173,16 @@ def setUpClass(cls) -> None: backend = 'nccl' if torch.cuda.is_available() else 'gloo' dist.init_process_group(backend, rank=0, world_size=1) - def prepare_model(self, device_ids=None) -> Dsnas: + def prepare_model(self, device_ids=None) -> DSNAS: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' model = ToyDiffModule() mutator = DiffModuleMutator() mutator.prepare_from_supernet(model) - algo = Dsnas(model, mutator).to(self.device) + algo = DSNAS(model, mutator).to(self.device) - return DsnasDDP( + return DSNASDDP( module=algo, find_unused_parameters=True, device_ids=device_ids) @classmethod @@ -193,7 +193,7 @@ def tearDownClass(cls) -> None: not torch.cuda.is_available(), reason='cuda device is not avaliable') def test_init(self) -> None: ddp_model = self.prepare_model() - self.assertIsInstance(ddp_model, DsnasDDP) + self.assertIsInstance(ddp_model, DSNASDDP) @patch('mmengine.logging.message_hub.MessageHub.get_info') def test_dsnasddp_train_step(self, mock_get_info) -> None: diff --git a/tests/test_models/test_algorithms/test_prune_algorithm.py b/tests/test_models/test_algorithms/test_prune_algorithm.py index 519407772..3a00e93b9 100644 --- a/tests/test_models/test_algorithms/test_prune_algorithm.py +++ b/tests/test_models/test_algorithms/test_prune_algorithm.py @@ -114,7 +114,8 @@ def test_iterative_prune_int(self): model = MODELS.build(MODEL_CFG) mutator = MODELS.build(MUTATOR_CONFIG_FLOAT) mutator.prepare_from_supernet(model) - prune_target = mutator.sample_choices() + mutator.set_choices(mutator.sample_choices()) + prune_target = mutator.choice_template epoch = 10 epoch_step = 2 @@ -135,9 +136,11 @@ def test_iterative_prune_int(self): data['inputs'], data['data_samples'], mode='loss') current_choices = algorithm.mutator.current_choices + group_prune_target = algorithm.group_target_pruning_ratio( + prune_target, mutator.search_groups) for key in current_choices: self.assertAlmostEqual( - current_choices[key], prune_target[key], delta=0.1) + current_choices[key], group_prune_target[key], delta=0.1) def test_load_pretrained(self): epoch_step = 2 @@ -158,7 +161,7 @@ def test_load_pretrained(self): algorithm = ItePruneAlgorithm( model_cfg, mutator_cfg=MUTATOR_CONFIG_NUM, - target_pruning_ratio={}, + target_pruning_ratio=None, step_epoch=epoch_step, prune_times=times, ).to(DEVICE) @@ -167,3 +170,43 @@ def test_load_pretrained(self): # delete checkpoint os.remove(checkpoint_path) + + def test_group_target_ratio(self): + + model = MODELS.build(MODEL_CFG) + mutator = MODELS.build(MUTATOR_CONFIG_FLOAT) + mutator.prepare_from_supernet(model) + mutator.set_choices(mutator.sample_choices()) + prune_target = mutator.choice_template + + custom_groups = [[ + 'backbone.layer1.0.conv1_(0, 64)_64', + 'backbone.layer1.1.conv1_(0, 64)_64' + ]] + mutator_cfg = copy.deepcopy(MUTATOR_CONFIG_FLOAT) + mutator_cfg['custom_groups'] = custom_groups + + epoch_step = 2 + times = 3 + + prune_target['backbone.layer1.0.conv1_(0, 64)_64'] = 0.1 + prune_target['backbone.layer1.1.conv1_(0, 64)_64'] = 0.1 + + _ = ItePruneAlgorithm( + MODEL_CFG, + target_pruning_ratio=prune_target, + mutator_cfg=mutator_cfg, + step_epoch=epoch_step, + prune_times=times).to(DEVICE) + + prune_target['backbone.layer1.0.conv1_(0, 64)_64'] = 0.1 + prune_target['backbone.layer1.1.conv1_(0, 64)_64'] = 0.2 + + with self.assertRaises(ValueError): + + _ = ItePruneAlgorithm( + MODEL_CFG, + target_pruning_ratio=prune_target, + mutator_cfg=mutator_cfg, + step_epoch=epoch_step, + prune_times=times).to(DEVICE) diff --git a/tests/test_models/test_algorithms/test_spos.py b/tests/test_models/test_algorithms/test_spos.py index 3392f469c..f73521111 100644 --- a/tests/test_models/test_algorithms/test_spos.py +++ b/tests/test_models/test_algorithms/test_spos.py @@ -56,7 +56,7 @@ def test_init(self): self.assertIsInstance(alg.mutator, OneShotModuleMutator) # initiate spos when `fix_subnet` is not None. - fix_subnet = {'mutable': 'conv1'} + fix_subnet = {'mutable': {'chosen': 'conv1'}} alg = SPOS(model, mutator, fix_subnet=fix_subnet) self.assertEqual(alg.architecture.mutable.num_choices, 1) @@ -75,7 +75,7 @@ def test_forward_loss(self): self.assertIsInstance(loss, dict) # subnet - fix_subnet = {'mutable': 'conv1'} + fix_subnet = {'mutable': {'chosen': 'conv1'}} alg = SPOS(model, fix_subnet=fix_subnet) loss = alg(inputs, mode='loss') self.assertIsInstance(loss, dict) diff --git a/tests/test_models/test_architectures/test_backbones/test_dartsbackbone.py b/tests/test_models/test_architectures/test_backbones/test_dartsbackbone.py index acaa84b2b..ba3f5955d 100644 --- a/tests/test_models/test_architectures/test_backbones/test_dartsbackbone.py +++ b/tests/test_models/test_architectures/test_backbones/test_dartsbackbone.py @@ -55,7 +55,7 @@ def setUp(self) -> None: self.mutator_cfg = dict( type='DiffModuleMutator', - custom_group=None, + custom_groups=None, ) def test_darts_backbone(self): @@ -81,7 +81,7 @@ def test_darts_backbone_with_auxliary(self): custom_group = self.generate_key(model) assert model is not None - self.mutable_cfg.update(custom_group=custom_group) + self.mutable_cfg.update(custom_groups=custom_group) mutator = MODELS.build(self.mutator_cfg) assert mutator is not None mutator.prepare_from_supernet(model) diff --git a/tests/test_models/test_architectures/test_dynamic_op/utils.py b/tests/test_models/test_architectures/test_dynamic_op/utils.py index ceb2a5d4f..e448f300e 100644 --- a/tests/test_models/test_architectures/test_dynamic_op/utils.py +++ b/tests/test_models/test_architectures/test_dynamic_op/utils.py @@ -2,6 +2,7 @@ from typing import Dict, Optional from mmrazor.models.architectures.dynamic_ops import DynamicMixin +from mmrazor.utils.typing import DumpChosen def fix_dynamic_op(op: DynamicMixin, @@ -13,4 +14,7 @@ def fix_dynamic_op(op: DynamicMixin, else: chosen = mutable.dump_chosen() - mutable.fix_chosen(chosen) + if not isinstance(chosen, DumpChosen): + chosen = DumpChosen(**chosen) + + mutable.fix_chosen(chosen.chosen) diff --git a/tests/test_models/test_mutables/test_derived_mutable.py b/tests/test_models/test_mutables/test_derived_mutable.py index 3e87b0654..0b5f55e88 100644 --- a/tests/test_models/test_mutables/test_derived_mutable.py +++ b/tests/test_models/test_mutables/test_derived_mutable.py @@ -24,9 +24,9 @@ def test_is_fixed(self) -> None: with pytest.raises(RuntimeError): derived_mutable.is_fixed = True - mc.fix_chosen(mc.dump_chosen()) + mc.fix_chosen(mc.dump_chosen().chosen) assert not derived_mutable.is_fixed - mv.fix_chosen(mv.dump_chosen()) + mv.fix_chosen(mv.dump_chosen().chosen) assert derived_mutable.is_fixed def test_fix_dump_chosen(self) -> None: @@ -34,13 +34,13 @@ def test_fix_dump_chosen(self) -> None: mv.current_choice = 3 derived_mutable = mv * 2 - assert derived_mutable.dump_chosen() == 6 + assert derived_mutable.dump_chosen().chosen == 6 mv.current_choice = 4 - assert derived_mutable.dump_chosen() == 8 + assert derived_mutable.dump_chosen().chosen == 8 # nothing will happen - derived_mutable.fix_chosen(derived_mutable.dump_chosen()) + derived_mutable.fix_chosen(derived_mutable.dump_chosen().chosen) def test_derived_same_mutable(self) -> None: mc = SquentialMutableChannel(num_channels=3) diff --git a/tests/test_models/test_mutables/test_diffop.py b/tests/test_models/test_mutables/test_diffop.py index 702adf8e2..eab9fff2b 100644 --- a/tests/test_models/test_mutables/test_diffop.py +++ b/tests/test_models/test_mutables/test_diffop.py @@ -44,9 +44,6 @@ def test_forward_arch_param(self): output = op.forward_arch_param(input, arch_param=arch_param) assert output is not None - output = op.forward_arch_param(input, arch_param=None) - assert output is not None - # test when some element of arch_param is 0 arch_param = nn.Parameter(torch.ones(op.num_choices)) output = op.forward_arch_param(input, arch_param=arch_param) diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py b/tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py index c93a43842..79c552250 100644 --- a/tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py +++ b/tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import unittest -import pytest import torch from mmrazor.models.mutables import (SimpleMutableChannel, @@ -31,5 +30,5 @@ def test_SimpleMutableChannel(self): channel.current_choice = torch.tensor([1, 0, 0, 0]).bool() self.assertEqual(channel.activated_channels, 1) channel.fix_chosen() - with pytest.raises(NotImplementedError): - channel.dump_chosen() + # with pytest.raises(NotImplementedError): + # channel.dump_chosen() diff --git a/tests/test_models/test_mutables/test_mutable_value.py b/tests/test_models/test_mutables/test_mutable_value.py index d7d05b1d5..b33dfcc98 100644 --- a/tests/test_models/test_mutables/test_mutable_value.py +++ b/tests/test_models/test_mutables/test_mutable_value.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy from unittest import TestCase import pytest @@ -42,22 +41,13 @@ def test_init_one_shot_mutable_value(self) -> None: def test_fix_chosen(self) -> None: mv = MutableValue([2, 3, 4]) chosen = mv.dump_chosen() - assert chosen == { - 'current_choice': mv.current_choice, - 'all_choices': mv.choices - } + assert chosen.chosen == mv.current_choice + assert chosen.meta['all_choices'] == mv.choices - chosen['current_choice'] = 5 with pytest.raises(AssertionError): - mv.fix_chosen(chosen) - - chosen_copied = copy.deepcopy(chosen) - chosen_copied['all_choices'] = [1, 2, 3] - with pytest.raises(AssertionError): - mv.fix_chosen(chosen_copied) + mv.fix_chosen(5) - chosen['current_choice'] = 3 - mv.fix_chosen(chosen) + mv.fix_chosen(3) assert mv.current_choice == 3 with pytest.raises(RuntimeError): diff --git a/tests/test_models/test_mutables/test_onehotop.py b/tests/test_models/test_mutables/test_onehotop.py index 4ace5870d..a3b86d745 100644 --- a/tests/test_models/test_mutables/test_onehotop.py +++ b/tests/test_models/test_mutables/test_onehotop.py @@ -44,9 +44,6 @@ def test_forward_arch_param(self): output = op.forward_arch_param(input, arch_param=arch_param) assert output is not None - output = op.forward_arch_param(input, arch_param=None) - assert output is not None - # test when some element of arch_param is 0 arch_param = nn.Parameter(torch.ones(op.num_choices)) output = op.forward_arch_param(input, arch_param=arch_param) diff --git a/tests/test_models/test_mutators/test_channel_mutator.py b/tests/test_models/test_mutators/test_channel_mutator.py index 96908d807..3d6ed7773 100644 --- a/tests/test_models/test_mutators/test_channel_mutator.py +++ b/tests/test_models/test_mutators/test_channel_mutator.py @@ -134,3 +134,34 @@ def test_models_with_predefined_dynamic_op(self): parse_cfg={'type': 'Predefined'}) mutator.prepare_from_supernet(model) self._test_a_mutator(mutator, model) + + def test_custom_group(self): + ARCHITECTURE_CFG = dict( + type='mmcls.ImageClassifier', + backbone=dict(type='mmcls.MobileNetV2', widen_factor=1.5), + neck=dict(type='mmcls.GlobalAveragePooling'), + head=dict( + type='mmcls.LinearClsHead', + num_classes=1000, + in_channels=1920, + loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5))) + model = MODELS.build(ARCHITECTURE_CFG) + + # generate config + model1 = copy.deepcopy(model) + mutator1 = ChannelMutator() + mutator1.prepare_from_supernet(model1) + + self.assertEqual(len(mutator1.search_groups), 25) + + custom_groups = [[ + 'backbone.layer2.1.conv.0.conv_(0, 240)_240', + 'backbone.layer3.0.conv.0.conv_(0, 240)_240' + ]] + + model2 = copy.deepcopy(model) + mutator2 = ChannelMutator(custom_groups=custom_groups) + mutator2.prepare_from_supernet(model2) + + self.assertEqual(len(mutator2.search_groups), 24) diff --git a/tests/test_models/test_mutators/test_diff_mutator.py b/tests/test_models/test_mutators/test_diff_mutator.py index 5e230a202..663637fc9 100644 --- a/tests/test_models/test_mutators/test_diff_mutator.py +++ b/tests/test_models/test_mutators/test_diff_mutator.py @@ -98,7 +98,8 @@ def setUp(self): module_kwargs=dict(in_channels=32, out_channels=32, stride=1)) self.MUTATOR_CFG = dict( - type='DiffModuleMutator', custom_group=[['op1'], ['op2'], ['op3']]) + type='DiffModuleMutator', + custom_groups=[['op1'], ['op2'], ['op3']]) def test_diff_mutator_diffop_layer(self) -> None: model = SearchableLayer(self.MUTABLE_CFG) @@ -111,7 +112,7 @@ def test_diff_mutator_diffop_model(self) -> None: model = SearchableModel(self.MUTABLE_CFG) mutator_cfg = self.MUTATOR_CFG.copy() - mutator_cfg['custom_group'] = [ + mutator_cfg['custom_groups'] = [ ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], @@ -128,7 +129,7 @@ def test_diff_mutator_diffop_model_error(self) -> None: model = SearchableModel(self.MUTABLE_CFG) mutator_cfg = self.MUTATOR_CFG.copy() - mutator_cfg['custom_group'] = [ + mutator_cfg['custom_groups'] = [ ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], ['slayer1.op3', 'slayer2.op3', 'slayer3.op3_error_key'], @@ -142,7 +143,7 @@ def test_diff_mutator_diffop_alias(self) -> None: model = SearchableModelAlias(self.MUTABLE_CFG) mutator_cfg = self.MUTATOR_CFG.copy() - mutator_cfg['custom_group'] = [['op1'], ['op2'], ['op3']] + mutator_cfg['custom_groups'] = [['op1'], ['op2'], ['op3']] mutator: DiffModuleMutator = MODELS.build(mutator_cfg) mutator.prepare_from_supernet(model) @@ -157,11 +158,11 @@ def test_diff_mutator_alias_module_name(self) -> None: model = SearchableModelAlias(self.MUTABLE_CFG) mutator_cfg = self.MUTATOR_CFG.copy() - mutator_cfg['custom_group'] = [['op1'], - [ - 'slayer1.op2', 'slayer2.op2', - 'slayer3.op2' - ], ['slayer1.op3', 'slayer2.op3']] + mutator_cfg['custom_groups'] = [['op1'], + [ + 'slayer1.op2', 'slayer2.op2', + 'slayer3.op2' + ], ['slayer1.op3', 'slayer2.op3']] mutator: DiffModuleMutator = MODELS.build(mutator_cfg) mutator.prepare_from_supernet(model) @@ -175,7 +176,7 @@ def test_diff_mutator_duplicate_keys(self) -> None: model = SearchableModel(self.MUTABLE_CFG) mutator_cfg = self.MUTATOR_CFG.copy() - mutator_cfg['custom_group'] = [ + mutator_cfg['custom_groups'] = [ ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], ['slayer1.op3', 'slayer2.op3', 'slayer2.op3'], @@ -189,7 +190,7 @@ def test_diff_mutator_duplicate_key_alias(self) -> None: model = SearchableModelAlias(self.MUTABLE_CFG) mutator_cfg = self.MUTATOR_CFG.copy() - mutator_cfg['custom_group'] = [ + mutator_cfg['custom_groups'] = [ ['op1', 'slayer1.op1', 'slayer2.op1', 'slayer3.op1'], ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], @@ -203,7 +204,7 @@ def test_diff_mutator_illegal_key(self) -> None: model = SearchableModel(self.MUTABLE_CFG) mutator_cfg = self.MUTATOR_CFG.copy() - mutator_cfg['custom_group'] = [ + mutator_cfg['custom_groups'] = [ ['illegal_key', 'slayer1.op1', 'slayer2.op1', 'slayer3.op1'], ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], @@ -217,7 +218,7 @@ def test_sample_and_set_choices(self): model = SearchableModel(self.MUTABLE_CFG) mutator_cfg = self.MUTATOR_CFG.copy() - mutator_cfg['custom_group'] = [ + mutator_cfg['custom_groups'] = [ ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], diff --git a/tests/test_models/test_subnet/test_fix_subnet.py b/tests/test_models/test_subnet/test_fix_subnet.py index 010372212..0137a8274 100644 --- a/tests/test_models/test_subnet/test_fix_subnet.py +++ b/tests/test_models/test_subnet/test_fix_subnet.py @@ -57,8 +57,12 @@ def test_load_fix_subnet(self): # fix subnet is dict fix_subnet = { - 'mutable1': 'conv1', - 'mutable2': 'conv2', + 'mutable1': { + 'chosen': 'conv1' + }, + 'mutable2': { + 'chosen': 'conv2' + }, } model = MockModel() @@ -80,8 +84,12 @@ def test_load_fix_subnet(self): def test_export_fix_subnet(self): # get FixSubnet fix_subnet = { - 'mutable1': 'conv1', - 'mutable2': 'conv2', + 'mutable1': { + 'chosen': 'conv1' + }, + 'mutable2': { + 'chosen': 'conv2' + }, } model = MockModel() @@ -95,6 +103,14 @@ def test_export_fix_subnet(self): model.mutable2.current_choice = 'conv2' exported_fix_subnet = export_fix_subnet(model) + mutable1_dump_chosen = exported_fix_subnet['mutable1'] + mutable2_dump_chosen = exported_fix_subnet['mutable2'] + + mutable1_chosen_dict = dict(chosen=mutable1_dump_chosen.chosen) + mutable2_chosen_dict = dict(chosen=mutable2_dump_chosen.chosen) + + exported_fix_subnet['mutable1'] = mutable1_chosen_dict + exported_fix_subnet['mutable2'] = mutable2_chosen_dict self.assertDictEqual(fix_subnet, exported_fix_subnet) def test_export_fix_subnet_with_derived_mutable(self) -> None: @@ -102,7 +118,10 @@ def test_export_fix_subnet_with_derived_mutable(self) -> None: fix_subnet = export_fix_subnet(model) self.assertDictEqual( fix_subnet, {'source_mutable': model.source_mutable.dump_chosen()}) - fix_subnet['source_mutable']['current_choice'] = 4 + + fix_subnet['source_mutable'] = dict( + fix_subnet['source_mutable']._asdict()) + fix_subnet['source_mutable']['chosen'] = 4 load_fix_subnet(model, fix_subnet) assert model.source_mutable.current_choice == 4 assert model.derived_mutable.current_choice == 8 @@ -114,7 +133,10 @@ def test_export_fix_subnet_with_derived_mutable(self) -> None: 'source_mutable': model.source_mutable.dump_chosen(), 'derived_mutable': model.derived_mutable.dump_chosen() }) - fix_subnet['source_mutable']['current_choice'] = 2 + + fix_subnet['source_mutable'] = dict( + fix_subnet['source_mutable']._asdict()) + fix_subnet['source_mutable']['chosen'] = 2 load_fix_subnet(model, fix_subnet) assert model.source_mutable.current_choice == 2 assert model.derived_mutable.current_choice == 4