diff --git a/configs/_base_/nas_backbones/dsnas_shufflenet_supernet.py b/configs/_base_/nas_backbones/dsnas_shufflenet_supernet.py new file mode 100644 index 000000000..f73c8b90e --- /dev/null +++ b/configs/_base_/nas_backbones/dsnas_shufflenet_supernet.py @@ -0,0 +1,28 @@ +norm_cfg = dict(type='BN', eps=0.01) + +_STAGE_MUTABLE = dict( + type='mmrazor.OneHotMutableOP', + fix_threshold=0.3, + candidates=dict( + shuffle_3x3=dict( + type='ShuffleBlock', kernel_size=3, norm_cfg=norm_cfg), + shuffle_5x5=dict( + type='ShuffleBlock', kernel_size=5, norm_cfg=norm_cfg), + shuffle_7x7=dict( + type='ShuffleBlock', kernel_size=7, norm_cfg=norm_cfg), + shuffle_xception=dict(type='ShuffleXception', norm_cfg=norm_cfg))) + +arch_setting = [ + # Parameters to build layers. 3 parameters are needed to construct a + # layer, from left to right: channel, num_blocks, mutable_cfg. + [64, 4, _STAGE_MUTABLE], + [160, 4, _STAGE_MUTABLE], + [320, 8, _STAGE_MUTABLE], + [640, 4, _STAGE_MUTABLE] +] + +nas_backbone = dict( + type='mmrazor.SearchableShuffleNetV2', + widen_factor=1.0, + arch_setting=arch_setting, + norm_cfg=norm_cfg) diff --git a/configs/_base_/settings/imagenet_bs1024_dsnas.py b/configs/_base_/settings/imagenet_bs1024_dsnas.py new file mode 100644 index 000000000..bf266c51c --- /dev/null +++ b/configs/_base_/settings/imagenet_bs1024_dsnas.py @@ -0,0 +1,102 @@ +# dataset settings +dataset_type = 'mmcls.ImageNet' +data_preprocessor = dict( + type='mmcls.ClsDataPreprocessor', + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type='mmcls.LoadImageFromFile'), + dict(type='mmcls.RandomResizedCrop', scale=224), + dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'), + dict(type='mmcls.PackClsInputs'), +] + +test_pipeline = [ + dict(type='mmcls.LoadImageFromFile'), + dict(type='mmcls.ResizeEdge', scale=256, edge='short'), + dict(type='mmcls.CenterCrop', crop_size=224), + dict(type='mmcls.PackClsInputs'), +] + +train_dataloader = dict( + batch_size=128, + num_workers=4, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/train.txt', + data_prefix='train', + pipeline=train_pipeline), + sampler=dict(type='mmcls.DefaultSampler', shuffle=True), + persistent_workers=True, +) + +val_dataloader = dict( + batch_size=128, + num_workers=4, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/val.txt', + data_prefix='val', + pipeline=test_pipeline), + sampler=dict(type='mmcls.DefaultSampler', shuffle=False), + persistent_workers=True, +) +val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator + +# optimizer +paramwise_cfg = dict(bias_decay_mult=0.0, norm_decay_mult=0.0) + +optim_wrapper = dict( + constructor='mmrazor.SeparateOptimWrapperConstructor', + architecture=dict( + optimizer=dict( + type='mmcls.SGD', lr=0.5, momentum=0.9, weight_decay=4e-5), + paramwise_cfg=paramwise_cfg), + mutator=dict( + optimizer=dict( + type='mmcls.Adam', lr=0.001, weight_decay=0.0, betas=(0.5, + 0.999)))) + +search_epochs = 85 +# leanring policy +param_scheduler = dict( + architecture=[ + dict( + type='mmcls.LinearLR', + end=5, + start_factor=0.2, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='mmcls.CosineAnnealingLR', + T_max=240, + begin=5, + end=search_epochs, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='mmcls.CosineAnnealingLR', + T_max=160, + begin=search_epochs, + end=240, + eta_min=0.0, + by_epoch=True, + convert_to_iter_based=True) + ], + mutator=[]) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=240) +val_cfg = dict() +test_cfg = dict() diff --git a/configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml b/configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml new file mode 100644 index 000000000..d2fa294d3 --- /dev/null +++ b/configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml @@ -0,0 +1,20 @@ +backbone.layers.0.0: shuffle_3x3 +backbone.layers.0.1: shuffle_3x3 +backbone.layers.0.2: shuffle_xception +backbone.layers.0.3: shuffle_3x3 +backbone.layers.1.0: shuffle_xception +backbone.layers.1.1: shuffle_7x7 +backbone.layers.1.2: shuffle_3x3 +backbone.layers.1.3: shuffle_3x3 +backbone.layers.2.0: shuffle_xception +backbone.layers.2.1: shuffle_xception +backbone.layers.2.2: shuffle_7x7 +backbone.layers.2.3: shuffle_xception +backbone.layers.2.4: shuffle_xception +backbone.layers.2.5: shuffle_xception +backbone.layers.2.6: shuffle_7x7 +backbone.layers.2.7: shuffle_3x3 +backbone.layers.3.0: shuffle_3x3 +backbone.layers.3.1: shuffle_xception +backbone.layers.3.2: shuffle_xception +backbone.layers.3.3: shuffle_3x3 diff --git a/configs/nas/mmcls/dsnas/README.md b/configs/nas/mmcls/dsnas/README.md new file mode 100644 index 000000000..6a085eb78 --- /dev/null +++ b/configs/nas/mmcls/dsnas/README.md @@ -0,0 +1,43 @@ +# DSNAS + +> [DSNAS: Direct Neural Architecture Search without Parameter Retraining](https://arxiv.org/abs/2002.09128.pdf) + + + +## Abstract + +Most existing NAS methods require two-stage parameter optimization. +However, performance of the same architecture in the two stages correlates poorly. +Based on this observation, DSNAS proposes a task-specific end-to-end differentiable NAS framework that simultaneously optimizes architecture and parameters with a low-biased Monte Carlo estimate. Child networks derived from DSNAS can be deployed directly without parameter retraining. + +![pipeline](/docs/en/imgs/model_zoo/dsnas/pipeline.jpg) + +## Results and models + +### Supernet + +| Dataset | Params(M) | FLOPs (G) | Top-1 Acc (%) | Top-5 Acc (%) | Config | Download | Remarks | +| :------: | :-------: | :-------: | :-----------: | :-----------: | :---------------------------------------: | :----------------------: | :--------------: | +| ImageNet | 3.33 | 0.299 | 73.56 | 91.24 | [config](./dsnas_supernet_8xb128_in1k.py) | [model](<>) \| [log](<>) | MMRazor searched | + +**Note**: + +1. There **might be(not all the case)** some small differences in our experiment in order to be consistent with other repos in OpenMMLab. For example, + normalize images in data preprocessing; resize by cv2 rather than PIL in training; dropout is not used in network. **Please refer to corresponding config for details.** +2. We convert the official searched checkpoint DSNASsearch240.pth into mmrazor-style and evaluate with pytorch1.8_cuda11.0, Top-1 is 74.1 and Top-5 is 91.51. +3. The implementation of ShuffleNetV2 in official DSNAS is different from OpenMMLab's and we follow the structure design in OpenMMLab. Note that with the + origin ShuffleNetV2 design in official DSNAS, the Top-1 is 73.92 and Top-5 is 91.59. +4. The finetune stage in our implementation refers to the 'search-from-search' stage mentioned in official DSNAS. +5. We obtain params and FLOPs using `mmrazor.ResourceEstimator`, which may be different from the origin repo. + +## Citation + +```latex +@inproceedings{hu2020dsnas, + title={Dsnas: Direct neural architecture search without parameter retraining}, + author={Hu, Shoukang and Xie, Sirui and Zheng, Hehui and Liu, Chunxiao and Shi, Jianping and Liu, Xunying and Lin, Dahua}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={12084--12092}, + year={2020} +} +``` diff --git a/configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py b/configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py new file mode 100644 index 000000000..ca30a5946 --- /dev/null +++ b/configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py @@ -0,0 +1,29 @@ +_base_ = ['./dsnas_supernet_8xb128_in1k.py'] + +# NOTE: Replace this with the mutable_cfg searched by yourself. +fix_subnet = { + 'backbone.layers.0.0': 'shuffle_3x3', + 'backbone.layers.0.1': 'shuffle_7x7', + 'backbone.layers.0.2': 'shuffle_3x3', + 'backbone.layers.0.3': 'shuffle_5x5', + 'backbone.layers.1.0': 'shuffle_3x3', + 'backbone.layers.1.1': 'shuffle_3x3', + 'backbone.layers.1.2': 'shuffle_3x3', + 'backbone.layers.1.3': 'shuffle_7x7', + 'backbone.layers.2.0': 'shuffle_xception', + 'backbone.layers.2.1': 'shuffle_3x3', + 'backbone.layers.2.2': 'shuffle_3x3', + 'backbone.layers.2.3': 'shuffle_5x5', + 'backbone.layers.2.4': 'shuffle_3x3', + 'backbone.layers.2.5': 'shuffle_5x5', + 'backbone.layers.2.6': 'shuffle_7x7', + 'backbone.layers.2.7': 'shuffle_7x7', + 'backbone.layers.3.0': 'shuffle_xception', + 'backbone.layers.3.1': 'shuffle_3x3', + 'backbone.layers.3.2': 'shuffle_7x7', + 'backbone.layers.3.3': 'shuffle_3x3', +} + +model = dict(fix_subnet=fix_subnet) + +find_unused_parameters = False diff --git a/configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py b/configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py new file mode 100644 index 000000000..ea821da40 --- /dev/null +++ b/configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py @@ -0,0 +1,36 @@ +_base_ = [ + 'mmrazor::_base_/settings/imagenet_bs1024_dsnas.py', + 'mmrazor::_base_/nas_backbones/dsnas_shufflenet_supernet.py', + 'mmcls::_base_/default_runtime.py', +] + +# model +model = dict( + type='mmrazor.Dsnas', + architecture=dict( + type='ImageClassifier', + data_preprocessor=_base_.data_preprocessor, + backbone=_base_.nas_backbone, + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=1024, + loss=dict( + type='LabelSmoothLoss', + num_classes=1000, + label_smooth_val=0.1, + mode='original', + loss_weight=1.0), + topk=(1, 5))), + mutator=dict(type='mmrazor.DiffModuleMutator'), + pretrain_epochs=15, + finetune_epochs=_base_.search_epochs, +) + +model_wrapper_cfg = dict( + type='mmrazor.DsnasDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +randomness = dict(seed=48, diff_rank_seed=True) diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py index fd221b577..f2df86a83 100644 --- a/mmrazor/engine/__init__.py +++ b/mmrazor/engine/__init__.py @@ -10,6 +10,6 @@ 'SeparateOptimWrapperConstructor', 'DumpSubnetHook', 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', - 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop', - 'EstimateResourcesHook' + 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'EstimateResourcesHook', + 'SelfDistillValLoop' ] diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index bbc0e5755..2d96f3a96 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -3,12 +3,13 @@ from .distill import (DAFLDataFreeDistillation, DataFreeDistillation, FpnTeacherDistill, OverhaulFeatureDistillation, SelfDistill, SingleTeacherDistill) -from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP +from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP, Dsnas, DsnasDDP from .pruning import SlimmableNetwork, SlimmableNetworkDDP __all__ = [ 'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS', 'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation', - 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation' + 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'Dsnas', + 'DsnasDDP' ] diff --git a/mmrazor/models/algorithms/nas/__init__.py b/mmrazor/models/algorithms/nas/__init__.py index b18fd339d..17eab7e86 100644 --- a/mmrazor/models/algorithms/nas/__init__.py +++ b/mmrazor/models/algorithms/nas/__init__.py @@ -1,6 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .autoslim import AutoSlim, AutoSlimDDP from .darts import Darts, DartsDDP +from .dsnas import Dsnas, DsnasDDP from .spos import SPOS -__all__ = ['SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP'] +__all__ = [ + 'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'Dsnas', 'DsnasDDP' +] diff --git a/mmrazor/models/algorithms/nas/dsnas.py b/mmrazor/models/algorithms/nas/dsnas.py new file mode 100644 index 000000000..62c2c7f04 --- /dev/null +++ b/mmrazor/models/algorithms/nas/dsnas.py @@ -0,0 +1,347 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from mmengine.dist import get_dist_info +from mmengine.logging import MessageHub +from mmengine.model import BaseModel, MMDistributedDataParallel +from mmengine.optim import OptimWrapper, OptimWrapperDict +from torch import nn +from torch.nn.modules.batchnorm import _BatchNorm + +from mmrazor.models.mutables.base_mutable import BaseMutable +from mmrazor.models.mutators import DiffModuleMutator +from mmrazor.models.utils import add_prefix +from mmrazor.registry import MODEL_WRAPPERS, MODELS, TASK_UTILS +from mmrazor.structures import load_fix_subnet +from mmrazor.utils import FixMutable +from ..base import BaseAlgorithm + + +@MODELS.register_module() +class Dsnas(BaseAlgorithm): + """Implementation of `DSNAS `_ + + Args: + architecture (dict|:obj:`BaseModel`): The config of :class:`BaseModel` + or built model. Corresponding to supernet in NAS algorithm. + mutator (dict|:obj:`DiffModuleMutator`): The config of + :class:`DiffModuleMutator` or built mutator. + fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or + loaded dict or built :obj:`FixSubnet`. + pretrain_epochs (int): Num of epochs for supernet pretraining. + finetune_epochs (int): Num of epochs for subnet finetuning. + flops_constraints (float): Flops constraints for judging whether to + backward flops loss or not. Default to 300.0(M). + estimator_cfg (Dict[str, Any]): Used for building a resource estimator. + Default to None. + norm_training (bool): Whether to set norm layers to training mode, + namely, not freeze running stats (mean and var). Note: Effect on + Batch Norm and its variants only. Defaults to False. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. Defaults to None. + init_cfg (dict): Init config for ``BaseModule``. + + Note: + Dsnas doesn't require retraining. It has 3 stages in searching: + 1. `cur_epoch` < `pretrain_epochs` refers to supernet pretraining. + 2. `pretrain_epochs` <= `cur_epoch` < `finetune_epochs` refers to + normal supernet training while mutator is updated. + 3. `cur_epoch` >= `finetune_epochs` refers to subnet finetuning. + """ + + def __init__(self, + architecture: Union[BaseModel, Dict], + mutator: Optional[Union[DiffModuleMutator, Dict]] = None, + fix_subnet: Optional[FixMutable] = None, + pretrain_epochs: int = 0, + finetune_epochs: int = 80, + flops_constraints: float = 300.0, + estimator_cfg: Dict[str, Any] = None, + norm_training: bool = False, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None, + **kwargs): + super().__init__(architecture, data_preprocessor, **kwargs) + + if estimator_cfg is None: + estimator_cfg = dict(type='mmrazor.ResourceEstimator') + self.estimator = TASK_UTILS.build(estimator_cfg) + if fix_subnet: + # Avoid circular import + from mmrazor.structures import load_fix_subnet + + # According to fix_subnet, delete the unchosen part of supernet + load_fix_subnet(self.architecture, fix_subnet) + self.is_supernet = False + else: + assert mutator is not None, \ + 'mutator cannot be None when fix_subnet is None.' + if isinstance(mutator, DiffModuleMutator): + self.mutator = mutator + elif isinstance(mutator, dict): + self.mutator = MODELS.build(mutator) + else: + raise TypeError('mutator should be a `dict` or ' + f'`DiffModuleMutator` instance, but got ' + f'{type(mutator)}') + + self.mutable_module_resources = self._get_module_resources() + # Mutator is an essential component of the NAS algorithm. It + # provides some APIs commonly used by NAS. + # Before using it, you must do some preparations according to + # the supernet. + self.mutator.prepare_from_supernet(self.architecture) + self.is_supernet = True + self.search_space_name_list = list( + self.mutator.name2mutable.keys()) + + self.norm_training = norm_training + self.pretrain_epochs = pretrain_epochs + self.finetune_epochs = finetune_epochs + if pretrain_epochs >= finetune_epochs: + raise ValueError(f'Pretrain stage (optional) must be done before ' + f'finetuning stage. Got `{pretrain_epochs}` >= ' + f'`{finetune_epochs}`.') + + self.flops_loss_coef = 1e-2 + self.flops_constraints = flops_constraints + _, self.world_size = get_dist_info() + + def search_subnet(self): + """Search subnet by mutator.""" + + # Avoid circular import + from mmrazor.structures import export_fix_subnet + + subnet = self.mutator.sample_choices() + self.mutator.set_choices(subnet) + return export_fix_subnet(self) + + def fix_subnet(self): + """Fix subnet when finetuning.""" + subnet = self.mutator.sample_choices() + self.mutator.set_choices(subnet) + for module in self.architecture.modules(): + if isinstance(module, BaseMutable): + if not module.is_fixed: + module.fix_chosen(module.current_choice) + self.is_supernet = False + + def train(self, mode=True): + """Convert the model into eval mode while keep normalization layer + unfreezed.""" + + super().train(mode) + if self.norm_training and not mode: + for module in self.architecture.modules(): + if isinstance(module, _BatchNorm): + module.training = True + + def train_step(self, data: List[dict], + optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: + """The iteration step during training. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. + """ + if isinstance(optim_wrapper, OptimWrapperDict): + log_vars = dict() + self.message_hub = MessageHub.get_current_instance() + cur_epoch = self.message_hub.get_info('epoch') + need_update_mutator = self.need_update_mutator(cur_epoch) + + # TODO process the input + if cur_epoch == self.finetune_epochs and self.is_supernet: + # synchronize arch params to start the finetune stage. + for k, v in self.mutator.arch_params.items(): + dist.broadcast(v, src=0) + self.fix_subnet() + + # 1. update architecture + with optim_wrapper['architecture'].optim_context(self): + pseudo_data = self.data_preprocessor(data, True) + supernet_batch_inputs = pseudo_data['inputs'] + supernet_data_samples = pseudo_data['data_samples'] + supernet_loss = self( + supernet_batch_inputs, supernet_data_samples, mode='loss') + + supernet_losses, supernet_log_vars = self.parse_losses( + supernet_loss) + optim_wrapper['architecture'].backward( + supernet_losses, retain_graph=need_update_mutator) + optim_wrapper['architecture'].step() + optim_wrapper['architecture'].zero_grad() + log_vars.update(add_prefix(supernet_log_vars, 'supernet')) + + # 2. update mutator + if need_update_mutator: + with optim_wrapper['mutator'].optim_context(self): + mutator_loss = self.compute_mutator_loss() + mutator_losses, mutator_log_vars = \ + self.parse_losses(mutator_loss) + optim_wrapper['mutator'].update_params(mutator_losses) + log_vars.update(add_prefix(mutator_log_vars, 'mutator')) + # handle the grad of arch params & weights + self.handle_grads() + + else: + # Enable automatic mixed precision training context. + with optim_wrapper.optim_context(self): + pseudo_data = self.data_preprocessor(data, True) + batch_inputs = pseudo_data['inputs'] + data_samples = pseudo_data['data_samples'] + losses = self(batch_inputs, data_samples, mode='loss') + parsed_losses, log_vars = self.parse_losses(losses) + optim_wrapper.update_params(parsed_losses) + + return log_vars + + def _get_module_resources(self): + """Get resources of spec modules.""" + + spec_modules = [] + for name, module in self.architecture.named_modules(): + if isinstance(module, BaseMutable): + for choice in module.choices: + spec_modules.append(name + '._candidates.' + choice) + + mutable_module_resources = self.estimator.estimate_separation_modules( + self.architecture, dict(spec_modules=spec_modules)) + + return mutable_module_resources + + def need_update_mutator(self, cur_epoch: int) -> bool: + """Whether to update mutator.""" + if cur_epoch >= self.pretrain_epochs and \ + cur_epoch < self.finetune_epochs: + return True + return False + + def compute_mutator_loss(self) -> Dict[str, torch.Tensor]: + """Compute mutator loss. + + In this method, arch_loss & flops_loss[optional] are computed + by traversing arch_weights & probs in search groups. + + Returns: + Dict: Loss of the mutator. + """ + arch_loss = 0.0 + flops_loss = 0.0 + for name, module in self.architecture.named_modules(): + if isinstance(module, BaseMutable): + k = str(self.search_space_name_list.index(name)) + probs = F.softmax(self.mutator.arch_params[k], -1) + arch_loss += torch.log( + (module.arch_weights * probs).sum(-1)).sum() + + # get the index of op with max arch weights. + index = (module.arch_weights == 1).nonzero().item() + _module_key = name + '._candidates.' + module.choices[index] + flops_loss += probs[index] * \ + self.mutable_module_resources[_module_key]['flops'] + + mutator_loss = dict(arch_loss=arch_loss / self.world_size) + + copied_model = copy.deepcopy(self) + fix_mutable = copied_model.search_subnet() + load_fix_subnet(copied_model, fix_mutable) + + subnet_flops = self.estimator.estimate(copied_model)['flops'] + if subnet_flops >= self.flops_constraints: + mutator_loss['flops_loss'] = \ + (flops_loss * self.flops_loss_coef) / self.world_size + + return mutator_loss + + def handle_grads(self): + """Handle grads of arch params & arch weights.""" + for name, module in self.architecture.named_modules(): + if isinstance(module, BaseMutable): + k = str(self.search_space_name_list.index(name)) + self.mutator.arch_params[k].grad.data.mul_( + module.arch_weights.grad.data.sum()) + module.arch_weights.grad.zero_() + + +@MODEL_WRAPPERS.register_module() +class DsnasDDP(MMDistributedDataParallel): + + def __init__(self, + *, + device_ids: Optional[Union[List, int, torch.device]] = None, + **kwargs) -> None: + if device_ids is None: + if os.environ.get('LOCAL_RANK') is not None: + device_ids = [int(os.environ['LOCAL_RANK'])] + super().__init__(device_ids=device_ids, **kwargs) + + def train_step(self, data: List[dict], + optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating are also defined in + this method, such as GAN. + """ + if isinstance(optim_wrapper, OptimWrapperDict): + log_vars = dict() + self.message_hub = MessageHub.get_current_instance() + cur_epoch = self.message_hub.get_info('epoch') + need_update_mutator = self.module.need_update_mutator(cur_epoch) + + # TODO process the input + if cur_epoch == self.module.finetune_epochs and \ + self.module.is_supernet: + # synchronize arch params to start the finetune stage. + for k, v in self.module.mutator.arch_params.items(): + dist.broadcast(v, src=0) + self.module.fix_subnet() + + # 1. update architecture + with optim_wrapper['architecture'].optim_context(self): + pseudo_data = self.module.data_preprocessor(data, True) + supernet_batch_inputs = pseudo_data['inputs'] + supernet_data_samples = pseudo_data['data_samples'] + supernet_loss = self( + supernet_batch_inputs, supernet_data_samples, mode='loss') + + supernet_losses, supernet_log_vars = self.module.parse_losses( + supernet_loss) + optim_wrapper['architecture'].backward( + supernet_losses, retain_graph=need_update_mutator) + optim_wrapper['architecture'].step() + optim_wrapper['architecture'].zero_grad() + log_vars.update(add_prefix(supernet_log_vars, 'supernet')) + + # 2. update mutator + if need_update_mutator: + with optim_wrapper['mutator'].optim_context(self): + mutator_loss = self.module.compute_mutator_loss() + mutator_losses, mutator_log_vars = \ + self.module.parse_losses(mutator_loss) + optim_wrapper['mutator'].update_params(mutator_losses) + log_vars.update(add_prefix(mutator_log_vars, 'mutator')) + # handle the grad of arch params & weights + self.module.handle_grads() + + else: + # Enable automatic mixed precision training context. + with optim_wrapper.optim_context(self): + pseudo_data = self.module.data_preprocessor(data, True) + batch_inputs = pseudo_data['inputs'] + data_samples = pseudo_data['data_samples'] + losses = self(batch_inputs, data_samples, mode='loss') + parsed_losses, log_vars = self.module.parse_losses(losses) + optim_wrapper.update_params(parsed_losses) + + return log_vars diff --git a/mmrazor/models/mutables/__init__.py b/mmrazor/models/mutables/__init__.py index 917364607..074eda445 100644 --- a/mmrazor/models/mutables/__init__.py +++ b/mmrazor/models/mutables/__init__.py @@ -3,12 +3,13 @@ from .mutable_channel import (MutableChannel, OneShotMutableChannel, SlimmableMutableChannel) from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP, - OneShotMutableModule, OneShotMutableOP) + OneHotMutableOP, OneShotMutableModule, + OneShotMutableOP) from .mutable_value import MutableValue, OneShotMutableValue __all__ = [ 'OneShotMutableOP', 'OneShotMutableModule', 'DiffMutableOP', 'DiffChoiceRoute', 'DiffMutableModule', 'OneShotMutableChannel', 'SlimmableMutableChannel', 'MutableChannel', 'DerivedMutable', - 'MutableValue', 'OneShotMutableValue' + 'MutableValue', 'OneShotMutableValue', 'OneHotMutableOP' ] diff --git a/mmrazor/models/mutables/mutable_module/__init__.py b/mmrazor/models/mutables/mutable_module/__init__.py index d1904e8c8..bcf10c3a8 100644 --- a/mmrazor/models/mutables/mutable_module/__init__.py +++ b/mmrazor/models/mutables/mutable_module/__init__.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .diff_mutable_module import (DiffChoiceRoute, DiffMutableModule, - DiffMutableOP) + DiffMutableOP, OneHotMutableOP) from .mutable_module import MutableModule from .one_shot_mutable_module import OneShotMutableModule, OneShotMutableOP __all__ = [ 'DiffMutableModule', 'DiffMutableOP', 'DiffChoiceRoute', - 'OneShotMutableOP', 'OneShotMutableModule', 'MutableModule' + 'OneShotMutableOP', 'OneShotMutableModule', 'MutableModule', + 'OneHotMutableOP' ] diff --git a/mmrazor/models/mutables/mutable_module/diff_mutable_module.py b/mmrazor/models/mutables/mutable_module/diff_mutable_module.py index 9269d4a2a..379c59235 100644 --- a/mmrazor/models/mutables/mutable_module/diff_mutable_module.py +++ b/mmrazor/models/mutables/mutable_module/diff_mutable_module.py @@ -37,8 +37,9 @@ def __init__(self, **kwargs) -> None: def forward(self, x: Any, arch_param: Optional[nn.Parameter] = None) -> Any: - """Calls either :func:`forward_fixed` or :func:`forward_choice` - depending on whether :func:`is_fixed` is ``True``. + """Calls either :func:`forward_fixed` or :func:`forward_arch_param` + depending on whether :func:`is_fixed` is ``True`` and whether + :func:`arch_param` is None. To reduce the coupling between `Mutable` and `Mutator`, the `arch_param` is generated by the `Mutator` and is passed to the @@ -52,6 +53,9 @@ def forward(self, x (Any): input data for forward computation. arch_param (nn.Parameter, optional): the architecture parameters for ``DiffMutableModule``. + + Returns: + Any: the result of forward """ if self.is_fixed: return self.forward_fixed(x) @@ -97,6 +101,10 @@ class DiffMutableOP(DiffMutableModule[str, str]): Args: candidates (dict[str, dict]): the configs for the candidate operations. + fix_threshold (float): The threshold that determines whether to fix + the choice of current module as the op with the maximum `probs`. + It happens when the maximum prob is `fix_threshold` or more higher + then all the other probs. Default to 1.0. module_kwargs (dict[str, dict], optional): Module initialization named arguments. Defaults to None. alias (str, optional): alias of the `MUTABLE`. @@ -109,6 +117,7 @@ class DiffMutableOP(DiffMutableModule[str, str]): def __init__( self, candidates: Dict[str, Dict], + fix_threshold: float = 1.0, module_kwargs: Optional[Dict[str, Dict]] = None, alias: Optional[str] = None, init_cfg: Optional[Dict] = None, @@ -120,6 +129,10 @@ def __init__( f'but got: {len(candidates)}' self._is_fixed = False + if fix_threshold < 0 or fix_threshold > 1.0: + raise ValueError( + f'The fix_threshold should be in [0, 1]. Got {fix_threshold}.') + self.fix_threshold = fix_threshold self._candidates = self._build_ops(candidates, self.module_kwargs) @staticmethod @@ -242,6 +255,94 @@ def choices(self) -> List[str]: return list(self._candidates.keys()) +@MODELS.register_module() +class OneHotMutableOP(DiffMutableOP): + """A type of ``MUTABLES`` for one-hot sample based architecture search, + such as DSNAS. Search the best module by learnable parameters `arch_param`. + + Args: + candidates (dict[str, dict]): the configs for the candidate + operations. + module_kwargs (dict[str, dict], optional): Module initialization named + arguments. Defaults to None. + alias (str, optional): alias of the `MUTABLE`. + init_cfg (dict, optional): initialization configuration dict for + ``BaseModule``. OpenMMLab has implement 5 initializer including + `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, + and `Pretrained`. + """ + + def sample_weights(self, + arch_param: nn.Parameter, + probs: torch.Tensor, + random_sample: bool = False) -> Tensor: + """Use one-hot distributions to sample the arch weights based on the + arch params. + + Args: + arch_param (nn.Parameter): architecture parameters for + `DiffMutableModule`. + probs (Tensor): the probs of choice. + random_sample (bool): Whether to random sample arch weights or not + Defaults to False. + + Returns: + Tensor: Sampled one-hot arch weights. + """ + import torch.distributions as D + if random_sample: + uni = torch.ones_like(arch_param) + m = D.one_hot_categorical.OneHotCategorical(uni) + else: + m = D.one_hot_categorical.OneHotCategorical(probs=probs) + return m.sample() + + def forward_arch_param(self, + x: Any, + arch_param: Optional[nn.Parameter] = None + ) -> Tensor: + """Forward with architecture parameters. + + Args: + x (Any): x could be a Torch.tensor or a tuple of + Torch.tensor, containing input data for forward computation. + arch_param (str, optional): architecture parameters for + `DiffMutableModule`. + + Returns: + Tensor: the result of forward with ``arch_param``. + """ + if arch_param is None: + return self.forward_all(x) + else: + # compute the probs of choice + probs = self.compute_arch_probs(arch_param=arch_param) + + if not self.is_fixed: + self.arch_weights = self.sample_weights(arch_param, probs) + sorted_param = torch.topk(probs, 2) + index = ( + sorted_param[0][0] - sorted_param[0][1] >= + self.fix_threshold) + if index: + self.fix_chosen(self.choices[index]) + + if self.is_fixed: + index = self.choices.index(self._chosen[0]) + self.arch_weights.data.zero_() + self.arch_weights.data[index].fill_(1.0) + self.arch_weights.requires_grad_() + + # forward based on self.arch_weights + outputs = list() + for prob, module in zip(self.arch_weights, + self._candidates.values()): + if prob > 0.: + outputs.append(prob * module(x)) + + return sum(outputs) + + @MODELS.register_module() class DiffChoiceRoute(DiffMutableModule[str, List[str]]): """A type of ``MUTABLES`` for Neural Architecture Search, which can select diff --git a/mmrazor/models/mutators/module_mutator/diff_module_mutator.py b/mmrazor/models/mutators/module_mutator/diff_module_mutator.py index e50314d5f..ac6358049 100644 --- a/mmrazor/models/mutators/module_mutator/diff_module_mutator.py +++ b/mmrazor/models/mutators/module_mutator/diff_module_mutator.py @@ -88,8 +88,8 @@ def sample_choices(self): choices = dict() for group_id, mutables in self.search_groups.items(): - arch_parm = self.arch_params[str(group_id)] - choice = mutables[0].sample_choice(arch_parm) + arch_param = self.arch_params[str(group_id)] + choice = mutables[0].sample_choice(arch_param) choices[group_id] = choice return choices diff --git a/mmrazor/models/mutators/module_mutator/module_mutator.py b/mmrazor/models/mutators/module_mutator/module_mutator.py index 7fa9e31d5..dc045932b 100644 --- a/mmrazor/models/mutators/module_mutator/module_mutator.py +++ b/mmrazor/models/mutators/module_mutator/module_mutator.py @@ -54,6 +54,24 @@ def prepare_from_supernet(self, supernet: Module) -> None: """ self._build_search_groups(supernet) + @property + def name2mutable(self) -> Dict[str, MUTABLE_TYPE]: + """Search space of supernet. + + Note: + To get the mapping: module name to mutable. + + Raises: + RuntimeError: Called before search space has been parsed. + + Returns: + Dict[str, MUTABLE_TYPE]: The name2mutable dict. + """ + if self._name2mutable is None: + raise RuntimeError( + 'Call `prepare_from_supernet` before access name2mutable!') + return self._name2mutable + @property def search_groups(self) -> Dict[int, List[MUTABLE_TYPE]]: """Search group of supernet. @@ -80,6 +98,8 @@ def _build_name_mutable_mapping( for name, module in supernet.named_modules(): if isinstance(module, self.mutable_class_type): name2mutable[name] = module + self._name2mutable = name2mutable + return name2mutable def _build_alias_names_mapping(self, @@ -121,7 +141,7 @@ def _build_search_groups(self, supernet: Module) -> None: >>> import torch >>> from mmrazor.models.mutables.diff_mutable import DiffMutableOP - >>> # Assume that a toy model consists of three mutabels + >>> # Assume that a toy model consists of three mutables >>> # whose name are op1,op2,op3. The corresponding >>> # alias names of the three mutables are a1, a1, a2. >>> model = ToyModel() diff --git a/mmrazor/structures/subnet/fix_subnet.py b/mmrazor/structures/subnet/fix_subnet.py index c29a0b181..2b142f6ea 100644 --- a/mmrazor/structures/subnet/fix_subnet.py +++ b/mmrazor/structures/subnet/fix_subnet.py @@ -50,21 +50,22 @@ def load_fix_subnet(model: nn.Module, # In the corresponding mutable, it will check whether the `chosen` # format is correct. if isinstance(module, BaseMutable): - if getattr(module, 'alias', None): - alias = module.alias - assert alias in fix_mutable, \ - f'The alias {alias} is not in fix_modules, ' \ - 'please check your `fix_mutable`.' - chosen = fix_mutable.get(alias, None) - else: - mutable_name = name.lstrip(prefix) - if mutable_name not in fix_mutable and \ - not isinstance(module, DerivedMutable): - raise RuntimeError( - f'The module name {mutable_name} is not in ' - 'fix_mutable, please check your `fix_mutable`.') - chosen = fix_mutable.get(mutable_name, None) - module.fix_chosen(chosen) + if not module.is_fixed: + if getattr(module, 'alias', None): + alias = module.alias + assert alias in fix_mutable, \ + f'The alias {alias} is not in fix_modules, ' \ + 'please check your `fix_mutable`.' + chosen = fix_mutable.get(alias, None) + else: + mutable_name = name.lstrip(prefix) + if mutable_name not in fix_mutable and \ + not isinstance(module, DerivedMutable): + raise RuntimeError( + f'The module name {mutable_name} is not in ' + 'fix_mutable, please check your `fix_mutable`.') + chosen = fix_mutable.get(mutable_name, None) + module.fix_chosen(chosen) # convert dynamic op to static op _dynamic_to_static(model) @@ -89,7 +90,6 @@ def export_fix_subnet(model: nn.Module, if isinstance(module, DerivedMutable) and not dump_derived_mutable: continue - assert not module.is_fixed if module.alias: fix_subnet[module.alias] = module.dump_chosen() else: diff --git a/tests/test_models/test_algorithms/test_dsnas.py b/tests/test_models/test_algorithms/test_dsnas.py new file mode 100644 index 000000000..929840148 --- /dev/null +++ b/tests/test_models/test_algorithms/test_dsnas.py @@ -0,0 +1,222 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from unittest import TestCase +from unittest.mock import patch + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from mmcls.structures import ClsDataSample +from mmengine.model import BaseModel +from mmengine.optim import build_optim_wrapper +from mmengine.optim.optimizer import OptimWrapper, OptimWrapperDict +from torch import Tensor +from torch.optim import SGD + +from mmrazor.models import DiffModuleMutator, Dsnas, OneHotMutableOP +from mmrazor.models.algorithms.nas.dsnas import DsnasDDP +from mmrazor.registry import MODELS + +MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True) +MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True) +MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True) + + +@MODELS.register_module() +class ToyDiffModule(BaseModel): + + def __init__(self, data_preprocessor=None): + super().__init__(data_preprocessor=data_preprocessor, init_cfg=None) + self.candidates = dict( + torch_conv2d_3x3=dict( + type='torchConv2d', + kernel_size=3, + padding=1, + ), + torch_conv2d_5x5=dict( + type='torchConv2d', + kernel_size=5, + padding=2, + ), + torch_conv2d_7x7=dict( + type='torchConv2d', + kernel_size=7, + padding=3, + ), + ) + module_kwargs = dict(in_channels=3, out_channels=8, stride=1) + + self.mutable = OneHotMutableOP( + candidates=self.candidates, module_kwargs=module_kwargs) + self.bn = nn.BatchNorm2d(8) + + def forward(self, batch_inputs, data_samples=None, mode='tensor'): + if mode == 'loss': + out = self.bn(self.mutable(batch_inputs)) + return dict(loss=out) + elif mode == 'predict': + out = self.bn(self.mutable(batch_inputs)) + 1 + return out + elif mode == 'tensor': + out = self.bn(self.mutable(batch_inputs)) + 2 + return out + + +class TestDsnas(TestCase): + + def setUp(self) -> None: + self.device: str = 'cpu' + + OPTIMIZER_CFG = dict( + type='SGD', + lr=0.5, + momentum=0.9, + nesterov=True, + weight_decay=0.0001) + + self.OPTIM_WRAPPER_CFG = dict(optimizer=OPTIMIZER_CFG) + + def test_init(self) -> None: + # initiate dsnas when `norm_training` is True. + model = ToyDiffModule() + mutator = DiffModuleMutator() + algo = Dsnas(architecture=model, mutator=mutator, norm_training=True) + algo.eval() + self.assertTrue(model.bn.training) + + # initiate Dsnas with built mutator + model = ToyDiffModule() + mutator = DiffModuleMutator() + algo = Dsnas(model, mutator) + self.assertIs(algo.mutator, mutator) + + # initiate Dsnas with unbuilt mutator + mutator = dict(type='DiffModuleMutator') + algo = Dsnas(model, mutator) + self.assertIsInstance(algo.mutator, DiffModuleMutator) + + # initiate Dsnas when `fix_subnet` is not None + fix_subnet = {'mutable': 'torch_conv2d_5x5'} + algo = Dsnas(model, mutator, fix_subnet=fix_subnet) + self.assertEqual(algo.architecture.mutable.num_choices, 1) + + # initiate Dsnas with error type `mutator` + with self.assertRaisesRegex(TypeError, 'mutator should be'): + Dsnas(model, model) + + def test_forward_loss(self) -> None: + inputs = torch.randn(1, 3, 8, 8) + model = ToyDiffModule() + + # supernet + mutator = DiffModuleMutator() + mutator.prepare_from_supernet(model) + algo = Dsnas(model, mutator) + loss = algo(inputs, mode='loss') + self.assertIsInstance(loss, dict) + + # subnet + fix_subnet = {'mutable': 'torch_conv2d_5x5'} + algo = Dsnas(model, fix_subnet=fix_subnet) + loss = algo(inputs, mode='loss') + self.assertIsInstance(loss, dict) + + def _prepare_fake_data(self): + imgs = torch.randn(16, 3, 224, 224).to(self.device) + data_samples = [ + ClsDataSample().set_gt_label(torch.randint(0, 1000, + (16, ))).to(self.device) + ] + return {'inputs': imgs, 'data_samples': data_samples} + + def test_search_subnet(self) -> None: + model = ToyDiffModule() + + mutator = DiffModuleMutator() + mutator.prepare_from_supernet(model) + algo = Dsnas(model, mutator) + subnet = algo.search_subnet() + self.assertIsInstance(subnet, dict) + + @patch('mmengine.logging.message_hub.MessageHub.get_info') + def test_dsnas_train_step(self, mock_get_info) -> None: + model = ToyDiffModule() + mutator = DiffModuleMutator() + mutator.prepare_from_supernet(model) + mock_get_info.return_value = 2 + + algo = Dsnas(model, mutator) + data = self._prepare_fake_data() + optim_wrapper = build_optim_wrapper(algo, self.OPTIM_WRAPPER_CFG) + loss = algo.train_step(data, optim_wrapper) + + self.assertTrue(isinstance(loss['loss'], Tensor)) + + algo = Dsnas(model, mutator) + optim_wrapper_dict = OptimWrapperDict( + architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)), + mutator=OptimWrapper(SGD(model.parameters(), lr=0.01))) + loss = algo.train_step(data, optim_wrapper_dict) + + self.assertIsNotNone(loss) + + +class TestDsnasDDP(TestDsnas): + + @classmethod + def setUpClass(cls) -> None: + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12345' + + # initialize the process group + if torch.cuda.is_available(): + backend = 'nccl' + cls.device = 'cuda' + else: + backend = 'gloo' + dist.init_process_group(backend, rank=0, world_size=1) + + def prepare_model(self, device_ids=None) -> Dsnas: + model = ToyDiffModule().to(self.device) + mutator = DiffModuleMutator().to(self.device) + mutator.prepare_from_supernet(model) + + algo = Dsnas(model, mutator) + + return DsnasDDP( + module=algo, find_unused_parameters=True, device_ids=device_ids) + + @classmethod + def tearDownClass(cls) -> None: + dist.destroy_process_group() + + @pytest.mark.skipif( + not torch.cuda.is_available(), reason='cuda device is not avaliable') + def test_init(self) -> None: + ddp_model = self.prepare_model() + self.assertIsInstance(ddp_model, DsnasDDP) + + @patch('mmengine.logging.message_hub.MessageHub.get_info') + def test_dsnasddp_train_step(self, mock_get_info) -> None: + model = ToyDiffModule() + mutator = DiffModuleMutator() + mutator.prepare_from_supernet(model) + mock_get_info.return_value = 2 + + algo = Dsnas(model, mutator) + ddp_model = DsnasDDP(module=algo, find_unused_parameters=True) + data = self._prepare_fake_data() + optim_wrapper = build_optim_wrapper(ddp_model, self.OPTIM_WRAPPER_CFG) + loss = ddp_model.train_step(data, optim_wrapper) + + self.assertIsNotNone(loss) + + algo = Dsnas(model, mutator) + ddp_model = DsnasDDP(module=algo, find_unused_parameters=True) + optim_wrapper_dict = OptimWrapperDict( + architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)), + mutator=OptimWrapper(SGD(model.parameters(), lr=0.01))) + loss = ddp_model.train_step(data, optim_wrapper_dict) + + self.assertIsNotNone(loss) diff --git a/tests/test_models/test_mutables/test_onehotop.py b/tests/test_models/test_mutables/test_onehotop.py new file mode 100644 index 000000000..4ace5870d --- /dev/null +++ b/tests/test_models/test_mutables/test_onehotop.py @@ -0,0 +1,203 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import pytest +import torch +import torch.nn as nn + +from mmrazor.models import * # noqa:F403,F401 +from mmrazor.registry import MODELS + +MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True) +MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True) +MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True) + + +class TestOneHotOP(TestCase): + + def test_forward_arch_param(self): + op_cfg = dict( + type='mmrazor.OneHotMutableOP', + candidates=dict( + torch_conv2d_3x3=dict( + type='torchConv2d', + kernel_size=3, + padding=1, + ), + torch_conv2d_5x5=dict( + type='torchConv2d', + kernel_size=5, + padding=2, + ), + torch_conv2d_7x7=dict( + type='torchConv2d', + kernel_size=7, + padding=3, + ), + ), + module_kwargs=dict(in_channels=32, out_channels=32, stride=1)) + + op = MODELS.build(op_cfg) + input = torch.randn(4, 32, 64, 64) + + arch_param = nn.Parameter(torch.randn(len(op_cfg['candidates']))) + output = op.forward_arch_param(input, arch_param=arch_param) + assert output is not None + + output = op.forward_arch_param(input, arch_param=None) + assert output is not None + + # test when some element of arch_param is 0 + arch_param = nn.Parameter(torch.ones(op.num_choices)) + output = op.forward_arch_param(input, arch_param=arch_param) + assert output is not None + + def test_forward_fixed(self): + op_cfg = dict( + type='mmrazor.OneHotMutableOP', + candidates=dict( + torch_conv2d_3x3=dict( + type='torchConv2d', + kernel_size=3, + ), + torch_conv2d_5x5=dict( + type='torchConv2d', + kernel_size=5, + ), + torch_conv2d_7x7=dict( + type='torchConv2d', + kernel_size=7, + ), + ), + module_kwargs=dict(in_channels=32, out_channels=32, stride=1)) + + op = MODELS.build(op_cfg) + input = torch.randn(4, 32, 64, 64) + + op.fix_chosen('torch_conv2d_7x7') + output = op.forward_fixed(input) + + assert output is not None + assert op.is_fixed is True + + def test_forward(self): + op_cfg = dict( + type='mmrazor.OneHotMutableOP', + candidates=dict( + torch_conv2d_3x3=dict( + type='torchConv2d', + kernel_size=3, + padding=1, + ), + torch_conv2d_5x5=dict( + type='torchConv2d', + kernel_size=5, + padding=2, + ), + torch_conv2d_7x7=dict( + type='torchConv2d', + kernel_size=7, + padding=3, + ), + ), + module_kwargs=dict(in_channels=32, out_channels=32, stride=1)) + + op = MODELS.build(op_cfg) + input = torch.randn(4, 32, 64, 64) + + # test set_forward_args + arch_param = nn.Parameter(torch.randn(len(op_cfg['candidates']))) + op.set_forward_args(arch_param=arch_param) + output = op.forward(input) + assert output is not None + + # test dump_chosen + with pytest.raises(AssertionError): + op.dump_chosen() + + # test forward when is_fixed is True + op.fix_chosen('torch_conv2d_7x7') + output = op.forward(input) + + def test_property(self): + op_cfg = dict( + type='mmrazor.OneHotMutableOP', + candidates=dict( + torch_conv2d_3x3=dict( + type='torchConv2d', + kernel_size=3, + padding=1, + ), + torch_conv2d_5x5=dict( + type='torchConv2d', + kernel_size=5, + padding=2, + ), + torch_conv2d_7x7=dict( + type='torchConv2d', + kernel_size=7, + padding=3, + ), + ), + module_kwargs=dict(in_channels=32, out_channels=32, stride=1)) + + op = MODELS.build(op_cfg) + + assert len(op.choices) == 3 + + # test is_fixed propty + assert op.is_fixed is False + + # test is_fixed setting + op.fix_chosen('torch_conv2d_5x5') + + with pytest.raises(AttributeError): + op.is_fixed = True + + # test fix choice when is_fixed is True + with pytest.raises(AttributeError): + op.fix_chosen('torch_conv2d_3x3') + + def test_module_kwargs(self): + op_cfg = dict( + type='mmrazor.OneHotMutableOP', + candidates=dict( + torch_conv2d_3x3=dict( + type='torchConv2d', + kernel_size=3, + in_channels=32, + out_channels=32, + stride=1, + ), + torch_conv2d_5x5=dict( + type='torchConv2d', + kernel_size=5, + in_channels=32, + out_channels=32, + stride=1, + ), + torch_conv2d_7x7=dict( + type='torchConv2d', + kernel_size=7, + in_channels=32, + out_channels=32, + stride=1, + ), + torch_maxpool_3x3=dict( + type='torchMaxPool2d', + kernel_size=3, + stride=1, + ), + torch_avgpool_3x3=dict( + type='torchAvgPool2d', + kernel_size=3, + stride=1, + ), + ), + ) + op = MODELS.build(op_cfg) + input = torch.randn(4, 32, 64, 64) + + op.fix_chosen('torch_avgpool_3x3') + output = op.forward(input) + assert output is not None