From fb42405af87e0da664ddea2db3cbd6776cd4c9d0 Mon Sep 17 00:00:00 2001 From: Yue Sun Date: Mon, 14 Nov 2022 13:01:04 +0800 Subject: [PATCH] [Feature] Add Autoformer algorithm (#315) * update candidates * update subnet_sampler_loop * update candidate * add readme * rename variable * rename variable * clean * update * add doc string * Revert "[Improvement] Support for candidate multiple dimensional search constraints." * [Improvement] Update Candidate with multi-dim search constraints. (#322) * update doc * add support type * clean code * update candidates * clean * xx * set_resource -> set_score * fix ci bug * py36 lint * fix bug * fix check constrain * py36 ci * redesign candidate * fix pre-commit * update cfg * add build_resource_estimator * fix ci bug * remove runner.epoch in testcase * [Feature] Autoformer architecture and dynamicOPs (#327) * add DynamicSequential * dynamiclayernorm * add dynamic_pathchembed * add DynamicMultiheadAttention and DynamicRelativePosition2D * add channel-level dynamicOP * add autoformer algo * clean notes * adapt channel_mutator * vit fly * fix import * mutable init * remove annotation * add DynamicInputResizer * add unittest for mutables * add OneShotMutableChannelUnit_VIT * clean code * reset unit for vit * remove attr * add autoformer backbone UT * add valuemutator UT * clean code * add autoformer algo UT * update classifier UT * fix test error * ignore * make lint * update * fix lint * mutable_attrs * fix test * fix error * remove DynamicInputResizer * fix test ci * remove InputResizer * rename variables * modify type * Continued improvements of ChannelUnit * fix lint * fix lint * remove OneShotMutableChannelUnit * adjust derived type * combination mixins * clean code * fix sample subnet * search loop fly * more annotations * avoid counter warning and modify batch_augment cfg by gy * restore * source_value_mutables restriction * simply arch_setting api * update * clean * fix ut --- .../_base_/settings/imagenet_bs2048_AdamW.py | 180 +++++++++ configs/nas/mmcls/autoformer/README.md | 66 ++++ .../autoformer_search_8xb128_in1k.py | 17 + .../autoformer_supernet_32xb256_in1k.py | 79 ++++ .../spos/spos_mobilenet_search_8xb128_in1k.py | 2 +- .../spos_shufflenet_search_8xb128_in1k.py | 2 +- .../detnas_frcnn_shufflenet_search_coco_1x.py | 2 +- .../engine/runner/evolution_search_loop.py | 153 +++++-- mmrazor/engine/runner/subnet_sampler_loop.py | 99 +++-- mmrazor/engine/runner/utils/__init__.py | 4 +- mmrazor/engine/runner/utils/check.py | 37 +- mmrazor/models/algorithms/__init__.py | 4 +- mmrazor/models/algorithms/nas/__init__.py | 4 +- mmrazor/models/algorithms/nas/autoformer.py | 115 ++++++ mmrazor/models/architectures/__init__.py | 1 + .../architectures/backbones/__init__.py | 3 +- .../backbones/searchable_autoformer.py | 374 ++++++++++++++++++ .../architectures/classifiers/__init__.py | 4 + .../models/architectures/classifiers/image.py | 53 +++ .../architectures/dynamic_ops/__init__.py | 17 +- .../dynamic_ops/bricks/__init__.py | 17 + .../dynamic_ops/bricks/dynamic_container.py | 109 +++++ .../dynamic_ops/bricks/dynamic_embed.py | 142 +++++++ .../bricks/dynamic_multi_head_attention.py | 280 +++++++++++++ .../dynamic_ops/bricks/dynamic_norm.py | 64 ++- .../bricks/dynamic_relative_position.py | 154 ++++++++ .../dynamic_ops/head/__init__.py | 4 + .../dynamic_ops/head/dynamic_linear_head.py | 80 ++++ .../dynamic_ops/mixins/__init__.py | 9 +- .../mixins/dynamic_layernorm_mixins.py | 147 +++++++ .../dynamic_ops/mixins/dynamic_mixins.py | 9 + mmrazor/models/architectures/ops/__init__.py | 4 +- .../architectures/ops/transformer_series.py | 192 +++++++++ mmrazor/models/mutables/derived_mutable.py | 62 ++- .../mutable_channel_container.py | 4 +- .../sequential_mutable_channel.py | 14 +- .../mutable_channel/units/channel_unit.py | 6 +- .../units/mutable_channel_unit.py | 60 ++- .../units/one_shot_mutable_channel_unit.py | 6 +- .../mutables/mutable_value/mutable_value.py | 9 +- mmrazor/models/mutators/__init__.py | 4 +- .../channel_mutator/channel_mutator.py | 3 + mmrazor/models/mutators/group_mixin.py | 53 ++- .../models/mutators/value_mutator/__init__.py | 5 + .../value_mutator/dynamic_value_mutator.py | 13 + .../mutators/value_mutator/value_mutator.py | 73 ++++ .../counters/flops_params_counter.py | 14 +- mmrazor/structures/subnet/candidate.py | 137 +++++-- mmrazor/structures/subnet/fix_subnet.py | 14 +- tests/data/models.py | 96 ++++- .../test_algorithms/test_autoformer.py | 116 ++++++ .../test_backbones/test_autoformerbackbone.py | 60 +++ .../test_bricks/test_dynamic_attention.py | 50 +++ .../test_bricks/test_dynamic_container.py | 46 +++ .../test_bricks/test_dynamic_embed.py | 53 +++ .../test_bricks/test_dynamic_layernorm.py | 45 +++ .../test_dynamic_relative_position.py | 55 +++ .../test_classifier/test_imageclassifier.py | 45 +++ .../test_mutables/test_derived_mutable.py | 84 ++++ .../test_sequential_mutable_channel.py | 16 +- .../test_one_shot_mutable_channel_unit.py | 19 + .../test_mutables/test_mutable_value.py | 6 - .../test_mutators/test_channel_mutator.py | 26 +- .../test_mutators/test_value_mutator.py | 33 ++ .../test_models/test_subnet/test_candidate.py | 107 ++++- .../test_evolution_search_loop.py | 53 ++- .../test_runners/test_subnet_sampler_loop.py | 6 +- tests/test_runners/test_utils/test_check.py | 38 +- 68 files changed, 3598 insertions(+), 260 deletions(-) create mode 100644 configs/_base_/settings/imagenet_bs2048_AdamW.py create mode 100644 configs/nas/mmcls/autoformer/README.md create mode 100644 configs/nas/mmcls/autoformer/autoformer_search_8xb128_in1k.py create mode 100644 configs/nas/mmcls/autoformer/autoformer_supernet_32xb256_in1k.py create mode 100644 mmrazor/models/algorithms/nas/autoformer.py create mode 100644 mmrazor/models/architectures/backbones/searchable_autoformer.py create mode 100644 mmrazor/models/architectures/classifiers/__init__.py create mode 100644 mmrazor/models/architectures/classifiers/image.py create mode 100644 mmrazor/models/architectures/dynamic_ops/bricks/dynamic_container.py create mode 100644 mmrazor/models/architectures/dynamic_ops/bricks/dynamic_embed.py create mode 100644 mmrazor/models/architectures/dynamic_ops/bricks/dynamic_multi_head_attention.py create mode 100644 mmrazor/models/architectures/dynamic_ops/bricks/dynamic_relative_position.py create mode 100644 mmrazor/models/architectures/dynamic_ops/head/__init__.py create mode 100644 mmrazor/models/architectures/dynamic_ops/head/dynamic_linear_head.py create mode 100644 mmrazor/models/architectures/dynamic_ops/mixins/dynamic_layernorm_mixins.py create mode 100644 mmrazor/models/architectures/ops/transformer_series.py create mode 100644 mmrazor/models/mutators/value_mutator/__init__.py create mode 100644 mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py create mode 100644 mmrazor/models/mutators/value_mutator/value_mutator.py create mode 100644 tests/test_models/test_algorithms/test_autoformer.py create mode 100644 tests/test_models/test_architectures/test_backbones/test_autoformerbackbone.py create mode 100644 tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_attention.py create mode 100644 tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_container.py create mode 100644 tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_embed.py create mode 100644 tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_layernorm.py create mode 100644 tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_relative_position.py create mode 100644 tests/test_models/test_classifier/test_imageclassifier.py create mode 100644 tests/test_models/test_mutators/test_value_mutator.py diff --git a/configs/_base_/settings/imagenet_bs2048_AdamW.py b/configs/_base_/settings/imagenet_bs2048_AdamW.py new file mode 100644 index 000000000..7b7b29097 --- /dev/null +++ b/configs/_base_/settings/imagenet_bs2048_AdamW.py @@ -0,0 +1,180 @@ +# dataset settings +dataset_type = 'mmcls.ImageNet' +preprocess_cfg = dict( + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = preprocess_cfg['mean'][::-1] +bgr_std = preprocess_cfg['std'][::-1] + +# Refers to `_RAND_INCREASING_TRANSFORMS` in pytorch-image-models +rand_increasing_policies = [ + dict(type='mmcls.AutoContrast'), + dict(type='mmcls.Equalize'), + dict(type='mmcls.Invert'), + dict(type='mmcls.Rotate', magnitude_key='angle', magnitude_range=(0, 30)), + dict(type='mmcls.Posterize', magnitude_key='bits', magnitude_range=(4, 0)), + dict(type='mmcls.Solarize', magnitude_key='thr', magnitude_range=(256, 0)), + dict( + type='mmcls.SolarizeAdd', + magnitude_key='magnitude', + magnitude_range=(0, 110)), + dict( + type='mmcls.ColorTransform', + magnitude_key='magnitude', + magnitude_range=(0, 0.9)), + dict( + type='mmcls.Contrast', + magnitude_key='magnitude', + magnitude_range=(0, 0.9)), + dict( + type='mmcls.Brightness', + magnitude_key='magnitude', + magnitude_range=(0, 0.9)), + dict( + type='mmcls.Sharpness', + magnitude_key='magnitude', + magnitude_range=(0, 0.9)), + dict( + type='mmcls.Shear', + magnitude_key='magnitude', + magnitude_range=(0, 0.3), + direction='horizontal'), + dict( + type='mmcls.Shear', + magnitude_key='magnitude', + magnitude_range=(0, 0.3), + direction='vertical'), + dict( + type='mmcls.Translate', + magnitude_key='magnitude', + magnitude_range=(0, 0.45), + direction='horizontal'), + dict( + type='mmcls.Translate', + magnitude_key='magnitude', + magnitude_range=(0, 0.45), + direction='vertical') +] + +train_pipeline = [ + dict(type='mmcls.LoadImageFromFile'), + dict( + type='mmcls.RandomResizedCrop', + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='mmcls.RandAugment', + policies=rand_increasing_policies, + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict( + pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')), + dict( + type='mmcls.RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type='mmcls.PackClsInputs'), +] + +test_pipeline = [ + dict(type='mmcls.LoadImageFromFile'), + dict( + type='mmcls.ResizeEdge', + scale=248, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='mmcls.CenterCrop', crop_size=224), + dict(type='mmcls.PackClsInputs') +] + +train_dataloader = dict( + batch_size=64, + num_workers=6, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/train.txt', + data_prefix='train', + pipeline=train_pipeline), + sampler=dict(type='mmcls.RepeatAugSampler'), + persistent_workers=True, +) + +val_dataloader = dict( + batch_size=256, + num_workers=6, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/val.txt', + data_prefix='val', + pipeline=test_pipeline), + sampler=dict(type='mmcls.DefaultSampler', shuffle=False), + persistent_workers=True, +) +val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator + +# optimizer +paramwise_cfg = dict( + bias_decay_mult=0.0, norm_decay_mult=0.0, dwconv_decay_mult=0.0) + +optim_wrapper = dict( + optimizer=dict( + type='AdamW', + lr=0.002, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999)), + # specific to vit pretrain + paramwise_cfg=dict(custom_keys={ + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0) + })) + +# leanring policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type='LinearLR', + start_factor=1e-3, + by_epoch=True, + begin=0, + # about 10000 iterations for ImageNet-1k + end=20, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict( + type='CosineAnnealingLR', + T_max=500, + eta_min=1e-5, + by_epoch=True, + begin=20, + end=500, + convert_to_iter_based=True), +] + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=500) +val_cfg = dict() +test_cfg = dict() + +auto_scale_lr = dict(base_batch_size=2048) diff --git a/configs/nas/mmcls/autoformer/README.md b/configs/nas/mmcls/autoformer/README.md new file mode 100644 index 000000000..cab373e1d --- /dev/null +++ b/configs/nas/mmcls/autoformer/README.md @@ -0,0 +1,66 @@ +# AutoFormer + +> [Searching Transformers for Visual Recognition](https://arxiv.org/abs/2107.00651) + + + +## Abstract + +Recently, pure transformer-based models have shown +great potentials for vision tasks such as image classification and detection. However, the design of transformer networks is challenging. It has been observed that the depth, +embedding dimension, and number of heads can largely affect the performance of vision transformers. Previous models configure these dimensions based upon manual crafting. In this work, we propose a new one-shot architecture +search framework, namely AutoFormer, dedicated to vision +transformer search. AutoFormer entangles the weights of +different blocks in the same layers during supernet training. Benefiting from the strategy, the trained supernet allows thousands of subnets to be very well-trained. Specifically, the performance of these subnets with weights inherited from the supernet is comparable to those retrained +from scratch. Besides, the searched models, which we refer to AutoFormers, surpass the recent state-of-the-arts such +as ViT and DeiT. In particular, AutoFormer-tiny/small/base +achieve 74.7%/81.7%/82.4% top-1 accuracy on ImageNet +with 5.7M/22.9M/53.7M parameters, respectively. Lastly, +we verify the transferability of AutoFormer by providing +the performance on downstream benchmarks and distillation experiments. + +![pipeline](/docs/en/imgs/model_zoo/autoformer/pipeline.png) + +## Introduction + +### Supernet pre-training on ImageNet + +```bash +python ./tools/train.py \ + configs/nas/mmcls/autoformer/autoformer_supernet_32xb256_in1k.py \ + --work-dir $WORK_DIR +``` + +### Search for subnet on the trained supernet + +```bash +sh tools/train.py \ + configs/nas/mmcls/autoformer/autoformer_search_8xb128_in1k.py \ + $STEP1_CKPT \ + --work-dir $WORK_DIR +``` + +## Results and models + +| Dataset | Supernet | Subnet | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | Remarks | +| :------: | :------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------: | :------: | :-------: | :-------: | :---------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :--------------: | +| ImageNet | vit | [mutable](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-454627be_mutable_cfg.yaml?versionId=CAEQHxiBgICw5b6I7xciIGY5MjVmNWFhY2U5MjQzN2M4NDViYzI2YWRmYWE1YzQx) | 52.472 | 10.2 | 82.48 | 95.99 | [config](./autoformer_supernet_32xb256_in1k.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/x.pth) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/x.log.json) | MMRazor searched | + +**Note**: + +1. There are some small differences in our experiment in order to be consistent with mmrazor repo. For example, we set the max value of embed_channels 624 while the original repo set it 640. However, the original repo only search 528, 576, 624 embed_channels, so set 624 can also get the same result with orifinal paper. +2. The original paper get 82.4 top-1 acc with 53.7M Params while we get 82.48 top-1 acc with 52.47M Params. + +## Citation + +```latex +@article{xu2021autoformer, + title={Autoformer: Decomposition transformers with auto-correlation for long-term series forecasting}, + author={Xu, Jiehui and Wang, Jianmin and Long, Mingsheng and others}, + journal={Advances in Neural Information Processing Systems}, + volume={34}, + year={2021} +} +``` + +Footer diff --git a/configs/nas/mmcls/autoformer/autoformer_search_8xb128_in1k.py b/configs/nas/mmcls/autoformer/autoformer_search_8xb128_in1k.py new file mode 100644 index 000000000..be1d4660d --- /dev/null +++ b/configs/nas/mmcls/autoformer/autoformer_search_8xb128_in1k.py @@ -0,0 +1,17 @@ +_base_ = ['./autoformer_supernet_32xb256_in1k.py'] + +custom_hooks = None + +train_cfg = dict( + _delete_=True, + type='mmrazor.EvolutionSearchLoop', + dataloader=_base_.val_dataloader, + evaluator=_base_.val_evaluator, + max_epochs=20, + num_candidates=20, + top_k=10, + num_mutation=5, + num_crossover=5, + mutate_prob=0.2, + constraints_range=dict(params=(0, 55)), + score_key='accuracy/top1') diff --git a/configs/nas/mmcls/autoformer/autoformer_supernet_32xb256_in1k.py b/configs/nas/mmcls/autoformer/autoformer_supernet_32xb256_in1k.py new file mode 100644 index 000000000..24639a545 --- /dev/null +++ b/configs/nas/mmcls/autoformer/autoformer_supernet_32xb256_in1k.py @@ -0,0 +1,79 @@ +_base_ = [ + 'mmrazor::_base_/settings/imagenet_bs2048_AdamW.py', + 'mmcls::_base_/default_runtime.py', +] + +# data preprocessor +data_preprocessor = dict( + _scope_='mmcls', + type='ClsDataPreprocessor', + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, + num_classes=1000, + batch_augments=dict( + augments=[ + dict(type='Mixup', alpha=0.2), + dict(type='CutMix', alpha=1.0) + ], + probs=[0.5, 0.5])) + +arch_setting = dict( + mlp_ratios=[3.0, 3.5, 4.0], + num_heads=[8, 9, 10], + depth=[14, 15, 16], + embed_dims=[528, 576, 624]) + +supernet = dict( + _scope_='mmrazor', + type='SearchableImageClassifier', + data_preprocessor=data_preprocessor, + backbone=dict( + _scope_='mmrazor', + type='AutoformerBackbone', + arch_setting=arch_setting), + neck=None, + head=dict( + type='DynamicLinearClsHead', + num_classes=1000, + in_channels=624, + loss=dict( + type='mmcls.LabelSmoothLoss', + mode='original', + num_classes=1000, + label_smooth_val=0.1, + loss_weight=1.0), + topk=(1, 5)), + connect_head=dict(connect_with_backbone='backbone.last_mutable'), +) + +model = dict( + type='mmrazor.Autoformer', + architecture=supernet, + fix_subnet=None, + mutators=dict( + channel_mutator=dict( + type='mmrazor.OneShotChannelMutator', + channel_unit_cfg={ + 'type': 'OneShotMutableChannelUnit', + 'default_args': { + 'unit_predefined': True + } + }, + parse_cfg={'type': 'Predefined'}), + value_mutator=dict(type='mmrazor.DynamicValueMutator'))) + +# runtime setting +custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')] + +# checkpoint saving +_base_.default_hooks.checkpoint = dict( + type='CheckpointHook', + interval=2, + by_epoch=True, + save_best='accuracy/top1', + max_keep_ckpts=3) + +find_unused_parameters = True diff --git a/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py b/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py index 4f5edb316..87553ec39 100644 --- a/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py +++ b/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py @@ -13,5 +13,5 @@ num_mutation=25, num_crossover=25, mutate_prob=0.1, - flops_range=(0., 465.), + constraints_range=dict(flops=(0., 465.)), score_key='accuracy/top1') diff --git a/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py b/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py index f3f963e40..f5a5e88f4 100644 --- a/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py +++ b/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py @@ -13,5 +13,5 @@ num_mutation=25, num_crossover=25, mutate_prob=0.1, - flops_range=(0., 330.), + constraints_range=dict(flops=(0, 330)), score_key='accuracy/top1') diff --git a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py index d1dd1637a..689618362 100644 --- a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py +++ b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py @@ -13,5 +13,5 @@ num_mutation=20, num_crossover=20, mutate_prob=0.1, - flops_range=(0., 300.), + constraints_range=dict(flops=(0, 330)), score_key='coco/bbox_mAP') diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index a9a76b383..d85c0ed30 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import os import os.path as osp import random import warnings -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from mmengine import fileio @@ -14,10 +15,10 @@ from torch.utils.data import DataLoader from mmrazor.models.task_modules import ResourceEstimator -from mmrazor.registry import LOOPS +from mmrazor.registry import LOOPS, TASK_UTILS from mmrazor.structures import Candidates, export_fix_subnet from mmrazor.utils import SupportRandomSubnet -from .utils import check_subnet_flops, crossover +from .utils import check_subnet_resources, crossover @LOOPS.register_module() @@ -41,10 +42,11 @@ class EvolutionSearchLoop(EpochBasedTrainLoop): num_crossover (int): The number of candidates got by crossover. Defaults to 25. mutate_prob (float): The probability of mutation. Defaults to 0.1. - flops_range (tuple, optional): It is used for screening candidates. - resource_estimator_cfg (dict): The config for building estimator, which - is be used to estimate the flops of sampled subnet. Defaults to - None, which means default config is used. + crossover_prob (float): The probability of crossover. Defaults to 0.5. + constraints_range (Dict[str, Any]): Constraints to be used for + screening candidates. Defaults to dict(flops=(0, 330)). + resource_estimator_cfg (dict, Optional): Used for building a + resource estimator. Defaults to None. score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. init_candidates (str, optional): The candidates file path, which is @@ -64,8 +66,9 @@ def __init__(self, num_mutation: int = 25, num_crossover: int = 25, mutate_prob: float = 0.1, - flops_range: Optional[Tuple[float, float]] = (0., 330.), - resource_estimator_cfg: Optional[dict] = None, + crossover_prob: float = 0.5, + constraints_range: Dict[str, Any] = dict(flops=(0., 330.)), + resource_estimator_cfg: Optional[Dict] = None, score_key: str = 'accuracy/top1', init_candidates: Optional[str] = None) -> None: super().__init__(runner, dataloader, max_epochs) @@ -83,11 +86,12 @@ def __init__(self, self.num_candidates = num_candidates self.top_k = top_k - self.flops_range = flops_range + self.constraints_range = constraints_range self.score_key = score_key self.num_mutation = num_mutation self.num_crossover = num_crossover self.mutate_prob = mutate_prob + self.crossover_prob = crossover_prob self.max_keep_ckpts = max_keep_ckpts self.resume_from = resume_from @@ -99,16 +103,58 @@ def __init__(self, correct init candidates file' self.top_k_candidates = Candidates() - if resource_estimator_cfg is None: - self.estimator = ResourceEstimator() - else: - self.estimator = ResourceEstimator(**resource_estimator_cfg) if self.runner.distributed: self.model = runner.model.module else: self.model = runner.model + # Build resource estimator. + resource_estimator_cfg = dict( + ) if resource_estimator_cfg is None else resource_estimator_cfg + self.estimator = self.build_resource_estimator(resource_estimator_cfg) + + def build_resource_estimator( + self, resource_estimator: Union[ResourceEstimator, + Dict]) -> ResourceEstimator: + """Build resource estimator for search loop. + + Examples of ``resource_estimator``: + + # `ResourceEstimator` will be used + resource_estimator = dict() + + # custom resource_estimator + resource_estimator = dict(type='mmrazor.ResourceEstimator') + + Args: + resource_estimator (ResourceEstimator or dict): A + resource_estimator or a dict to build resource estimator. + If ``resource_estimator`` is a resource estimator object, + just returns itself. + + Returns: + :obj:`ResourceEstimator`: Resource estimator object build from + ``resource_estimator``. + """ + if isinstance(resource_estimator, ResourceEstimator): + return resource_estimator + elif not isinstance(resource_estimator, dict): + raise TypeError( + 'resource estimator should be a ResourceEstimator object or' + f'dict, but got {resource_estimator}') + + resource_estimator_cfg = copy.deepcopy( + resource_estimator) # type: ignore + + if 'type' in resource_estimator_cfg: + estimator = TASK_UTILS.build(resource_estimator_cfg) + else: + estimator = ResourceEstimator( + **resource_estimator_cfg) # type: ignore + + return estimator # type: ignore + def run(self) -> None: """Launch searching.""" self.runner.call_hook('before_train') @@ -144,31 +190,48 @@ def run_epoch(self) -> None: f'{scores_before}') self.candidates.extend(self.top_k_candidates) - self.candidates.sort(key=lambda x: x[1], reverse=True) - self.top_k_candidates = Candidates(self.candidates[:self.top_k]) + self.candidates.sort_by(key_indicator='score', reverse=True) + self.top_k_candidates = Candidates(self.candidates.data[:self.top_k]) scores_after = self.top_k_candidates.scores self.runner.logger.info(f'top k scores after update: ' f'{scores_after}') mutation_candidates = self.gen_mutation_candidates() + self.candidates_mutator_crossover = Candidates(mutation_candidates) crossover_candidates = self.gen_crossover_candidates() - candidates = mutation_candidates + crossover_candidates - assert len(candidates) <= self.num_candidates, 'Total of mutation and \ - crossover should be no more than the number of candidates.' + self.candidates_mutator_crossover.extend(crossover_candidates) + + assert len(self.candidates_mutator_crossover + ) <= self.num_candidates, 'Total of mutation and \ + crossover should be less than the number of candidates.' - self.candidates = Candidates(candidates) + self.candidates = self.candidates_mutator_crossover self._epoch += 1 def sample_candidates(self) -> None: """Update candidate pool contains specified number of candicates.""" + candidates_resources = [] + init_candidates = len(self.candidates) if self.runner.rank == 0: while len(self.candidates) < self.num_candidates: candidate = self.model.sample_subnet() - if self._check_constraints(random_subnet=candidate): + is_pass, result = self._check_constraints( + random_subnet=candidate) + if is_pass: self.candidates.append(candidate) + candidates_resources.append(result) + self.candidates = Candidates(self.candidates.data) else: - self.candidates = Candidates([None] * self.num_candidates) + self.candidates = Candidates([dict(a=0)] * self.num_candidates) + + if len(candidates_resources) > 0: + self.candidates.update_resources( + candidates_resources, + start=len(self.candidates.data) - len(candidates_resources)) + assert init_candidates + len( + candidates_resources) == self.num_candidates + # broadcast candidates to val with multi-GPUs. broadcast_object_list(self.candidates.data) @@ -180,14 +243,18 @@ def update_candidates_scores(self) -> None: metrics = self._val_candidate() score = metrics[self.score_key] \ if len(metrics) != 0 else 0. - self.candidates.set_score(i, score) + self.candidates.set_resource(i, score, 'score') self.runner.logger.info( f'Epoch:[{self._epoch}/{self._max_epochs}] ' f'Candidate:[{i + 1}/{self.num_candidates}] ' - f'Score:{score}') + f'Flops: {self.candidates.resources("flops")[i]} ' + f'Params: {self.candidates.resources("params")[i]} ' + f'Latency: {self.candidates.resources("latency")[i]} ' + f'Score: {self.candidates.scores} ') - def gen_mutation_candidates(self) -> List: + def gen_mutation_candidates(self): """Generate specified number of mutation candicates.""" + mutation_resources = [] mutation_candidates: List = [] max_mutate_iters = self.num_mutation * 10 mutate_iter = 0 @@ -198,12 +265,20 @@ def gen_mutation_candidates(self) -> List: mutation_candidate = self._mutation() - if self._check_constraints(random_subnet=mutation_candidate): + is_pass, result = self._check_constraints( + random_subnet=mutation_candidate) + if is_pass: mutation_candidates.append(mutation_candidate) + mutation_resources.append(result) + + mutation_candidates = Candidates(mutation_candidates) + mutation_candidates.update_resources(mutation_resources) + return mutation_candidates - def gen_crossover_candidates(self) -> List: + def gen_crossover_candidates(self): """Generate specofied number of crossover candicates.""" + crossover_resources = [] crossover_candidates: List = [] crossover_iter = 0 max_crossover_iters = self.num_crossover * 10 @@ -214,8 +289,15 @@ def gen_crossover_candidates(self) -> List: crossover_candidate = self._crossover() - if self._check_constraints(random_subnet=crossover_candidate): + is_pass, result = self._check_constraints( + random_subnet=crossover_candidate) + if is_pass: crossover_candidates.append(crossover_candidate) + crossover_resources.append(result) + + crossover_candidates = Candidates(crossover_candidates) + crossover_candidates.update_resources(crossover_resources) + return crossover_candidates def _mutation(self) -> SupportRandomSubnet: @@ -229,7 +311,7 @@ def _crossover(self) -> SupportRandomSubnet: """Crossover.""" candidate1 = random.choice(self.top_k_candidates.subnets) candidate2 = random.choice(self.top_k_candidates.subnets) - candidate = crossover(candidate1, candidate2) + candidate = crossover(candidate1, candidate2, prob=self.crossover_prob) return candidate def _resume(self): @@ -263,7 +345,7 @@ def _val_candidate(self) -> Dict: self.runner.model.eval() for data_batch in self.dataloader: outputs = self.runner.model.val_step(data_batch) - self.evaluator.process(data_samples=outputs, data_batch=data_batch) + self.evaluator.process(outputs, data_batch) metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) return metrics @@ -295,16 +377,17 @@ def _save_searcher_ckpt(self) -> None: if osp.isfile(ckpt_path): os.remove(ckpt_path) - def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: + def _check_constraints( + self, random_subnet: SupportRandomSubnet) -> Tuple[bool, Dict]: """Check whether is beyond constraints. Returns: - bool: The result of checking. + bool, result: The result of checking. """ - is_pass = check_subnet_flops( + is_pass, results = check_subnet_resources( model=self.model, subnet=random_subnet, estimator=self.estimator, - flops_range=self.flops_range) + constraints_range=self.constraints_range) - return is_pass + return is_pass, results diff --git a/mmrazor/engine/runner/subnet_sampler_loop.py b/mmrazor/engine/runner/subnet_sampler_loop.py index 1127aab21..273561568 100644 --- a/mmrazor/engine/runner/subnet_sampler_loop.py +++ b/mmrazor/engine/runner/subnet_sampler_loop.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import math import os import random from abc import abstractmethod -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch from mmengine import fileio @@ -13,10 +14,10 @@ from torch.utils.data import DataLoader from mmrazor.models.task_modules import ResourceEstimator -from mmrazor.registry import LOOPS +from mmrazor.registry import LOOPS, TASK_UTILS from mmrazor.structures import Candidates from mmrazor.utils import SupportRandomSubnet -from .utils import check_subnet_flops +from .utils import check_subnet_resources class BaseSamplerTrainLoop(IterBasedTrainLoop): @@ -77,18 +78,15 @@ def run_iter(self, data_batch: Sequence[dict]) -> None: @LOOPS.register_module() class GreedySamplerTrainLoop(BaseSamplerTrainLoop): """IterBasedTrainLoop for greedy sampler. - In GreedySamplerTrainLoop, `Greedy` means that only use some top sampled candidates to train the supernet. So GreedySamplerTrainLoop mainly picks the top candidates based on their val socres, then use them to train the supernet one by one. - Steps: 1. Sample from the supernet and the candidates. 2. Validate these sampled candidates to get each candidate's score. 3. Get top-k candidates based on their scores, then use them to train the supernet one by one. - Args: runner (Runner): A reference of runner. dataloader (Dataloader or dict): A dataloader object or a dict to @@ -102,10 +100,10 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop): val_interval (int): Validation interval. Defaults to 1000. score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. - flops_range (dict): Constraints to be used for screening candidates. - resource_estimator_cfg (dict): The config for building estimator, which - is be used to estimate the flops of sampled subnet. Defaults to - None, which means default config is used. + constraints_range (Dict[str, Any]): Constraints to be used for + screening candidates. Defaults to dict(flops=(0, 330)). + resource_estimator_cfg (dict, Optional): Used for building a + resource estimator. Defaults to None. num_candidates (int): The number of the candidates consist of samples from supernet and itself. Defaults to 1000. num_samples (int): The number of sample in each sampling subnet. @@ -139,8 +137,8 @@ def __init__(self, val_begin: int = 1, val_interval: int = 1000, score_key: str = 'accuracy/top1', - flops_range: Optional[Tuple[float, float]] = (0., 330), - resource_estimator_cfg: Optional[dict] = None, + constraints_range: Dict[str, Any] = dict(flops=(0, 330)), + resource_estimator_cfg: Optional[Dict] = None, num_candidates: int = 1000, num_samples: int = 10, top_k: int = 5, @@ -163,7 +161,7 @@ def __init__(self, self.evaluator = evaluator self.score_key = score_key - self.flops_range = flops_range + self.constraints_range = constraints_range self.num_candidates = num_candidates self.num_samples = num_samples self.top_k = top_k @@ -177,10 +175,52 @@ def __init__(self, self.candidates = Candidates() self.top_k_candidates = Candidates() - if resource_estimator_cfg is None: - self.estimator = ResourceEstimator() + + # Build resource estimator. + resource_estimator_cfg = dict( + ) if resource_estimator_cfg is None else resource_estimator_cfg + self.estimator = self.build_resource_estimator(resource_estimator_cfg) + + def build_resource_estimator( + self, resource_estimator: Union[ResourceEstimator, + Dict]) -> ResourceEstimator: + """Build resource estimator for search loop. + + Examples of ``resource_estimator``: + + # `ResourceEstimator` will be used + resource_estimator = dict() + + # custom resource_estimator + resource_estimator = dict(type='mmrazor.ResourceEstimator') + + Args: + resource_estimator (ResourceEstimator or dict): + A resource_estimator or a dict to build resource estimator. + If ``resource_estimator`` is a resource estimator object, + just returns itself. + + Returns: + :obj:`ResourceEstimator`: Resource estimator object build from + ``resource_estimator``. + """ + if isinstance(resource_estimator, ResourceEstimator): + return resource_estimator + elif not isinstance(resource_estimator, dict): + raise TypeError( + 'resource estimator should be a ResourceEstimator object or' + f'dict, but got {resource_estimator}') + + resource_estimator_cfg = copy.deepcopy( + resource_estimator) # type: ignore + + if 'type' in resource_estimator_cfg: + estimator = TASK_UTILS.build(resource_estimator_cfg) else: - self.estimator = ResourceEstimator(**resource_estimator_cfg) + estimator = ResourceEstimator( + **resource_estimator_cfg) # type: ignore + + return estimator # type: ignore def run(self) -> None: """Launch training.""" @@ -230,9 +270,11 @@ def sample_subnet(self) -> SupportRandomSubnet: self.update_candidates_scores() - self.candidates.sort(key=lambda x: x[1], reverse=True) - self.candidates = Candidates(self.candidates[:self.num_candidates]) - self.top_k_candidates = Candidates(self.candidates[:self.top_k]) + self.candidates.sort_by(key_indicator='score', reverse=True) + self.candidates = Candidates( + self.candidates.data[:self.num_candidates]) + self.top_k_candidates = Candidates( + self.candidates.data[:self.top_k]) top1_score = self.top_k_candidates.scores[0] if (self._iter % self.val_interval) < self.top_k: @@ -243,7 +285,7 @@ def sample_subnet(self) -> SupportRandomSubnet: f'{num_sample_from_supernet}/{self.num_samples} ' f'top1_score {top1_score:.3f} ' f'cur_num_candidates: {len(self.candidates)}') - return self.top_k_candidates.pop(0)[0] + return self.top_k_candidates.subnets[0] def update_cur_prob(self, cur_iter: int) -> None: """update current probablity of sampling from the candidates, which is @@ -278,7 +320,8 @@ def get_candidates_with_sample(self, for _ in range(num_samples): if random.random() >= self.cur_prob or len(self.candidates) == 0: subnet = self._sample_from_supernet() - if self._check_constraints(subnet): + is_pass, _ = self._check_constraints(subnet) + if is_pass: sampled_candidates.append(subnet) num_sample_from_supernet += 1 else: @@ -292,7 +335,7 @@ def update_candidates_scores(self) -> None: self.model.set_subnet(candidate) metrics = self._val_candidate() score = metrics[self.score_key] if len(metrics) != 0 else 0. - self.candidates.set_score(i, score) + self.candidates.set_resource(i, score, 'score') @torch.no_grad() def _val_candidate(self) -> Dict: @@ -312,22 +355,22 @@ def _sample_from_supernet(self) -> SupportRandomSubnet: def _sample_from_candidates(self) -> SupportRandomSubnet: """Sample from the candidates.""" assert len(self.candidates) > 0 - subnet = random.choice(self.candidates) + subnet = random.choice(self.candidates.data) return subnet - def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: + def _check_constraints(self, random_subnet: SupportRandomSubnet): """Check whether is beyond constraints. Returns: - bool: The result of checking. + bool, result: The result of checking. """ - is_pass = check_subnet_flops( + is_pass, results = check_subnet_resources( model=self.model, subnet=random_subnet, estimator=self.estimator, - flops_range=self.flops_range) + constraints_range=self.constraints_range) - return is_pass + return is_pass, results def _save_candidates(self) -> None: """Save the candidates to init the next searching.""" diff --git a/mmrazor/engine/runner/utils/__init__.py b/mmrazor/engine/runner/utils/__init__.py index ec2f2cb29..557002e2c 100644 --- a/mmrazor/engine/runner/utils/__init__.py +++ b/mmrazor/engine/runner/utils/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .check import check_subnet_flops +from .check import check_subnet_resources from .genetic import crossover -__all__ = ['crossover', 'check_subnet_flops'] +__all__ = ['crossover', 'check_subnet_resources'] diff --git a/mmrazor/engine/runner/utils/check.py b/mmrazor/engine/runner/utils/check.py index e2fdcfcc6..7e4c5d66b 100644 --- a/mmrazor/engine/runner/utils/check.py +++ b/mmrazor/engine/runner/utils/check.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -from typing import Optional, Tuple +from typing import Any, Dict, Tuple -import torch.nn as nn +import torch from mmrazor.models import ResourceEstimator from mmrazor.structures import export_fix_subnet, load_fix_subnet @@ -15,18 +15,20 @@ BaseDetector = get_placeholder('mmdet') -def check_subnet_flops( - model: nn.Module, - subnet: SupportRandomSubnet, - estimator: ResourceEstimator, - flops_range: Optional[Tuple[float, float]] = None) -> bool: - """Check whether is beyond flops constraints. +@torch.no_grad() +def check_subnet_resources( + model, + subnet: SupportRandomSubnet, + estimator: ResourceEstimator, + constraints_range: Dict[str, Any] = dict(flops=(0, 330)) +) -> Tuple[bool, Dict]: + """Check whether is beyond resources constraints. Returns: - bool: The result of checking. + bool, result: The result of checking. """ - if flops_range is None: - return True + if constraints_range is None: + return True, dict() assert hasattr(model, 'set_subnet') and hasattr(model, 'architecture') model.set_subnet(subnet) @@ -40,9 +42,10 @@ def check_subnet_flops( else: results = estimator.estimate(model=model_to_check) - flops = results['flops'] - flops_mix, flops_max = flops_range - if flops_mix <= flops <= flops_max: # type: ignore - return True - else: - return False + for k, v in constraints_range.items(): + if not isinstance(v, (list, tuple)): + v = (0, v) + if results[k] < v[0] or results[k] > v[1]: + return False, results + + return True, results diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index a5129acb4..d7e4dc3af 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -3,7 +3,8 @@ from .distill import (DAFLDataFreeDistillation, DataFreeDistillation, FpnTeacherDistill, OverhaulFeatureDistillation, SelfDistill, SingleTeacherDistill) -from .nas import DSNAS, DSNASDDP, SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP +from .nas import (DSNAS, DSNASDDP, SPOS, Autoformer, AutoSlim, AutoSlimDDP, + Darts, DartsDDP) from .pruning import SlimmableNetwork, SlimmableNetworkDDP from .pruning.ite_prune_algorithm import ItePruneAlgorithm @@ -25,4 +26,5 @@ 'ItePruneAlgorithm', 'DSNAS', 'DSNASDDP', + 'Autoformer', ] diff --git a/mmrazor/models/algorithms/nas/__init__.py b/mmrazor/models/algorithms/nas/__init__.py index b290afa0a..3757fe455 100644 --- a/mmrazor/models/algorithms/nas/__init__.py +++ b/mmrazor/models/algorithms/nas/__init__.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .autoformer import Autoformer from .autoslim import AutoSlim, AutoSlimDDP from .darts import Darts, DartsDDP from .dsnas import DSNAS, DSNASDDP from .spos import SPOS __all__ = [ - 'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'DSNAS', 'DSNASDDP' + 'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'DSNAS', + 'DSNASDDP', 'Autoformer' ] diff --git a/mmrazor/models/algorithms/nas/autoformer.py b/mmrazor/models/algorithms/nas/autoformer.py new file mode 100644 index 000000000..ba4baf389 --- /dev/null +++ b/mmrazor/models/algorithms/nas/autoformer.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +from mmengine.model import BaseModel +from mmengine.structures import BaseDataElement +from torch import nn + +from mmrazor.registry import MODELS +from mmrazor.utils import ValidFixMutable +from ..base import BaseAlgorithm, LossResults + + +@MODELS.register_module() +class Autoformer(BaseAlgorithm): + """Implementation of `Autoformer `_ + + AutoFormer is dedicated to vision transformer search. AutoFormer + entangles the weights of different blocks in the same layers during + supernet training. + The logic of the search part is implemented in + :class:`mmrazor.engine.EvolutionSearchLoop` + Args: + architecture (dict|:obj:`BaseModel`): The config of :class:`BaseModel` + or built model. Corresponding to supernet in NAS algorithm. + mutators (Optional[dict]): The dict of different Mutators config. + Defaults to None. + fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or + loaded dict or built :obj:`FixSubnet`. Defaults to None. + data_preprocessor (Optional[Union[dict, nn.Module]]): The pre-process + config of :class:`BaseDataPreprocessor`. Defaults to None. + init_cfg (Optional[dict]): Init config for ``BaseModule``. + Defaults to None. + Note: + Autoformer uses two mutators which are ``DynamicValueMutator`` and + ``ChannelMutator``. `DynamicValueMutator` handle the mutable object + ``OneShotMutableValue`` in Autoformer while ChannelMutator handle + the mutable object ``OneShotMutableChannel`` in Autoformer. + """ + + def __init__(self, + architecture: Union[BaseModel, Dict], + mutators: Optional[Dict] = None, + fix_subnet: Optional[ValidFixMutable] = None, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None): + super().__init__(architecture, data_preprocessor, init_cfg) + + # Autoformer support supernet training and subnet retraining. + # fix_subnet is not None, means subnet retraining. + if fix_subnet: + # Avoid circular import + from mmrazor.structures import load_fix_subnet + + # According to fix_subnet, delete the unchosen part of supernet + load_fix_subnet(self.architecture, fix_subnet) + self.is_supernet = False + else: + assert mutators is not None, \ + 'mutator cannot be None when fix_subnet is None.' + if isinstance(mutators, dict): + built_mutators: Dict = dict() + for name, mutator_cfg in mutators.items(): + if 'parse_cfg' in mutator_cfg and isinstance( + mutator_cfg['parse_cfg'], dict): + assert mutator_cfg['parse_cfg'][ + 'type'] == 'Predefined', \ + 'autoformer only support predefined.' + mutator = MODELS.build(mutator_cfg) + built_mutators[name] = mutator + mutator.prepare_from_supernet(self.architecture) + self.mutators = built_mutators + else: + raise TypeError('mutator should be a `dict` but got ' + f'{type(mutator)}') + + self.is_supernet = True + + def sample_subnet(self) -> Dict: + """Random sample subnet by mutator.""" + subnet_dict = dict() + for name, mutator in self.mutators.items(): + if name == 'value_mutator': + subnet_dict.update( + dict((str(group_id), value) for group_id, value in + mutator.sample_choices().items())) + else: + subnet_dict.update(mutator.sample_choices()) + return subnet_dict + + def set_subnet(self, subnet_dict: Dict) -> None: + """Set the subnet sampled by :meth:sample_subnet.""" + for name, mutator in self.mutators.items(): + if name == 'value_mutator': + value_subnet = dict((int(group_id), value) + for group_id, value in subnet_dict.items() + if isinstance(group_id, str)) + mutator.set_choices(value_subnet) + else: + channel_subnet = dict( + (group_id, value) + for group_id, value in subnet_dict.items() + if isinstance(group_id, int)) + mutator.set_choices(channel_subnet) + + def loss( + self, + batch_inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + ) -> LossResults: + """Calculate losses from a batch of inputs and data samples.""" + if self.is_supernet: + random_subnet = self.sample_subnet() + self.set_subnet(random_subnet) + return self.architecture(batch_inputs, data_samples, mode='loss') diff --git a/mmrazor/models/architectures/__init__.py b/mmrazor/models/architectures/__init__.py index b3ec197ad..e011ab5b9 100644 --- a/mmrazor/models/architectures/__init__.py +++ b/mmrazor/models/architectures/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .backbones import * # noqa: F401,F403 +from .classifiers import * # noqa: F401,F403 from .connectors import * # noqa: F401,F403 from .dynamic_ops import * # noqa: F401,F403 from .generators import * # noqa: F401,F403 diff --git a/mmrazor/models/architectures/backbones/__init__.py b/mmrazor/models/architectures/backbones/__init__.py index 0fcce8b61..4e99a6746 100644 --- a/mmrazor/models/architectures/backbones/__init__.py +++ b/mmrazor/models/architectures/backbones/__init__.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .darts_backbone import DartsBackbone +from .searchable_autoformer import AutoformerBackbone from .searchable_mobilenet import SearchableMobileNet from .searchable_shufflenet_v2 import SearchableShuffleNetV2 from .wideresnet import WideResNet __all__ = [ 'SearchableMobileNet', 'SearchableShuffleNetV2', 'DartsBackbone', - 'WideResNet' + 'WideResNet', 'AutoformerBackbone' ] diff --git a/mmrazor/models/architectures/backbones/searchable_autoformer.py b/mmrazor/models/architectures/backbones/searchable_autoformer.py new file mode 100644 index 000000000..ffcccab0e --- /dev/null +++ b/mmrazor/models/architectures/backbones/searchable_autoformer.py @@ -0,0 +1,374 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_norm_layer + +from mmrazor.models.architectures.dynamic_ops.bricks import ( + DynamicLinear, DynamicMultiheadAttention, DynamicPatchEmbed, + DynamicSequential) +from mmrazor.models.mutables import (BaseMutable, BaseMutableChannel, + MutableChannelContainer, + OneShotMutableChannel, + OneShotMutableValue) +from mmrazor.models.mutables.mutable_channel import OneShotMutableChannelUnit +from mmrazor.registry import MODELS + +try: + from mmcls.models.backbones.base_backbone import BaseBackbone +except ImportError: + from mmrazor.utils import get_placeholder + BaseBackbone = get_placeholder('mmcls') + + +class TransformerEncoderLayer(BaseBackbone): + """Autoformer block. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + mlp_ratio (List): Ratio of ffn. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop path rate after attention. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + qkv_bias (bool, optional): Whether to keep bias of qkv. + Defaults to True. + act_cfg (Dict, optional): The config for acitvation function. + Defaults to dict(type='GELU'). + norm_cfg (Dict, optional): The config for normalization. + Defaults to dict(type='mmrazor.DynamicLayerNorm'). + init_cfg (Dict, optional): The config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + mlp_ratio: float, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + qkv_bias: bool = True, + act_cfg: Dict = dict(type='GELU'), + norm_cfg: Dict = dict(type='mmrazor.DynamicLayerNorm'), + init_cfg: Dict = None) -> None: + super().__init__(init_cfg) + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + self.attn = DynamicMultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop_rate=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias) + + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + self.add_module(self.norm2_name, norm2) + + middle_channels = int(embed_dims * mlp_ratio) + self.fc1 = DynamicLinear(embed_dims, middle_channels) + self.fc2 = DynamicLinear(middle_channels, embed_dims) + self.act = build_activation_layer(act_cfg) + + @property + def norm1(self): + """The first normalization.""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """The second normalization.""" + return getattr(self, self.norm2_name) + + def register_mutables(self, mutable_num_heads: BaseMutable, + mutable_mlp_ratios: BaseMutable, + mutable_q_embed_dims: BaseMutable, + mutable_head_dims: BaseMutable, + mutable_embed_dims: BaseMutable): + """Mutate the mutables of encoder layer.""" + # record the mutables + self.mutable_num_heads = mutable_num_heads + self.mutable_mlp_ratios = mutable_mlp_ratios + self.mutable_q_embed_dims = mutable_q_embed_dims + self.mutable_embed_dims = mutable_embed_dims + self.mutable_head_dims = mutable_head_dims + # handle the mutable of FFN + self.middle_channels = mutable_mlp_ratios * mutable_embed_dims + + self.attn.register_mutable_attr('num_heads', mutable_num_heads) + + # handle the mutable of the first dynamic LN + MutableChannelContainer.register_mutable_channel_to_module( + self.norm1, self.mutable_embed_dims, True) + # handle the mutable of the second dynamic LN + MutableChannelContainer.register_mutable_channel_to_module( + self.norm2, self.mutable_embed_dims, True) + + # handle the mutable of attn + MutableChannelContainer.register_mutable_channel_to_module( + self.attn, self.mutable_embed_dims, False) + MutableChannelContainer.register_mutable_channel_to_module( + self.attn, + self.mutable_q_embed_dims, + True, + end=self.mutable_q_embed_dims.current_choice) + MutableChannelContainer.register_mutable_channel_to_module( + self.attn.rel_pos_embed_k, self.mutable_head_dims, False) + MutableChannelContainer.register_mutable_channel_to_module( + self.attn.rel_pos_embed_v, self.mutable_head_dims, False) + + # handle the mutable of fc + MutableChannelContainer.register_mutable_channel_to_module( + self.fc1, mutable_embed_dims, False) + MutableChannelContainer.register_mutable_channel_to_module( + self.fc1, + self.middle_channels, + True, + start=0, + end=self.middle_channels.current_choice) + MutableChannelContainer.register_mutable_channel_to_module( + self.fc2, + self.middle_channels, + False, + start=0, + end=self.middle_channels.current_choice) + MutableChannelContainer.register_mutable_channel_to_module( + self.fc2, mutable_embed_dims, True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward of Transformer Encode Layer.""" + residual = x + x = self.norm1(x) + x = self.attn(x) + x = residual + x + residual = x + x = self.norm2(x) + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return residual + x + + +@MODELS.register_module() +class AutoformerBackbone(BaseBackbone): + """Autoformer backbone. + + A PyTorch implementation of Autoformer introduced by: + `AutoFormer: Searching Transformers for Visual Recognition + `_ + + Modified from the `official repo + `. + + Args: + arch_setting (Dict[str, List]): Architecture settings. + img_size (int, optional): The image size of input. + Defaults to 224. + patch_size (int, optional): The patch size of autoformer. + Defaults to 16. + in_channels (int, optional): The input channel dimension. + Defaults to 3. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool, optional): Whether to keep bias of qkv. + Defaults to True. + norm_cfg (Dict, optional): The config of normalization. + Defaults to dict(type='mmrazor.DynamicLayerNorm'). + act_cfg (Dict, optional): The config of activation functions. + Defaults to dict(type='GELU'). + use_final_norm (bool, optional): Whether use final normalization. + Defaults to True. + init_cfg (Dict, optional): The config for initialization. + Defaults to None. + + Excamples: + >>> arch_setting = dict( + ... mlp_ratios=[3.0, 3.5, 4.0], + ... num_heads=[8, 9, 10], + ... depth=[14, 15, 16], + ... embed_dims=[528, 576, 624] + ... ) + >>> model = AutoformerBackbone(arch_setting=arch_setting) + """ + + def __init__(self, + arch_setting: Dict[str, List], + img_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + drop_rate: float = 0., + drop_path_rate: float = 0., + qkv_bias: bool = True, + norm_cfg: Dict = dict(type='mmrazor.DynamicLayerNorm'), + act_cfg: Dict = dict(type='GELU'), + use_final_norm: bool = True, + init_cfg: Dict = None) -> None: + + super().__init__(init_cfg) + + self.arch_setting = arch_setting + self.img_size = img_size + self.patch_size = patch_size + self.qkv_bias = qkv_bias + self.in_channels = in_channels + self.drop_rate = drop_rate + self.use_final_norm = use_final_norm + self.act_cfg = act_cfg + + # adapt mutable settings + self.mlp_ratio_range: List = self.arch_setting['mlp_ratios'] + self.num_head_range: List = self.arch_setting['num_heads'] + self.depth_range: List = self.arch_setting['depth'] + self.embed_dim_range: List = self.arch_setting['embed_dims'] + + # mutable variables of autoformer + self.mutable_depth = OneShotMutableValue( + value_list=self.depth_range, default_value=self.depth_range[-1]) + + self.mutable_embed_dims = OneShotMutableChannel( + num_channels=self.embed_dim_range[-1], + candidate_choices=self.embed_dim_range) + + # handle the mutable in multihead attention + self.base_embed_dims = OneShotMutableChannel( + num_channels=64, candidate_choices=[64]) + + self.mutable_num_heads = [ + OneShotMutableValue( + value_list=self.num_head_range, + default_value=self.num_head_range[-1]) + for _ in range(self.depth_range[-1]) + ] + self.mutable_mlp_ratios = [ + OneShotMutableValue( + value_list=self.mlp_ratio_range, + default_value=self.mlp_ratio_range[-1]) + for _ in range(self.depth_range[-1]) + ] + + self.mutable_q_embed_dims = [ + i * self.base_embed_dims for i in self.mutable_num_heads + ] + + # patch embeddings + self.patch_embed = DynamicPatchEmbed( + img_size=self.img_size, + in_channels=self.in_channels, + embed_dims=self.mutable_embed_dims.num_channels) + + # num of patches + self.patch_resolution = [ + img_size // patch_size, img_size // patch_size + ] + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # cls token and pos embed + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, + self.mutable_embed_dims.num_channels)) + + self.cls_token = nn.Parameter( + torch.zeros(1, 1, self.mutable_embed_dims.num_channels)) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + # stochastic depth decay rule + self.dpr = np.linspace(0, drop_path_rate, + self.mutable_depth.max_choice) + + # main body + self.blocks = self.make_layers( + embed_dims=self.mutable_embed_dims.num_channels, + depth=self.mutable_depth.max_choice) + + # final norm + if self.use_final_norm: + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, self.mutable_embed_dims.num_channels) + self.add_module(self.norm1_name, norm1) + + self.last_mutable = self.mutable_embed_dims + + self.register_mutables() + + @property + def norm1(self): + """The first normalization.""" + return getattr(self, self.norm1_name) + + def make_layers(self, embed_dims, depth): + """Build multiple TransformerEncoderLayers.""" + layers = [] + for i in range(depth): + layer = TransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=self.mutable_num_heads[i].max_choice, + mlp_ratio=self.mutable_mlp_ratios[i].max_choice, + drop_rate=self.drop_rate, + drop_path_rate=self.dpr[i], + qkv_bias=self.qkv_bias, + act_cfg=self.act_cfg) + layers.append(layer) + return DynamicSequential(*layers) + + def register_mutables(self): + """Mutate the autoformer.""" + OneShotMutableChannelUnit._register_channel_container( + self, MutableChannelContainer) + + # handle the mutation of depth + self.blocks.register_mutable_attr('depth', self.mutable_depth) + + # handle the mutation of patch embed + MutableChannelContainer.register_mutable_channel_to_module( + self.patch_embed, self.mutable_embed_dims, True) + + # handle the dependencies of TransformerEncoderLayers + for i in range(self.mutable_depth.max_choice): # max depth here + layer = self.blocks[i] + layer.register_mutables( + mutable_num_heads=self.mutable_num_heads[i], + mutable_mlp_ratios=self.mutable_mlp_ratios[i], + mutable_q_embed_dims=self.mutable_q_embed_dims[i], + mutable_head_dims=self.base_embed_dims, + mutable_embed_dims=self.last_mutable) + + # handle the mutable of final norm + if self.use_final_norm: + MutableChannelContainer.register_mutable_channel_to_module( + self.norm1, self.last_mutable, True) + + def forward(self, x: torch.Tensor): + """Forward of Autoformer.""" + B = x.shape[0] + x = self.patch_embed(x) + + embed_dims = int(self.mutable_embed_dims.current_choice) if isinstance( + self.mutable_embed_dims, + BaseMutableChannel) else self.embed_dim_range[-1] + + # cls token + cls_tokens = self.cls_token[..., :embed_dims].expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # pos embed + x = x + self.pos_embed[..., :embed_dims] + x = self.drop_after_pos(x) + + # dynamic depth + x = self.blocks(x) + + if self.use_final_norm: + x = self.norm1(x) + + return (torch.mean(x[:, 1:], dim=1), ) diff --git a/mmrazor/models/architectures/classifiers/__init__.py b/mmrazor/models/architectures/classifiers/__init__.py new file mode 100644 index 000000000..6bbd245ff --- /dev/null +++ b/mmrazor/models/architectures/classifiers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .image import SearchableImageClassifier + +__all__ = ['SearchableImageClassifier'] diff --git a/mmrazor/models/architectures/classifiers/image.py b/mmrazor/models/architectures/classifiers/image.py new file mode 100644 index 000000000..019815608 --- /dev/null +++ b/mmrazor/models/architectures/classifiers/image.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +from mmrazor.registry import MODELS + +try: + from mmcls.models import ImageClassifier +except ImportError: + from mmrazor.utils import get_placeholder + ImageClassifier = get_placeholder('mmcls') + + +@MODELS.register_module() +class SearchableImageClassifier(ImageClassifier): + """SearchableImageClassifier for sliceable networks. + + Args: + backbone (dict): The same as ImageClassifier. + neck (dict, optional): The same as ImageClassifier. Defaults to None. + head (dict, optional): The same as ImageClassifier. Defaults to None. + pretrained (dict, optional): The same as ImageClassifier. Defaults to + None. + train_cfg (dict, optional): The same as ImageClassifier. Defaults to + None. + data_preprocessor (dict, optional): The same as ImageClassifier. + Defaults to None. + init_cfg (dict, optional): The same as ImageClassifier. Defaults to + None. + connect_head (dict, optional): Dimensions are aligned in head will be + substitute to it's `str type` value, so that search_space of the + first components can be connets to the next. e.g: + {'connect_with_backbone': 'backbone.last_mutable'} means that + func:`connect_with_backbone` will be substitute to backbones + last_mutable. Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: Optional[dict] = None, + head: Optional[dict] = None, + pretrained: Optional[str] = None, + train_cfg: Optional[dict] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None, + connect_head: Optional[dict] = None): + super().__init__(backbone, neck, head, pretrained, train_cfg, + data_preprocessor, init_cfg) + + if self.with_head and connect_head is not None: + for kh, vh in connect_head.items(): + component, attr = vh.split('.') + value = getattr(getattr(self, component), attr) + getattr(self.head, kh)(value) diff --git a/mmrazor/models/architectures/dynamic_ops/__init__.py b/mmrazor/models/architectures/dynamic_ops/__init__.py index 620c9e4c8..b49693855 100644 --- a/mmrazor/models/architectures/dynamic_ops/__init__.py +++ b/mmrazor/models/architectures/dynamic_ops/__init__.py @@ -1,15 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .bricks.dynamic_conv import BigNasConv2d, DynamicConv2d, OFAConv2d -from .bricks.dynamic_linear import DynamicLinear -from .bricks.dynamic_norm import (DynamicBatchNorm1d, DynamicBatchNorm2d, - DynamicBatchNorm3d, SwitchableBatchNorm2d) -from .mixins.dynamic_conv_mixins import DynamicConvMixin -from .mixins.dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin, - DynamicLinearMixin, DynamicMixin) - -__all__ = [ - 'BigNasConv2d', 'DynamicConv2d', 'OFAConv2d', 'DynamicLinear', - 'DynamicBatchNorm1d', 'DynamicBatchNorm2d', 'DynamicBatchNorm3d', - 'DynamicMixin', 'DynamicChannelMixin', 'DynamicBatchNormMixin', - 'DynamicLinearMixin', 'SwitchableBatchNorm2d', 'DynamicConvMixin' -] +from .bricks import * # noqa: F401,F403 +from .head import * # noqa: F401,F403 +from .mixins import * # noqa: F401,F403 diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/__init__.py b/mmrazor/models/architectures/dynamic_ops/bricks/__init__.py index ef101fec6..e9cde3f7d 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/__init__.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/__init__.py @@ -1 +1,18 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .dynamic_container import DynamicSequential +from .dynamic_conv import BigNasConv2d, DynamicConv2d, OFAConv2d +from .dynamic_embed import DynamicPatchEmbed +from .dynamic_linear import DynamicLinear +from .dynamic_multi_head_attention import DynamicMultiheadAttention +from .dynamic_norm import (DynamicBatchNorm1d, DynamicBatchNorm2d, + DynamicBatchNorm3d, DynamicLayerNorm, + SwitchableBatchNorm2d) +from .dynamic_relative_position import DynamicRelativePosition2D + +__all__ = [ + 'BigNasConv2d', 'DynamicConv2d', 'OFAConv2d', 'DynamicLinear', + 'DynamicBatchNorm1d', 'DynamicBatchNorm2d', 'DynamicBatchNorm3d', + 'SwitchableBatchNorm2d', 'DynamicSequential', 'DynamicPatchEmbed', + 'DynamicLayerNorm', 'DynamicRelativePosition2D', + 'DynamicMultiheadAttention' +] diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_container.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_container.py new file mode 100644 index 000000000..3696fe38e --- /dev/null +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_container.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Iterator, Optional, Set + +import torch.nn as nn +from mmengine.model import Sequential +from torch import Tensor +from torch.nn import Module + +from mmrazor.models.mutables import DerivedMutable, MutableValue +from mmrazor.models.mutables.base_mutable import BaseMutable +from ..mixins import DynamicMixin + + +class DynamicSequential(Sequential, DynamicMixin): + """Dynamic Sequential Container.""" + mutable_attrs: nn.ModuleDict + accepted_mutable_attrs: Set[str] = {'depth'} + + forward_ignored_module = (MutableValue, DerivedMutable, nn.ModuleDict) + + def __init__(self, *args, init_cfg: Optional[dict] = None): + super().__init__(*args, init_cfg=init_cfg) + + self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict() + + @property + def mutable_depth(self): + """Mutable depth.""" + assert hasattr(self, 'mutable_attrs') + return self.mutable_attrs['depth'] + + def register_mutable_attr(self: Sequential, attr: str, + mutable: BaseMutable): + """Register attribute of mutable.""" + if attr == 'depth': + self._register_mutable_depth(mutable) + else: + raise NotImplementedError + + def _register_mutable_depth(self: Sequential, mutable_depth: MutableValue): + """Register mutable depth.""" + assert hasattr(self, 'mutable_attrs') + assert mutable_depth.current_choice is not None + current_depth = mutable_depth.current_choice + if current_depth > len(self._modules): + raise ValueError(f'Expect depth of mutable to be smaller than ' + f'{len(self._modules)} as `depth`, ' + f'but got: {current_depth}.') + self.mutable_attrs['depth'] = mutable_depth + + @property + def static_op_factory(self): + """Corresponding Pytorch OP.""" + return Sequential + + def to_static_op(self: Sequential) -> Sequential: + """Convert dynamic Sequential to static one.""" + self.check_if_mutables_fixed() + + if self.mutable_depth is None: + fixed_depth = len(self) + else: + fixed_depth = self.get_current_choice(self.mutable_depth) + + modules = [] + passed_module_nums = 0 + for module in self: + if isinstance(module, self.forward_ignored_module): + continue + else: + passed_module_nums += 1 + if passed_module_nums > fixed_depth: + break + + modules.append(module) + + return Sequential(*modules) + + def forward(self, x: Tensor) -> Tensor: + """Forward of Dynamic Sequential.""" + if self.mutable_depth is None: + return self(x) + + current_depth = self.get_current_choice(self.mutable_depth) + passed_module_nums = 0 + for module in self.pure_modules(): + passed_module_nums += 1 + if passed_module_nums > current_depth: + break + x = module(x) + return x + + @property + def pure_module_nums(self) -> int: + """Number of pure module.""" + return sum(1 for _ in self.pure_modules()) + + def pure_modules(self) -> Iterator[Module]: + """nn.Module would influence the forward of Sequential.""" + for module in self._modules.values(): + if isinstance(module, self.forward_ignored_module): + continue + yield module + + @classmethod + def convert_from(cls, module: Sequential): + """Convert the static Sequential to dynamic one.""" + dynamic_m = cls(module._modules) + return dynamic_m diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_embed.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_embed.py new file mode 100644 index 000000000..4393101c6 --- /dev/null +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_embed.py @@ -0,0 +1,142 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from typing import Dict, Set, Tuple + +import torch.nn as nn +import torch.nn.functional as F +from mmcls.models.utils import PatchEmbed +from mmengine import print_log +from torch import Tensor + +from mmrazor.models.mutables.base_mutable import BaseMutable +from mmrazor.registry import MODELS +from ..mixins import DynamicChannelMixin + + +@MODELS.register_module() +class DynamicPatchEmbed(PatchEmbed, DynamicChannelMixin): + """Dynamic Patch Embedding. + + Note: + Arguments for ``__init__`` of ``DynamicPatchEmbed`` is totally same as + :obj:`mmcls.models.utils.PatchEmbed`. + Attributes: + mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes, + such as `embed_dims`. The key of the dict must in + ``accepted_mutable_attrs``. + """ + + mutable_attrs: nn.ModuleDict + accepted_mutable_attrs: Set[str] = {'embed_dims'} + attr_mappings: Dict[str, str] = { + 'in_channels': 'embed_dims', + 'out_channels': 'embed_dims' + } + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict() + + @property + def mutable_embed_dims(self): + """Mutable embedding dimension.""" + assert hasattr(self, 'mutable_attrs') + return self.mutable_attrs['embed_dims'] + + def register_mutable_attr(self: PatchEmbed, attr: str, + mutable: BaseMutable): + """Register attribute of mutable.""" + self.check_mutable_attr_valid(attr) + if attr in self.attr_mappings: + attr_map = self.attr_mappings[attr] + assert attr_map in self.accepted_mutable_attrs + if attr_map in self.mutable_attrs: + print_log( + f'{attr_map}({attr}) is already in `mutable_attrs`', + level=logging.WARNING) + else: + self._register_mutable_attr(attr_map, mutable) + elif attr in self.accepted_mutable_attrs: + self._register_mutable_attr(attr, mutable) + else: + raise NotImplementedError + + def _register_mutable_attr(self, attr, mutable): + """Register `embed_dims`.""" + if attr == 'embed_dims': + self._register_embed_dims(mutable) + else: + raise NotImplementedError + + def _register_embed_dims(self: PatchEmbed, + mutable_patch_embedding: BaseMutable) -> None: + """Register mutable embedding dimension.""" + mask_size = mutable_patch_embedding.current_mask.size(0) + + if mask_size != self.embed_dims: + raise ValueError( + f'Expect mask size of mutable to be {self.embed_dims} as ' + f'`embed_dims`, but got: {mask_size}.') + + self.mutable_attrs['embed_dims'] = mutable_patch_embedding + + def _get_dynamic_params(self: PatchEmbed) -> Tuple[Tensor, Tensor]: + """Get mask of ``embed_dims``""" + if 'embed_dims' not in self.mutable_attrs: + return self.projection.weight, self.projection.bias + else: + out_mask = self.mutable_embed_dims.current_mask.to( + self.projection.weight.device) + weight = self.projection.weight[out_mask][:] + bias = self.projection.bias[ + out_mask] if self.projection.bias is not None else None # noqa: E501 + return weight, bias + + def to_static_op(self: PatchEmbed) -> nn.Module: + """Convert dynamic PatchEmbed to static PatchEmbed.""" + self.check_if_mutables_fixed() + assert self.mutable_embed_dims is not None + + weight, bias = self._get_dynamic_params() + static_patch_embed = self.static_op_factory( + img_size=self.img_size, + in_channels=3, + embed_dims=self.mutable_embed_dims.activated_channels) + + static_patch_embed.projection.weight = nn.Parameter(weight.clone()) + static_patch_embed.projection.bias = nn.Parameter(bias.clone()) + + return static_patch_embed + + @property + def static_op_factory(self): + """Corresponding Pytorch OP.""" + return PatchEmbed + + @classmethod + def convert_from(cls, module) -> nn.Module: + """Convert a PatchEmbed to a DynamicPatchEmbed.""" + + dynamic_patch_embed = cls( + img_size=module.img_size, + in_channels=3, + embed_dims=module.embed_dims, + norm_cfg=None, + conv_cfg=None, + init_cfg=None) + + return dynamic_patch_embed + + def forward(self, x: Tensor) -> Tensor: + """Forward of dynamic patch embed.""" + weight, bias = self._get_dynamic_params() + x = F.conv2d( + x, + weight, + bias, + stride=16, + padding=self.projection.padding, + dilation=self.projection.dilation).flatten(2).transpose(1, 2) + + return x diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_multi_head_attention.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_multi_head_attention.py new file mode 100644 index 000000000..b270f870e --- /dev/null +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_multi_head_attention.py @@ -0,0 +1,280 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from typing import Dict, Set, Tuple + +import torch.nn as nn +import torch.nn.functional as F +from mmengine import print_log +from torch import Tensor + +from mmrazor.models.architectures.ops import MultiheadAttention +from mmrazor.models.mutables.base_mutable import BaseMutable +from ..mixins import DynamicChannelMixin +from .dynamic_relative_position import DynamicRelativePosition2D # noqa: E501 + + +class DynamicMultiheadAttention(MultiheadAttention, DynamicChannelMixin): + """Dynamic Multihead Attention with iRPE.. + + Note: + Arguments for ``__init__`` of ``DynamicMultiheadAttention`` is + totally same as + :obj:`mmrazor.models.architectures.MultiheadAttention`. + Attributes: + mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes, + such as `num_heads`、 `embed_dims`、 `q_embed_dims`. + The key of the dict must in ``accepted_mutable_attrs``. + """ + + mutable_attrs: nn.ModuleDict + relative_position: bool + max_relative_position: int + w_qs: nn.Linear + w_ks: nn.Linear + w_vs: nn.Linear + embed_dims: int + q_embed_dims: int + proj: nn.Linear + attn_drop_rate: float + accepted_mutable_attrs: Set[str] = { + 'num_heads', 'embed_dims', 'q_embed_dims' + } + attr_mappings: Dict[str, str] = { + 'in_channels': 'embed_dims', + 'out_channels': 'q_embed_dims', + } + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict() + + # dynamic image relative position encoding + if self.relative_position: + self.rel_pos_embed_k = DynamicRelativePosition2D( + self.head_dims, self.max_relative_position) + self.rel_pos_embed_v = DynamicRelativePosition2D( + self.head_dims, self.max_relative_position) + + @property + def mutable_num_heads(self): + """Mutable number of heads.""" + assert hasattr(self, 'mutable_attrs') + return self.mutable_attrs['num_heads'] + + @property + def mutable_embed_dims(self): + """Mutable embedding dimension.""" + assert hasattr(self, 'mutable_attrs') + return self.mutable_attrs['embed_dims'] + + @property + def mutable_q_embed_dims(self): + """Mutable intermediate embedding dimension.""" + assert hasattr(self, 'mutable_attrs') + return self.mutable_attrs['q_embed_dims'] + + def register_mutable_attr(self, attr: str, mutable: BaseMutable): + """Register attribute of mutable.""" + self.check_mutable_attr_valid(attr) + if attr in self.attr_mappings: + attr_map = self.attr_mappings[attr] + assert attr_map in self.accepted_mutable_attrs + # if hasattr(self, 'mutable_attrs'): + if attr_map in self.mutable_attrs: + print_log( + f'{attr_map}({attr}) is already in `mutable_attrs`', + level=logging.WARNING) + else: + self._register_mutable_attr(attr_map, mutable) + elif attr in self.accepted_mutable_attrs: + self._register_mutable_attr(attr, mutable) + else: + raise NotImplementedError + + def _register_mutable_attr(self, attr: str, mutable: BaseMutable): + """Register `embed_dims` `q_embed_dims` `num_heads`""" + if attr == 'num_heads': + self._register_mutable_num_heads(mutable) + elif attr == 'embed_dims': + self._register_mutable_embed_dims(mutable) + elif attr == 'q_embed_dims': + self._register_mutable_q_embed_dims(mutable) + else: + raise NotImplementedError + + def _register_mutable_num_heads(self, mutable_num_heads): + """Register the mutable number of heads.""" + assert hasattr(self, 'mutable_attrs') + current_choice = mutable_num_heads.current_choice + if current_choice > self.num_heads: + raise ValueError( + f'Expect value of mutable to be smaller or equal than ' + f'{self.num_heads} as `num_heads`, but got: {current_choice}.') + + self.mutable_attrs['num_heads'] = mutable_num_heads + + def _register_mutable_embed_dims(self, mutable_embed_dims): + """Register mutable embedding dimension.""" + assert hasattr(self, 'mutable_attrs') + mask_size = mutable_embed_dims.current_mask.size(0) + if mask_size != self.embed_dims: + raise ValueError( + f'Expect mask size of mutable to be {self.embed_dims} as ' + f'`embed_dims`, but got: {mask_size}.') + + self.mutable_attrs['embed_dims'] = mutable_embed_dims + + def _register_mutable_q_embed_dims(self, mutable_q_embed_dims): + """Register intermediate mutable embedding dimension.""" + assert hasattr(self, 'mutable_attrs') + self.mutable_attrs['q_embed_dims'] = mutable_q_embed_dims + + def _get_dynamic_proj_params(self, w: nn.Linear) -> Tuple[Tensor, Tensor]: + """Get parameters of dynamic projection. + + Note: + The input dimension is decided by `mutable_q_embed_dims`. + The output dimension is decided by `mutable_embed_dims`. + """ + # TODO support mask + if self.mutable_embed_dims is None and \ + self.mutable_q_embed_dims is None: + return w.weight, w.bias + + if self.mutable_q_embed_dims is not None: + in_features = self.mutable_q_embed_dims.activated_channels + else: + in_features = self.embed_dims + + if self.mutable_embed_dims is not None: + out_features = self.mutable_embed_dims.activated_channels + else: + out_features = self.embed_dims + + weight = w.weight[:out_features, :in_features] + bias = w.bias[:out_features] if w.bias is not None else None + + return weight, bias + + def _get_dynamic_qkv_params(self, w: nn.Linear) -> Tuple[Tensor, Tensor]: + """Get parameters of dynamic QKV. + + Note: + The output dimension is decided by `mutable_q_embed_dims`. + The input dimension is decided by `mutable_embed_dims`. + """ + # TODO support mask later + if self.mutable_q_embed_dims is None and \ + self.mutable_embed_dims is None: + return w.weight, w.bias + + if self.mutable_embed_dims is not None: + in_features = self.mutable_embed_dims.activated_channels + else: + in_features = self.embed_dims + + if self.mutable_q_embed_dims is not None: + out_features = self.mutable_q_embed_dims.activated_channels + else: + out_features = self.mutable_q_embed_dims + + weight = w.weight[:out_features, :in_features] + bias = w.bias[:out_features] if w.bias is not None else None + + return weight, bias + + def to_static_op(self) -> MultiheadAttention: + """Convert dynamic MultiheadAttention to static one.""" + self.check_if_mutables_fixed() + + embed_dims = self.mutable_embed_dims.activated_channels + num_heads = self.mutable_num_heads.current_choice + + q_w, q_b = self._get_dynamic_qkv_params(self.w_qs) + k_w, k_b = self._get_dynamic_qkv_params(self.w_ks) + v_w, v_b = self._get_dynamic_qkv_params(self.w_vs) + + proj_w, proj_b = self._get_dynamic_proj_params(self.proj) + + static_mha = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + input_dims=None, + attn_drop_rate=self.attn_drop_rate, + relative_position=self.relative_position, + max_relative_position=self.max_relative_position) + + static_mha.w_qs.weight = nn.Parameter(q_w.clone()) + static_mha.w_qs.bias = nn.Parameter(q_b.clone()) + + static_mha.w_ks.weight = nn.Parameter(k_w.clone()) + static_mha.w_ks.bias = nn.Parameter(k_b.clone()) + + static_mha.w_vs.weight = nn.Parameter(v_w.clone()) + static_mha.w_vs.bias = nn.Parameter(v_b.clone()) + + static_mha.proj.weight = nn.Parameter(proj_w.clone()) + static_mha.proj.bias = nn.Parameter(proj_b.clone()) + + if self.relative_position: + static_mha.rel_pos_embed_k = self.rel_pos_embed_k.to_static_op() + static_mha.rel_pos_embed_v = self.rel_pos_embed_v.to_static_op() + + return static_mha + + @classmethod + def convert_from(cls, module): + """Convert the static module to dynamic one.""" + dynamic_mha = cls( + embed_dims=module.embed_dims, + num_heads=module.num_heads, + ) + return dynamic_mha + + def static_op_factory(self): + """Corresponding Pytorch OP.""" + return MultiheadAttention + + def forward(self, x: Tensor) -> Tensor: + """Forward of dynamic MultiheadAttention.""" + B, N = x.shape[0], x.shape[1] + q_w, q_b = self._get_dynamic_qkv_params(self.w_qs) + k_w, k_b = self._get_dynamic_qkv_params(self.w_ks) + v_w, v_b = self._get_dynamic_qkv_params(self.w_vs) + + q_embed_dims = self.mutable_q_embed_dims.activated_channels + num_heads = self.mutable_num_heads.current_choice + + q = F.linear(x, q_w, q_b).view(B, N, num_heads, + q_embed_dims // num_heads) + k = F.linear(x, k_w, k_b).view(B, N, num_heads, + q_embed_dims // num_heads) + v = F.linear(x, v_w, v_b).view(B, N, num_heads, + q_embed_dims // num_heads) + + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if self.relative_position: + r_p_k = self.rel_pos_embed_k(N, N) + attn = attn + (q.permute(2, 0, 1, 3).reshape(N, num_heads * B, -1) # noqa: E501 + @ r_p_k.transpose(2, 1)) \ + .transpose(1, 0).reshape(B, num_heads, N, N) * self.scale + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + + if self.relative_position: + r_p_v = self.rel_pos_embed_v(N, N) + attn_1 = attn.permute(2, 0, 1, 3).reshape(N, B * num_heads, -1) + x = x + (attn_1 @ r_p_v).transpose(1, 0).reshape( + B, num_heads, N, -1).transpose(2, 1).reshape(B, N, -1) + + # proj + weight, bias = self._get_dynamic_proj_params(self.proj) + x = F.linear(x, weight, bias) + x = self.proj_drop(x) + return x diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py index e3e795fa4..4ac153dc2 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py @@ -4,11 +4,12 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor +from torch.nn import LayerNorm from torch.nn.modules.batchnorm import _BatchNorm from mmrazor.models.mutables.base_mutable import BaseMutable from mmrazor.registry import MODELS -from ..mixins.dynamic_mixins import DynamicBatchNormMixin +from ..mixins import DynamicBatchNormMixin, DynamicLayerNormMixin class _DynamicBatchNorm(_BatchNorm, DynamicBatchNormMixin): @@ -91,6 +92,7 @@ class DynamicBatchNorm1d(_DynamicBatchNorm): @property def static_op_factory(self): + """Corresponding Pytorch OP.""" return nn.BatchNorm1d def _check_input_dim(self, input: Tensor) -> None: @@ -106,6 +108,7 @@ class DynamicBatchNorm2d(_DynamicBatchNorm): @property def static_op_factory(self): + """Corresponding Pytorch OP.""" return nn.BatchNorm2d def _check_input_dim(self, input: Tensor) -> None: @@ -121,6 +124,7 @@ class DynamicBatchNorm3d(_DynamicBatchNorm): @property def static_op_factory(self): + """Corresponding Pytorch OP.""" return nn.BatchNorm3d def _check_input_dim(self, input: Tensor) -> None: @@ -190,3 +194,61 @@ def _check_candidates(self, candidates: List): def static_op_factory(self): """Return initializer of static op.""" return nn.BatchNorm2d + + +@MODELS.register_module() +class DynamicLayerNorm(LayerNorm, DynamicLayerNormMixin): + """Applies Layer Normalization over a mini-batch of inputs according to the + `mutable_num_channels` dynamically. + + Note: + Arguments for ``__init__`` of ``DynamicLayerNorm`` is totally same as + :obj:`torch.nn.LayerNorm`. + Attributes: + mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes, + such as `num_features`. The key of the dict must in + ``accepted_mutable_attrs``. + """ + accepted_mutable_attrs = {'num_features'} + + def __init__(self, *args, **kwargs): + super(DynamicLayerNorm, self).__init__(*args, **kwargs) + + self.mutable_attrs: Dict[str, Optional[BaseMutable]] = nn.ModuleDict() + + @property + def static_op_factory(self): + """Corresponding Pytorch OP.""" + return LayerNorm + + @classmethod + def convert_from(cls, module: LayerNorm): + """Convert a _BatchNorm module to a DynamicBatchNorm. + + Args: + module (:obj:`torch.nn._BatchNorm`): The original BatchNorm module. + """ + dynamic_ln = cls( + normalized_shape=module.normalized_shape, + eps=module.eps, + elementwise_affine=module.elementwise_affine) + + return dynamic_ln + + def forward(self, input: Tensor) -> Tensor: + """Slice the parameters according to `mutable_num_channels`, and + forward.""" + self._check_input_dim(input) + + weight, bias = self.get_dynamic_params() + self.normalized_shape = ( + self.mutable_num_features.activated_channels, ) + + return F.layer_norm(input, self.normalized_shape, weight, bias, + self.eps) + + def _check_input_dim(self, input: Tensor) -> None: + """Check if input dimension is valid.""" + if input.dim() != 3: + raise ValueError('expected 3D input (got {}D input)'.format( + input.dim())) diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_relative_position.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_relative_position.py new file mode 100644 index 000000000..572880a43 --- /dev/null +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_relative_position.py @@ -0,0 +1,154 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from typing import Dict, Set + +import torch +from mmengine import print_log +from torch import Tensor, nn + +from mmrazor.models.architectures.ops import RelativePosition2D +from mmrazor.models.mutables.base_mutable import BaseMutable +from ..mixins import DynamicChannelMixin + + +class DynamicRelativePosition2D(RelativePosition2D, DynamicChannelMixin): + """Searchable RelativePosition module. + + Note: + Arguments for ``__init__`` of ``DynamicRelativePosition2D`` is totally + same as :obj:`mmrazor.models.architectures.RelativePosition2D`. + Attributes: + mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes, + such as `head_dims`. The key of the dict must in + ``accepted_mutable_attrs``. + """ + + mutable_attrs: nn.ModuleDict + head_dims: int + max_relative_position: int + embeddings_table_v: nn.Parameter + embeddings_table_h: nn.Parameter + accepted_mutable_attrs: Set[str] = {'head_dims'} + attr_mappings: Dict[str, str] = { + 'in_channels': 'head_dims', + 'out_channels': 'head_dims', + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict() + + @property + def mutable_head_dims(self): + """Mutable head dimension.""" + assert hasattr(self, 'mutable_attrs') + return self.mutable_attrs['head_dims'] + + def register_mutable_attr(self, attr: str, mutable: BaseMutable): + """Register attribute of mutable.""" + self.check_mutable_attr_valid(attr) + if attr in self.attr_mappings: + attr_map = self.attr_mappings[attr] + assert attr_map in self.accepted_mutable_attrs + if attr_map in self.mutable_attrs: + print_log( + f'{attr_map}({attr}) is already in `mutable_attrs`', + level=logging.WARNING) + else: + self._register_mutable_attr(attr_map, mutable) + elif attr in self.accepted_mutable_attrs: + self._register_mutable_attr(attr, mutable) + else: + raise NotImplementedError + + def _register_mutable_attr(self, attr, mutable): + """Register `head_dims`""" + if attr == 'head_dims': + self._registry_mutable_head_dims(mutable) + else: + raise NotImplementedError + + def _registry_mutable_head_dims(self, + mutable_head_dims: BaseMutable) -> None: + """Register head dimension.""" + assert hasattr(self, 'mutable_attrs') + self.mutable_attrs['head_dims'] = mutable_head_dims + + def to_static_op(self) -> nn.Module: + """Convert dynamic RelativePosition2D to static One.""" + self.check_if_mutables_fixed() + assert self.mutable_head_dims is not None + + self.current_head_dim = self.mutable_head_dims.activated_channels + static_relative_position = self.static_op_factory( + self.current_head_dim) + static_relative_position.embeddings_table_v = \ + nn.Parameter( + self.embeddings_table_v[:, :self.current_head_dim].clone()) + static_relative_position.embeddings_table_h = \ + nn.Parameter( + self.embeddings_table_h[:, :self.current_head_dim].clone()) + + return static_relative_position + + @property + def static_op_factory(self): + """Corresponding Pytorch OP.""" + return RelativePosition2D + + @classmethod + def convert_from(cls, module): + """Convert a RP to a dynamic RP.""" + dynamic_rp = cls( + head_dims=module.head_dims, + max_relative_position=module.max_relative_position) + return dynamic_rp + + def forward(self, length_q, length_k) -> Tensor: + """Forward of Dynamic Relative Position.""" + if self.mutable_head_dims is None: + self.current_head_dim = self.head_dims + else: + self.current_head_dim = self.mutable_head_dims.activated_channels + + self.sample_eb_table_h = self.embeddings_table_h[:, :self. + current_head_dim] + self.sample_eb_table_v = self.embeddings_table_v[:, :self. + current_head_dim] + + # remove the first cls token distance computation + length_q = length_q - 1 + length_k = length_k - 1 + range_vec_q = torch.arange(length_q) + range_vec_k = torch.arange(length_k) + # compute the row and column distance + distance_mat_v = ( + range_vec_k[None, :] // int(length_q**0.5) - + range_vec_q[:, None] // int(length_q**0.5)) + distance_mat_h = ( + range_vec_k[None, :] % int(length_q**0.5) - + range_vec_q[:, None] % int(length_q**0.5)) + distance_mat_clipped_v = torch.clamp(distance_mat_v, + -self.max_relative_position, + self.max_relative_position) + distance_mat_clipped_h = torch.clamp(distance_mat_h, + -self.max_relative_position, + self.max_relative_position) + + final_mat_v = distance_mat_clipped_v + self.max_relative_position + 1 + final_mat_h = distance_mat_clipped_h + self.max_relative_position + 1 + # pad the 0 which represent the cls token + final_mat_v = torch.nn.functional.pad(final_mat_v, (1, 0, 1, 0), + 'constant', 0) + final_mat_h = torch.nn.functional.pad(final_mat_h, (1, 0, 1, 0), + 'constant', 0) + + final_mat_v = torch.LongTensor(final_mat_v) + final_mat_h = torch.LongTensor(final_mat_h) + # get the embeddings with the corresponding distance + + embeddings = self.sample_eb_table_v[final_mat_v] + \ + self.sample_eb_table_h[final_mat_h] + + return embeddings diff --git a/mmrazor/models/architectures/dynamic_ops/head/__init__.py b/mmrazor/models/architectures/dynamic_ops/head/__init__.py new file mode 100644 index 000000000..a9da44d6e --- /dev/null +++ b/mmrazor/models/architectures/dynamic_ops/head/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dynamic_linear_head import DynamicLinearClsHead # noqa: F401 + +__all__ = ['DynamicLinearClsHead'] diff --git a/mmrazor/models/architectures/dynamic_ops/head/dynamic_linear_head.py b/mmrazor/models/architectures/dynamic_ops/head/dynamic_linear_head.py new file mode 100644 index 000000000..9053a4775 --- /dev/null +++ b/mmrazor/models/architectures/dynamic_ops/head/dynamic_linear_head.py @@ -0,0 +1,80 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import Optional, Tuple + +import torch +from mmcls.models import ClsHead + +from mmrazor.models.mutables.base_mutable import BaseMutable +from mmrazor.models.mutables.mutable_channel import MutableChannelContainer +from mmrazor.models.mutables.mutable_channel.units import \ + OneShotMutableChannelUnit +from mmrazor.registry import MODELS +from ..bricks.dynamic_linear import DynamicLinear + + +class DynamicHead: + + @abstractmethod + def connect_with_backbone(self, + backbone_output_mutable: BaseMutable) -> None: + """Connect with Dynamic Backbone.""" + ... + + +@MODELS.register_module() +class DynamicLinearClsHead(ClsHead, DynamicHead): + """Dynamic Linear classification head for Autoformer. + + Args: + num_classes (int): Number of classes. + in_channels (int): Number of input channels. + init_cfg (Optional[dict], optional): Init config. + Defaults to dict(type='Normal', + layer='DynamicLinear', std=0.01). + """ + + def __init__(self, + num_classes: int = 1000, + in_channels: int = 624, + init_cfg: Optional[dict] = dict( + type='Normal', layer='DynamicLinear', std=0.01), + **kwargs): + super().__init__(init_cfg=init_cfg, **kwargs) + + self.in_channels = in_channels + self.num_classes = num_classes + + if self.num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + self.fc = DynamicLinear(self.in_channels, self.num_classes) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``LinearClsHead``, we just obtain the + feature of the last stage. + """ + # The LinearClsHead doesn't have other module, just return after + # unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.fc(pre_logits) + return cls_score + + def connect_with_backbone(self, + backbone_output_mutable: BaseMutable) -> None: + """Connect dynamic backbone.""" + + OneShotMutableChannelUnit._register_channel_container( + self, MutableChannelContainer) + + MutableChannelContainer.register_mutable_channel_to_module( + self.fc, backbone_output_mutable, False) diff --git a/mmrazor/models/architectures/dynamic_ops/mixins/__init__.py b/mmrazor/models/architectures/dynamic_ops/mixins/__init__.py index 7a5097bc5..e97f7ad78 100644 --- a/mmrazor/models/architectures/dynamic_ops/mixins/__init__.py +++ b/mmrazor/models/architectures/dynamic_ops/mixins/__init__.py @@ -1,9 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. from .dynamic_conv_mixins import DynamicConvMixin +from .dynamic_layernorm_mixins import DynamicLayerNormMixin from .dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin, DynamicLinearMixin, DynamicMixin) __all__ = [ - 'DynamicChannelMixin', 'DynamicBatchNormMixin', 'DynamicLinearMixin', - 'DynamicMixin', 'DynamicConvMixin' + 'DynamicChannelMixin', + 'DynamicBatchNormMixin', + 'DynamicLinearMixin', + 'DynamicMixin', + 'DynamicConvMixin', + 'DynamicLayerNormMixin', ] diff --git a/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_layernorm_mixins.py b/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_layernorm_mixins.py new file mode 100644 index 000000000..785be9935 --- /dev/null +++ b/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_layernorm_mixins.py @@ -0,0 +1,147 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from typing import Dict, Optional, Set, Tuple + +import torch +from mmengine import print_log +from torch import Tensor, nn +from torch.nn import LayerNorm + +from mmrazor.models.mutables.base_mutable import BaseMutable +from .dynamic_mixins import DynamicChannelMixin + + +class DynamicLayerNormMixin(DynamicChannelMixin): + """A mixin class for Pytorch LayerNorm, which can mutate + ``num_features``.""" + accepted_mutable_attrs: Set[str] = {'num_features'} + + attr_mappings: Dict[str, str] = { + 'in_channels': 'num_features', + 'out_channels': 'num_features', + } + + @property + def num_features(self): + return getattr(self, 'normalized_shape')[0] + + @property + def mutable_num_features(self): + """Mutable number of features.""" + assert hasattr(self, 'mutable_attrs') + return self.mutable_attrs['num_features'] + + def register_mutable_attr(self, attr, mutable): + """Register attribute of mutable.""" + self.check_mutable_attr_valid(attr) + if attr in self.attr_mappings: + attr_map = self.attr_mappings[attr] + assert attr_map in self.accepted_mutable_attrs + if attr_map in self.mutable_attrs: + print_log( + f'{attr_map}({attr}) is already in `mutable_attrs`', + level=logging.WARNING) + else: + self._register_mutable_attr(attr_map, mutable) + elif attr in self.accepted_mutable_attrs: + self._register_mutable_attr(attr, mutable) + else: + raise NotImplementedError + + def _register_mutable_attr(self, attr, mutable): + """Register `num_features`.""" + if attr == 'num_features': + self._register_mutable_num_features(mutable) + else: + raise NotImplementedError + + def _register_mutable_num_features( + self: LayerNorm, mutable_num_features: BaseMutable) -> None: + """Mutate ``num_features`` with given mutable. + + Args: + mutable_num_features (BaseMutable): Mutable for controlling + ``num_features``. + Raises: + RuntimeError: Error if both ``affine`` and + ``tracking_running_stats`` are False. + ValueError: Error if size of mask if not same as ``num_features``. + """ + if not self.elementwise_affine: + raise RuntimeError( + 'num_features can not be mutated if both `affine` and ' + '`tracking_running_stats` are False') + + self.check_mutable_channels(mutable_num_features) + mask_size = mutable_num_features.current_mask.size(0) + + # normalized_shape is a tuple + if mask_size != self.normalized_shape[0]: + raise ValueError( + f'Expect mask size of mutable to be {self.normalized_shape}' + f' as `normalized_shape`, but got: {mask_size}.') + + self.mutable_attrs['num_features'] = mutable_num_features + + def _get_num_features_mask(self: LayerNorm) -> Optional[torch.Tensor]: + """Get mask of ``num_features``.""" + if self.elementwise_affine: + refer_tensor = self.weight + else: + return None + + if 'num_features' in self.mutable_attrs: + out_mask = self.mutable_num_features.current_mask.to( + refer_tensor.device) + else: + out_mask = torch.ones_like(refer_tensor).bool() + + return out_mask + + def get_dynamic_params( + self: LayerNorm) -> Tuple[Optional[Tensor], Optional[Tensor]]: + """Get dynamic parameters that will be used in forward process. + + Returns: + Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], + Optional[Tensor]]: Sliced running_mean, running_var, weight and + bias. + """ + out_mask = self._get_num_features_mask() + + if self.elementwise_affine: + weight = self.weight[out_mask] + bias = self.bias[out_mask] + else: + weight, bias = self.weight, self.bias + + return weight, bias + + def to_static_op(self: LayerNorm) -> nn.Module: + """Convert dynamic LayerNormxd to :obj:`torch.nn.LayerNormxd`. + + Returns: + torch.nn.LayerNormxd: :obj:`torch.nn.LayerNormxd` with sliced + parameters. + """ + self.check_if_mutables_fixed() + + weight, bias = self.get_dynamic_params() + + if 'num_features' in self.mutable_attrs: + num_features = self.mutable_attrs['num_features'].current_mask.sum( + ).item() + else: + num_features = self.num_features + + static_ln = self.static_op_factory( + normalized_shape=num_features, + eps=self.eps, + elementwise_affine=self.elementwise_affine) + + if weight is not None: + static_ln.weight = nn.Parameter(weight.clone()) + if bias is not None: + static_ln.bias = nn.Parameter(bias.clone()) + + return static_ln diff --git a/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_mixins.py b/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_mixins.py index 4837bb12b..f11701517 100644 --- a/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_mixins.py +++ b/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_mixins.py @@ -74,12 +74,16 @@ def check_if_mutables_fixed(self) -> None: Raises: RuntimeError: Error if a existing mutable is not fixed. """ + from mmrazor.models.mutables import (DerivedMutable, + MutableChannelContainer) def check_fixed(mutable: Optional[BaseMutable]) -> None: if mutable is not None and not mutable.is_fixed: raise RuntimeError(f'Mutable {type(mutable)} is not fixed.') for mutable in self.mutable_attrs.values(): # type: ignore + if isinstance(mutable, (MutableChannelContainer, DerivedMutable)): + continue check_fixed(mutable) def check_mutable_attr_valid(self, attr): @@ -115,6 +119,11 @@ class DynamicChannelMixin(DynamicMixin): ``mutable_out_channels`` APIs. """ + attr_mappings: Dict[str, str] = { + 'in_channels': 'in_channels', + 'out_channels': 'out_channels', + } + @staticmethod def check_mutable_channels(mutable_channels: BaseMutable) -> None: """Check if mutable has `currnet_mask` attribute. diff --git a/mmrazor/models/architectures/ops/__init__.py b/mmrazor/models/architectures/ops/__init__.py index cc9862dd7..d3f7e414a 100644 --- a/mmrazor/models/architectures/ops/__init__.py +++ b/mmrazor/models/architectures/ops/__init__.py @@ -6,9 +6,11 @@ from .gather_tensors import GatherTensors from .mobilenet_series import MBBlock from .shufflenet_series import ShuffleBlock, ShuffleXception +from .transformer_series import MultiheadAttention, RelativePosition2D __all__ = [ 'ShuffleBlock', 'ShuffleXception', 'DartsPoolBN', 'DartsDilConv', 'DartsSepConv', 'DartsSkipConnect', 'DartsZero', 'MBBlock', 'Identity', - 'ConvBnAct', 'DepthwiseSeparableConv', 'GatherTensors' + 'ConvBnAct', 'DepthwiseSeparableConv', 'GatherTensors', + 'RelativePosition2D', 'MultiheadAttention' ] diff --git a/mmrazor/models/architectures/ops/transformer_series.py b/mmrazor/models/architectures/ops/transformer_series.py new file mode 100644 index 000000000..d1bdadf86 --- /dev/null +++ b/mmrazor/models/architectures/ops/transformer_series.py @@ -0,0 +1,192 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional + +import torch +import torch.nn as nn +from mmengine.model.weight_init import trunc_normal_ + + +class RelativePosition2D(nn.Module): + """Rethinking and Improving Relative Position Encoding for Vision + Transformer. + + ICCV 2021. https://arxiv.org/pdf/2107.14222.pdf + Image RPE (iRPE for short) methods are new relative position encoding + methods dedicated to 2D images. + Args: + head_dims (int): embedding dims of relative position. + max_relative_position (int): The max relative position distance. + """ + + def __init__(self, head_dims: int, max_relative_position: int = 14): + super().__init__() + + self.head_dims = head_dims + self.max_relative_position = max_relative_position + # The first element in embeddings_table_v is the vertical embedding + # for the class + self.embeddings_table_v = nn.Parameter( + torch.randn(max_relative_position * 2 + 2, head_dims)) + self.embeddings_table_h = nn.Parameter( + torch.randn(max_relative_position * 2 + 2, head_dims)) + + trunc_normal_(self.embeddings_table_v, std=.02) + trunc_normal_(self.embeddings_table_h, std=.02) + + def forward(self, length_q, length_k): + # remove the first cls token distance computation + length_q = length_q - 1 + length_k = length_k - 1 + range_vec_q = torch.arange(length_q) + range_vec_k = torch.arange(length_k) + # compute the row and column distance + distance_mat_v = ( + range_vec_k[None, :] // int(length_q**0.5) - + range_vec_q[:, None] // int(length_q**0.5)) + distance_mat_h = ( + range_vec_k[None, :] % int(length_q**0.5) - + range_vec_q[:, None] % int(length_q**0.5)) + # clip the distance to the range of + # [-max_relative_position, max_relative_position] + distance_mat_clipped_v = torch.clamp(distance_mat_v, + -self.max_relative_position, + self.max_relative_position) + distance_mat_clipped_h = torch.clamp(distance_mat_h, + -self.max_relative_position, + self.max_relative_position) + + # translate the distance from [1, 2 * max_relative_position + 1], + # 0 is for the cls token + final_mat_v = distance_mat_clipped_v + self.max_relative_position + 1 + final_mat_h = distance_mat_clipped_h + self.max_relative_position + 1 + # pad the 0 which represent the cls token + final_mat_v = torch.nn.functional.pad(final_mat_v, (1, 0, 1, 0), + 'constant', 0) + final_mat_h = torch.nn.functional.pad(final_mat_h, (1, 0, 1, 0), + 'constant', 0) + + final_mat_v = torch.LongTensor(final_mat_v) + final_mat_h = torch.LongTensor(final_mat_h) + # get the embeddings with the corresponding distance + embeddings = self.embeddings_table_v[ + final_mat_v] + self.embeddings_table_h[final_mat_h] + + return embeddings + + +class MultiheadAttention(nn.Module): + """Multi-head Attention Module with iRPE. + + This module implements multi-head attention that supports different input + dims and embed dims. And it also supports a shortcut from ``value``, which + is useful if input dims is not the same with embed dims. + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + input_dims (int, optional): The input dimension, and if None, + use ``embed_dims``. Defaults to None. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + dropout_layer (dict): The dropout config before adding the shortcut. + Defaults to ``dict(type='Dropout', drop_prob=0.)``. + relative_position (bool, optional): Whether use relative position. + Defaults to True. + max_relative_position (int): The max relative position distance. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + v_shortcut (bool): Add a shortcut from value to output. It's usually + used if ``input_dims`` is different from ``embed_dims``. + Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + input_dims: Optional[int] = None, + attn_drop_rate: float = 0., + proj_drop: float = 0., + dropout_layer: Dict = dict(type='Dropout', drop_prob=0.), + relative_position: Optional[bool] = True, + max_relative_position: int = 14, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + proj_bias: bool = True, + v_shortcut: bool = False, + init_cfg: Optional[dict] = None): + super().__init__() + + self.input_dims = input_dims or embed_dims + self.embed_dims = embed_dims + self.num_heads = num_heads + self.v_shortcut = v_shortcut + self.relative_position = relative_position + self.max_relative_position = max_relative_position + self.attn_drop_rate = attn_drop_rate + + self.head_dims = 64 # unit + self.scale = qk_scale or self.head_dims**-0.5 + + self.q_embed_dims = num_heads * self.head_dims + + self.w_qs = nn.Linear( + self.input_dims, num_heads * self.head_dims, bias=qkv_bias) + self.w_ks = nn.Linear( + self.input_dims, num_heads * self.head_dims, bias=qkv_bias) + self.w_vs = nn.Linear( + self.input_dims, num_heads * self.head_dims, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear( + num_heads * self.head_dims, embed_dims, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.out_drop = nn.Dropout(dropout_layer['drop_prob']) + + # image relative position encoding + if self.relative_position: + self.rel_pos_embed_k = RelativePosition2D( + self.head_dims, self.max_relative_position) + self.rel_pos_embed_v = RelativePosition2D( + self.head_dims, self.max_relative_position) + + def forward(self, x): + B, N, _ = x.shape + + q = self.w_qs(x).view(B, N, self.num_heads, self.head_dims) + k = self.w_ks(x).view(B, N, self.num_heads, self.head_dims) + v = self.w_vs(x).view(B, N, self.num_heads, self.head_dims) + + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if self.relative_position: + r_p_k = self.rel_pos_embed_k(N, N) + attn = attn + (q.permute(2, 0, 1, 3).reshape(N, self.num_heads * B, -1) # noqa: E501 + @ r_p_k.transpose(2, 1)) \ + .transpose(1, 0).reshape(B, self.num_heads, N, N) * self.scale + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + + if self.relative_position: + r_p_v = self.rel_pos_embed_v(N, N) + t_attn = attn.permute(2, 0, 1, 3).reshape(N, B * self.num_heads, + -1) + x = x + (t_attn @ r_p_v).transpose(1, 0).reshape( + B, self.num_heads, N, -1).transpose(2, 1).reshape(B, N, -1) + + x = self.proj(x) + x = self.out_drop(self.proj_drop(x)) + + if self.v_shortcut: + x = v.squeeze(1) + x + return x diff --git a/mmrazor/models/mutables/derived_mutable.py b/mmrazor/models/mutables/derived_mutable.py index ddbf6adeb..5a3f9abb9 100644 --- a/mmrazor/models/mutables/derived_mutable.py +++ b/mmrazor/models/mutables/derived_mutable.py @@ -61,8 +61,15 @@ def _expand_mask_fn( def fn(): mask = mutable.current_mask - expand_num_channels = int(mask.size(0) * expand_ratio) - expand_choice = int(mutable.current_choice * expand_ratio) + if isinstance(expand_ratio, int): + expand_num_channels = mask.size(0) * expand_ratio + expand_choice = mutable.current_choice * expand_ratio + elif isinstance(expand_ratio, float): + expand_num_channels = int(mask.size(0) * expand_ratio) + expand_choice = int(mutable.current_choice * expand_ratio) + else: + raise NotImplementedError( + f'Not support type of expand_ratio: {type(expand_ratio)}') expand_mask = torch.zeros(expand_num_channels).bool() expand_mask[:expand_choice] = True @@ -136,25 +143,62 @@ def derive_same_mutable(self: MutableProtocol) -> 'DerivedMutable': def derive_expand_mutable( self: MutableProtocol, - expand_ratio: Union[int, float]) -> 'DerivedMutable': + expand_ratio: Union[int, BaseMutable, float]) -> 'DerivedMutable': """Derive expand mutable, usually used with `expand_ratio`.""" - choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio) + # avoid circular import + if isinstance(expand_ratio, int): + choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio) + elif isinstance(expand_ratio, float): + choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio) + elif isinstance(expand_ratio, BaseMutable): + current_ratio = expand_ratio.current_choice + choice_fn = _expand_choice_fn(self, expand_ratio=current_ratio) + else: + raise NotImplementedError( + f'Not support type of ratio: {type(expand_ratio)}') mask_fn: Optional[Callable] = None if hasattr(self, 'current_mask'): - mask_fn = _expand_mask_fn(self, expand_ratio=expand_ratio) + if isinstance(expand_ratio, int): + mask_fn = _expand_mask_fn(self, expand_ratio=expand_ratio) + elif isinstance(expand_ratio, float): + mask_fn = _expand_mask_fn(self, expand_ratio=expand_ratio) + elif isinstance(expand_ratio, BaseMutable): + mask_fn = _expand_mask_fn(self, expand_ratio=current_ratio) + else: + raise NotImplementedError( + f'Not support type of ratio: {type(expand_ratio)}') return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn) def derive_divide_mutable(self: MutableProtocol, - ratio: int, + ratio: Union[int, float, BaseMutable], divisor: int = 8) -> 'DerivedMutable': """Derive divide mutable, usually used with `make_divisable`.""" - choice_fn = _divide_choice_fn(self, ratio=ratio, divisor=divisor) + from .mutable_channel import BaseMutableChannel + + # avoid circular import + if isinstance(ratio, int): + choice_fn = _divide_choice_fn(self, ratio=ratio, divisor=divisor) + current_ratio = ratio + elif isinstance(ratio, float): + current_ratio = int(ratio) + choice_fn = _divide_choice_fn(self, ratio=current_ratio, divisor=1) + elif isinstance(ratio, BaseMutable): + current_ratio = int(ratio.current_choice) + choice_fn = _divide_choice_fn(self, ratio=current_ratio, divisor=1) + else: + raise NotImplementedError( + f'Not support type of ratio: {type(ratio)}') mask_fn: Optional[Callable] = None - if hasattr(self, 'current_mask'): - mask_fn = _divide_mask_fn(self, ratio=ratio, divisor=divisor) + if isinstance(self, BaseMutableChannel) and hasattr( + self, 'current_mask'): + mask_fn = _divide_mask_fn( + self, ratio=current_ratio, divisor=divisor) + elif getattr(self, 'mask_fn', None): # OneShotMutableChannel + mask_fn = _divide_mask_fn( + self, ratio=current_ratio, divisor=divisor) return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn) diff --git a/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py b/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py index 9292d64c8..f59929b27 100644 --- a/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py +++ b/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py @@ -3,9 +3,9 @@ import torch -from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin from mmrazor.registry import MODELS from mmrazor.utils import IndexDict +from ...architectures.dynamic_ops.mixins import DynamicChannelMixin from .base_mutable_channel import BaseMutableChannel from .simple_mutable_channel import SimpleMutableChannel @@ -66,7 +66,7 @@ def current_mask(self) -> torch.Tensor: def register_mutable(self, mutable_channel: BaseMutableChannel, start: int, end: int): - """Register/Store BaseMutableChannel in the MutableChannelContainer in + """Register/Store BaseMutableChannel in the MutableChannelContainer in the range [start,end)""" self.mutable_channels[(start, end)] = mutable_channel diff --git a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py index 9b891e349..07b85f6c6 100644 --- a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py @@ -82,7 +82,7 @@ def expand_choice_fn(mutable1: 'SquentialMutableChannel', mutable2: OneShotMutableValue) -> Callable: def fn(): - return mutable1.current_choice * mutable2.current_choice + return int(mutable1.current_choice * mutable2.current_choice) return fn @@ -93,9 +93,10 @@ def fn(): mask = mutable1.current_mask max_expand_ratio = mutable2.max_choice current_expand_ratio = mutable2.current_choice - expand_num_channels = mask.size(0) * max_expand_ratio + expand_num_channels = int(mask.size(0) * max_expand_ratio) - expand_choice = mutable1.current_choice * current_expand_ratio + expand_choice = int(mutable1.current_choice * + current_expand_ratio) expand_mask = torch.zeros(expand_num_channels).bool() expand_mask[:expand_choice] = True @@ -113,10 +114,17 @@ def fn(): def __floordiv__(self, other) -> DerivedMutable: if isinstance(other, int): return self.derive_divide_mutable(other) + elif isinstance(other, float): + return self.derive_divide_mutable(int(other)) if isinstance(other, tuple): assert len(other) == 2 return self.derive_divide_mutable(*other) + from ..mutable_value import OneShotMutableValue + if isinstance(other, OneShotMutableValue): + ratio = other.current_choice + return self.derive_divide_mutable(ratio) + raise TypeError(f'Unsupported type {type(other)} for div!') def _num2ratio(self, choice: Union[int, float]) -> float: diff --git a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py index e494b4018..c68b1f491 100644 --- a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py @@ -262,14 +262,16 @@ def config_template(self, def add_ouptut_related(self, channel: Channel): """Add a Channel which is output related.""" assert channel.is_output_channel - assert self.num_channels == channel.num_channels + assert self.num_channels == \ + int(channel.num_channels // channel.expand_ratio) if channel not in self.output_related: self.output_related.append(channel) def add_input_related(self, channel: Channel): """Add a Channel which is input related.""" assert channel.is_output_channel is False - assert self.num_channels == channel.num_channels + assert self.num_channels == \ + int(channel.num_channels // channel.expand_ratio) if channel not in self.input_related: self.input_related.append(channel) diff --git a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py index 59039cd83..748b2333b 100644 --- a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py @@ -6,26 +6,23 @@ import torch.nn as nn -from mmrazor.models.architectures import dynamic_ops from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin from mmrazor.models.mutables import DerivedMutable from mmrazor.models.mutables.mutable_channel import (BaseMutableChannel, MutableChannelContainer) +from mmrazor.models.mutables.mutable_value import MutableValue from .channel_unit import Channel, ChannelUnit class MutableChannelUnit(ChannelUnit): - # init methods def __init__(self, num_channels: int, **kwargs) -> None: """MutableChannelUnit inherits from ChannelUnit, which manages channels - with channel-dependency. - - Compared with ChannelUnit, MutableChannelUnit defines the core - interfaces for pruning. By inheriting MutableChannelUnit, - we can implement a variant pruning and nas algorithm. + with channel-dependency. Compared with ChannelUnit, MutableChannelUnit + defines the core interfaces for pruning. By inheriting + MutableChannelUnit, we can implement a variant pruning and nas + algorithm. These apis includes. - These apis includes - basic property - name - is_mutable @@ -60,6 +57,7 @@ def process_container(contanier: MutableChannelContainer, mutable2units, is_output=True): for index, mutable in contanier.mutable_channels.items(): + expand_ratio = 1 if isinstance(mutable, DerivedMutable): source_mutables: Set = \ mutable._trace_source_mutables() @@ -72,6 +70,17 @@ def process_container(contanier: MutableChannelContainer, 'used in DerivedMutable') mutable = list(source_channel_mutables)[0] + source_value_mutables = [ + mutable for mutable in source_mutables + if isinstance(mutable, MutableValue) + ] + assert len(source_value_mutables) <= 1, ( + 'only support one mutable value ' + 'used in DerivedMutable') + expand_ratio = int( + list(source_value_mutables) + [0].current_choice) if source_value_mutables else 1 + if mutable not in mutable2units: mutable2units[mutable] = cls.init_from_mutable_channel( mutable) @@ -83,14 +92,16 @@ def process_container(contanier: MutableChannelContainer, module_name, module, index, - is_output_channel=is_output)) + is_output_channel=is_output, + expand_ratio=expand_ratio)) else: unit.add_input_related( Channel( module_name, module, index, - is_output_channel=is_output)) + is_output_channel=is_output, + expand_ratio=expand_ratio)) mutable2units: Dict = {} for name, module in model.named_modules(): @@ -121,7 +132,7 @@ def traverse(channels: List[Channel]): if channel.is_mutable is False: all_channel_prunable = False break - if isinstance(channel.module, dynamic_ops.DynamicChannelMixin): + if isinstance(channel.module, DynamicChannelMixin): has_dynamic_op = True return has_dynamic_op, all_channel_prunable @@ -223,29 +234,16 @@ def _register_channel_container( model: nn.Module, container_class: Type[MutableChannelContainer]): """register channel container for dynamic ops.""" for module in model.modules(): - if isinstance(module, dynamic_ops.DynamicChannelMixin): + if isinstance(module, DynamicChannelMixin): + in_channels = getattr(module, + module.attr_mappings['in_channels'], 0) if module.get_mutable_attr('in_channels') is None: - in_channels = 0 - if isinstance(module, nn.Conv2d): - in_channels = module.in_channels - elif isinstance(module, nn.modules.batchnorm._BatchNorm): - in_channels = module.num_features - elif isinstance(module, nn.Linear): - in_channels = module.in_features - else: - raise NotImplementedError() module.register_mutable_attr('in_channels', container_class(in_channels)) + out_channels = getattr(module, + module.attr_mappings['out_channels'], 0) if module.get_mutable_attr('out_channels') is None: - out_channels = 0 - if isinstance(module, nn.Conv2d): - out_channels = module.out_channels - elif isinstance(module, nn.modules.batchnorm._BatchNorm): - out_channels = module.num_features - elif isinstance(module, nn.Linear): - out_channels = module.out_features - else: - raise NotImplementedError() + module.register_mutable_attr('out_channels', container_class(out_channels)) @@ -253,7 +251,7 @@ def _register_mutable_channel(self, mutable_channel: BaseMutableChannel): # register mutable_channel for channel in list(self.input_related) + list(self.output_related): module = channel.module - if isinstance(module, dynamic_ops.DynamicChannelMixin): + if isinstance(module, DynamicChannelMixin): container: MutableChannelContainer if channel.is_output_channel and module.get_mutable_attr( 'out_channels') is not None: diff --git a/mmrazor/models/mutables/mutable_channel/units/one_shot_mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/one_shot_mutable_channel_unit.py index 235978cfa..8ba55b25e 100644 --- a/mmrazor/models/mutables/mutable_channel/units/one_shot_mutable_channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/one_shot_mutable_channel_unit.py @@ -39,6 +39,7 @@ def __init__(self, min_ratio=0.9) -> None: super().__init__(num_channels, choice_mode, divisor, min_value, min_ratio) + candidate_choices = copy.copy(candidate_choices) if candidate_choices == []: candidate_choices.append( @@ -50,6 +51,8 @@ def __init__(self, self.candidate_choices, choice_mode) + self.unit_predefined = False + @classmethod def init_from_mutable_channel(cls, mutable_channel: OneShotMutableChannel): unit = cls(mutable_channel.num_channels, @@ -61,7 +64,8 @@ def init_from_mutable_channel(cls, mutable_channel: OneShotMutableChannel): def prepare_for_pruning(self, model: nn.Module): """Prepare for pruning.""" - super().prepare_for_pruning(model) + if not self.unit_predefined: + super().prepare_for_pruning(model) self.current_choice = self.max_choice # ~ diff --git a/mmrazor/models/mutables/mutable_value/mutable_value.py b/mmrazor/models/mutables/mutable_value/mutable_value.py index 20055287d..3df551813 100644 --- a/mmrazor/models/mutables/mutable_value/mutable_value.py +++ b/mmrazor/models/mutables/mutable_value/mutable_value.py @@ -99,7 +99,7 @@ def num_choices(self) -> int: return len(self.choices) @property - def current_choice(self) -> Optional[Any]: + def current_choice(self) -> Value: """Current choice of mutable value.""" return self._current_choice @@ -116,7 +116,7 @@ def __rmul__(self, other) -> DerivedMutable: """Please refer to method :func:`__mul__`.""" return self * other - def __mul__(self, other: int) -> DerivedMutable: + def __mul__(self, other: Union[int, float]) -> DerivedMutable: """Overload `*` operator. Args: @@ -127,7 +127,8 @@ def __mul__(self, other: int) -> DerivedMutable: """ if isinstance(other, int): return self.derive_expand_mutable(other) - + elif isinstance(other, float): + return self.derive_expand_mutable(other) raise TypeError(f'Unsupported type {type(other)} for mul!') def __floordiv__(self, other: Union[int, Tuple[int, @@ -143,6 +144,8 @@ def __floordiv__(self, other: Union[int, Tuple[int, """ if isinstance(other, int): return self.derive_divide_mutable(other) + elif isinstance(other, float): + return self.derive_divide_mutable(int(other)) if isinstance(other, tuple): assert len(other) == 2 return self.derive_divide_mutable(*other) diff --git a/mmrazor/models/mutators/__init__.py b/mmrazor/models/mutators/__init__.py index 82ab48bf4..d11358404 100644 --- a/mmrazor/models/mutators/__init__.py +++ b/mmrazor/models/mutators/__init__.py @@ -3,8 +3,10 @@ SlimmableChannelMutator) from .module_mutator import (DiffModuleMutator, ModuleMutator, OneShotModuleMutator) +from .value_mutator import DynamicValueMutator, ValueMutator __all__ = [ 'OneShotModuleMutator', 'DiffModuleMutator', 'ModuleMutator', - 'ChannelMutator', 'OneShotChannelMutator', 'SlimmableChannelMutator' + 'ChannelMutator', 'OneShotChannelMutator', 'SlimmableChannelMutator', + 'ValueMutator', 'DynamicValueMutator' ] diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.py b/mmrazor/models/mutators/channel_mutator/channel_mutator.py index 7a19f1c72..28c395acd 100644 --- a/mmrazor/models/mutators/channel_mutator/channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.py @@ -358,4 +358,7 @@ def _prepare_from_predefined_model(self, model: Module): units = self.unit_class.init_from_predefined_model(model) + for unit in units: + unit.unit_predefined = self.unit_default_args.pop( + 'unit_predefined', False) return units diff --git a/mmrazor/models/mutators/group_mixin.py b/mmrazor/models/mutators/group_mixin.py index 7e735b263..2575af2f1 100644 --- a/mmrazor/models/mutators/group_mixin.py +++ b/mmrazor/models/mutators/group_mixin.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. - +import sys from collections import Counter from typing import Dict, List, Type @@ -7,6 +7,11 @@ from ..mutables import BaseMutable +if sys.version_info < (3, 8): + from typing_extensions import Protocol +else: + from typing import Protocol + class GroupMixin(): """A mixin for :class:`BaseMutator`, which can group mutables by @@ -220,3 +225,49 @@ def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]], f'When a mutable is set alias attribute :{alias_key},' f'the corresponding module name {mutable_name} should ' f'not be used in `custom_group` {custom_group}.') + + +class MutatorProtocol(Protocol): # pragma: no cover + + @property + def mutable_class_type(self) -> Type[BaseMutable]: + ... + + @property + def search_groups(self) -> Dict: + ... + + +class OneShotSampleMixin: + + def sample_choices(self: MutatorProtocol) -> Dict: + random_choices = dict() + for group_id, modules in self.search_groups.items(): + random_choices[group_id] = modules[0].sample_choice() + + return random_choices + + def set_choices(self: MutatorProtocol, choices: Dict) -> None: + for group_id, modules in self.search_groups.items(): + choice = choices[group_id] + for module in modules: + module.current_choice = choice + + +class DynamicSampleMixin(OneShotSampleMixin): + + @property + def max_choices(self: MutatorProtocol) -> Dict: + max_choices = dict() + for group_id, modules in self.search_groups.items(): + max_choices[group_id] = modules[0].max_choice + + return max_choices + + @property + def min_choices(self: MutatorProtocol) -> Dict: + min_choices = dict() + for group_id, modules in self.search_groups.items(): + min_choices[group_id] = modules[0].min_choice + + return min_choices diff --git a/mmrazor/models/mutators/value_mutator/__init__.py b/mmrazor/models/mutators/value_mutator/__init__.py new file mode 100644 index 000000000..a29577bb1 --- /dev/null +++ b/mmrazor/models/mutators/value_mutator/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dynamic_value_mutator import DynamicValueMutator +from .value_mutator import ValueMutator + +__all__ = ['ValueMutator', 'DynamicValueMutator'] diff --git a/mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py b/mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py new file mode 100644 index 000000000..c65c90f80 --- /dev/null +++ b/mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmrazor.models.mutables import OneShotMutableValue +from mmrazor.registry import MODELS +from ..group_mixin import DynamicSampleMixin +from .value_mutator import ValueMutator + + +@MODELS.register_module() +class DynamicValueMutator(ValueMutator, DynamicSampleMixin): + + @property + def mutable_class_type(self): + return OneShotMutableValue diff --git a/mmrazor/models/mutators/value_mutator/value_mutator.py b/mmrazor/models/mutators/value_mutator/value_mutator.py new file mode 100644 index 000000000..5127cbe37 --- /dev/null +++ b/mmrazor/models/mutators/value_mutator/value_mutator.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Type + +from torch.nn import Module + +from mmrazor.models.mutables import MutableValue +from mmrazor.registry import MODELS +from ..base_mutator import BaseMutator +from ..group_mixin import GroupMixin + + +@MODELS.register_module() +class ValueMutator(BaseMutator[MutableValue], GroupMixin): + """The base class for mutable based mutator. All subclass should implement + the following APIS: + + - ``mutable_class_type`` + Args: + custom_group (list[list[str]], optional): User-defined search groups. + All searchable modules that are not in ``custom_group`` will be + grouped separately. + """ + + def __init__(self, + custom_group: Optional[List[List[str]]] = None, + init_cfg: Optional[Dict] = None) -> None: + super().__init__(init_cfg) + + if custom_group is None: + custom_group = [] + self._custom_group = custom_group + self._search_groups: Optional[Dict[int, List[MutableValue]]] = None + + # TODO + # should be a class property + @property + def mutable_class_type(self) -> Type[MutableValue]: + """Corresponding mutable class type. + + Returns: + Type[MUTABLE_TYPE]: Mutable class type. + """ + return MutableValue + + def prepare_from_supernet(self, supernet: Module) -> None: + """Do some necessary preparations with supernet. + + Note: + For mutable based mutator, we need to build search group first. + Args: + supernet (:obj:`torch.nn.Module`): The supernet to be searched + in your algorithm. + """ + self._search_groups = self.build_search_groups(supernet, + self.mutable_class_type, + self._custom_group) + + @property + def search_groups(self) -> Dict[int, List[MutableValue]]: + """Search group of supernet. + + Note: + For mutable based mutator, the search group is composed of + corresponding mutables. + Raises: + RuntimeError: Called before search group has been built. + Returns: + Dict[int, List[MUTABLE_TYPE]]: Search group. + """ + if self._search_groups is None: + raise RuntimeError( + 'Call `prepare_from_supernet` before access search group!') + return self._search_groups diff --git a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py index df0c867c6..dca740214 100644 --- a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py @@ -498,6 +498,9 @@ def add_flops_params_counter_variable_or_reset(module): module.__params__ = 0 +counter_warning_list = [] + + def get_counter_type(module) -> str: """Get counter type of the module based on the module class name. @@ -515,10 +518,13 @@ def get_counter_type(module) -> str: for base_cls in module.__class__.mro(): if base_cls in get_modules_list(): counter_type = base_cls.__name__ + 'Counter' - from mmengine import MMLogger - logger = MMLogger.get_current_instance() - logger.warning(f'`{old_counter_type}` not in op_counters. ' - f'Using `{counter_type}` instead.') + global counter_warning_list + if old_counter_type not in counter_warning_list: + from mmengine import MMLogger + logger = MMLogger.get_current_instance() + logger.warning(f'`{old_counter_type}` not in op_counters. ' + f'Using `{counter_type}` instead.') + counter_warning_list.append(old_counter_type) break return counter_type diff --git a/mmrazor/structures/subnet/candidate.py b/mmrazor/structures/subnet/candidate.py index f65f0b48b..50691f85e 100644 --- a/mmrazor/structures/subnet/candidate.py +++ b/mmrazor/structures/subnet/candidate.py @@ -1,35 +1,44 @@ # Copyright (c) OpenMMLab. All rights reserved. from collections import UserList -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union class Candidates(UserList): - """The data structure of sampled candidate. The format is [(any, float), - (any, float), ...]. - + """The data structure of sampled candidate. The format is Union[Dict[str, + Dict], List[Dict[str, Dict]]]. Examples: >>> candidates = Candidates() - >>> subnet_1 = {'choice_1': 'layer_1', 'choice_2': 'layer_2'} + >>> subnet_1 = {'1': 'choice1', '2': 'choice2'} >>> candidates.append(subnet_1) >>> candidates - [({'choice_1': 'layer_1', 'choice_2': 'layer_2'}, 0.0)] - >>> candidates.set_score(0, 0.9) + [{"{'1': 'choice1', '2': 'choice2'}": + {'score': 0.0, 'flops': 0.0, 'params': 0.0, 'latency': 0.0}}] + >>> candidates.set_resources(0, 49.9, 'flops') + >>> candidates.set_score(0, 100.) >>> candidates - [({'choice_1': 'layer_1', 'choice_2': 'layer_2'}, 0.9)] + [{"{'1': 'choice1', '2': 'choice2'}": + {'score': 100.0, 'flops': 49.9, 'params': 0.0, 'latency': 0.0}}] >>> subnet_2 = {'choice_3': 'layer_3', 'choice_4': 'layer_4'} - >>> candidates.append((subnet_2, 0.5)) + >>> candidates.append(subnet_2) >>> candidates - [({'choice_1': 'layer_1', 'choice_2': 'layer_2'}, 0.9), - ({'choice_3': 'layer_3', 'choice_4': 'layer_4'}, 0.5)] + [{"{'1': 'choice1', '2': 'choice2'}": + {'score': 100.0, 'flops': 49.9, 'params': 0.0, 'latency': 0.0}}, + {"{'choice_3': 'layer_3', 'choice_4':'layer_4'}": + {'score': 0.0, 'flops': 0.0, 'params': 0.0, 'latency': 0.0}}] >>> candidates.subnets - [{'choice_1': 'layer_1', 'choice_2': 'layer_2'}, + [{'1': 'choice1', '2': 'choice2'}, {'choice_3': 'layer_3', 'choice_4': 'layer_4'}] + >>> candidates.resources('flops') + [49.9, 0.0] >>> candidates.scores - [0.9, 0.5] + [100.0, 0.0] """ - _format_return = Union[Tuple[Any, float], List[Tuple[Any, float]]] + _format_return = Union[Dict[str, Dict], List[Dict[str, Dict]]] + _format_input = Union[Dict, List[Dict], Dict[str, Dict], List[Dict[str, + Dict]]] + _indicators = ('score', 'flops', 'params', 'latency') - def __init__(self, initdata: Optional[Any] = None): + def __init__(self, initdata: Optional[_format_input] = None): self.data = [] if initdata is not None: initdata = self._format(initdata) @@ -41,23 +50,59 @@ def __init__(self, initdata: Optional[Any] = None): @property def scores(self) -> List[float]: """The scores of candidates.""" - return [item[1] for item in self.data] + return [ + round(value.get('score', 0.), 2) for item in self.data + for _, value in item.items() + ] + + def resources(self, key_indicator: str = 'flops') -> List[float]: + """The resources of candidates.""" + assert key_indicator in ['flops', 'params', 'latency'] + return [ + value.get(key_indicator, 0.) for item in self.data + for _, value in item.items() + ] @property def subnets(self) -> List[Dict]: """The subnets of candidates.""" - return [item[0] for item in self.data] + return [eval(key) for item in self.data for key, _ in item.items()] - def _format(self, data: Any) -> _format_return: - """Transform [any, ...] to [tuple(any, float), ...] Transform any to - tuple(any, float).""" + def _format(self, data: _format_input) -> _format_return: + """Transform [Dict, ...] to Union[Dict[str, Dict], List[Dict[str, + Dict]]]. - def _format_item(item: Any): - """Transform any to tuple(any, float).""" - if isinstance(item, tuple): - return (item[0], float(item[1])) + Args: + data: Four types of input are supported: + 1. Dict: only include network information. + 2. List[Dict]: multiple candidates only include network + information. + 3. Dict[str, Dict]: network information and the corresponding + resources. + 4. List[Dict[str, Dict]]: multiple candidate information. + Returns: + Union[Dict[str, Dict], UserList[Dict[str, Dict]]]: + A dict or a list of dict that contains a pair of network + information and the corresponding Score | FLOPs | Params | + Latency results in each candidate. + Notes: + Score | FLOPs | Params | Latency: + 1. a candidate resources with a default value of -1 indicates + that it has not been estimated. + 2. a candidate resources with a default value of 0 indicates + that some indicators have been evaluated. + """ + + def _format_item( + cond: Union[Dict, Dict[str, Dict]]) -> Dict[str, Dict]: + """Transform Dict to Dict[str, Dict].""" + if isinstance(list(cond.values())[0], dict): + for value in list(cond.values()): + for key in list(self._indicators): + value.setdefault(key, 0.) + return cond else: - return (item, 0.) + return {str(cond): {}.fromkeys(self._indicators, -1)} if isinstance(data, UserList): return [_format_item(i) for i in data.data] @@ -68,12 +113,15 @@ def _format_item(item: Any): else: return _format_item(data) - def append(self, item: Any) -> None: + def append(self, item: _format_input) -> None: """Append operation.""" item = self._format(item) - self.data.append(item) + if isinstance(item, list): + self.data = self.data + item + else: + self.data.append(item) - def insert(self, i: int, item: Any) -> None: + def insert(self, i: int, item: _format_input) -> None: """Insert operation.""" item = self._format(item) self.data.insert(i, item) @@ -88,4 +136,35 @@ def extend(self, other: Any) -> None: def set_score(self, i: int, score: float) -> None: """Set score to the specified subnet by index.""" - self.data[i] = (self.data[i][0], float(score)) + self.set_resource(i, score, 'score') + + def set_resource(self, + i: int, + resources: float, + key_indicator: str = 'flops') -> None: + """Set resources to the specified subnet by index.""" + assert key_indicator in ['score', 'flops', 'params', 'latency'] + for _, value in self.data[i].items(): + value[key_indicator] = resources + + def update_resources(self, resources: list, start: int = 0) -> None: + """Update resources to the specified candidate.""" + end = start + len(resources) + assert len( + self.data) >= end, 'Check the number of candidate resources.' + for i, item in enumerate(self.data[start:end]): + for _, value in item.items(): + value.update(resources[i]) + + def sort_by(self, + key_indicator: str = 'score', + reverse: bool = True) -> None: + """Sort by a specific indicator in descending order. + + Args: + key_indicator (str): sort all candidates by key_indicator. + Defaults to 'score'. + reverse (bool): sort all candidates in descending order. + """ + self.data.sort( + key=lambda x: list(x.values())[0][key_indicator], reverse=reverse) diff --git a/mmrazor/structures/subnet/fix_subnet.py b/mmrazor/structures/subnet/fix_subnet.py index 625e65025..538a88dac 100644 --- a/mmrazor/structures/subnet/fix_subnet.py +++ b/mmrazor/structures/subnet/fix_subnet.py @@ -43,13 +43,15 @@ def load_fix_subnet(model: nn.Module, raise RuntimeError('Root model can not be dynamic op.') # Avoid circular import - from mmrazor.models.mutables import DerivedMutable + from mmrazor.models.mutables import DerivedMutable, MutableChannelContainer from mmrazor.models.mutables.base_mutable import BaseMutable for name, module in model.named_modules(): # The format of `chosen`` is different for each type of mutable. # In the corresponding mutable, it will check whether the `chosen` # format is correct. + if isinstance(module, (MutableChannelContainer, DerivedMutable)): + continue if isinstance(module, BaseMutable): if not module.is_fixed: if getattr(module, 'alias', None): @@ -61,8 +63,8 @@ def load_fix_subnet(model: nn.Module, chosen = fix_mutable.get(alias, None) else: mutable_name = name.lstrip(prefix) - if mutable_name not in fix_mutable and \ - not isinstance(module, DerivedMutable): + if mutable_name not in fix_mutable and not isinstance( + module, (DerivedMutable, MutableChannelContainer)): raise RuntimeError( f'The module name {mutable_name} is not in ' 'fix_mutable, please check your `fix_mutable`.') @@ -87,13 +89,15 @@ def export_fix_subnet(model: nn.Module, level=logging.WARNING) # Avoid circular import - from mmrazor.models.mutables import DerivedMutable + from mmrazor.models.mutables import DerivedMutable, MutableChannelContainer from mmrazor.models.mutables.base_mutable import BaseMutable fix_subnet = dict() for name, module in model.named_modules(): if isinstance(module, BaseMutable): - if isinstance(module, DerivedMutable) and not dump_derived_mutable: + if isinstance(module, + (MutableChannelContainer, + DerivedMutable)) and not dump_derived_mutable: continue if module.alias: diff --git a/tests/data/models.py b/tests/data/models.py index 867adc0c9..e45d8af4d 100644 --- a/tests/data/models.py +++ b/tests/data/models.py @@ -3,7 +3,7 @@ from torch import Tensor import torch.nn as nn import torch -from mmrazor.models.architectures.dynamic_ops import DynamicBatchNorm2d, DynamicConv2d, DynamicLinear, DynamicChannelMixin +from mmrazor.models.architectures.dynamic_ops import DynamicBatchNorm2d, DynamicConv2d, DynamicLinear, DynamicChannelMixin, DynamicPatchEmbed, DynamicSequential from mmrazor.models.mutables.mutable_channel import MutableChannelContainer from mmrazor.models.mutables import MutableChannelUnit from mmrazor.models.mutables import DerivedMutable @@ -13,6 +13,9 @@ from mmengine.model import BaseModel # this file includes models for tesing. +from mmrazor.models.mutables import OneShotMutableValue +from mmrazor.models.architectures.backbones.searchable_autoformer import TransformerEncoderLayer + class LinearHead(Module): @@ -475,7 +478,7 @@ def forward(self, x): def register_mutable(module: DynamicChannelMixin, - mutable: OneShotMutableChannelUnit, + mutable: MutableChannelUnit, is_out=True, start=0, end=-1): @@ -581,6 +584,95 @@ def _register_mutable(self): self.linear, mutable2, False) +class DynamicAttention(nn.Module): + """ + x + |blocks: DynamicSequential(depth) + |(blocks) + x1 + |fc (OneShotMutableChannel * OneShotMutableValue) + output + """ + + def __init__(self) -> None: + super().__init__() + + self.mutable_depth = OneShotMutableValue( + value_list=[1, 2], default_value=2) + self.mutable_embed_dims = OneShotMutableChannel( + num_channels=624, candidate_choices=[576, 624]) + self.base_embed_dims = OneShotMutableChannel( + num_channels=64, candidate_choices=[64]) + self.mutable_num_heads = [ + OneShotMutableValue( + value_list=[8, 10], + default_value=10) + for _ in range(2) + ] + self.mutable_mlp_ratios = [ + OneShotMutableValue( + value_list=[3.0, 3.5, 4.0], + default_value=4.0) + for _ in range(2) + ] + self.mutable_q_embed_dims = [ + i * self.base_embed_dims for i in self.mutable_num_heads + ] + + self.patch_embed = DynamicPatchEmbed( + img_size=224, + in_channels=3, + embed_dims=self.mutable_embed_dims.num_channels) + + # cls token and pos embed + self.pos_embed = nn.Parameter( + torch.zeros(1, 197, + self.mutable_embed_dims.num_channels)) + self.cls_token = nn.Parameter( + torch.zeros(1, 1, self.mutable_embed_dims.num_channels)) + + layers = [] + for i in range(self.mutable_depth.max_choice): + layer = TransformerEncoderLayer( + embed_dims=self.mutable_embed_dims.num_channels, + num_heads=self.mutable_num_heads[i].max_choice, + mlp_ratio=self.mutable_mlp_ratios[i].max_choice) + layers.append(layer) + self.blocks = DynamicSequential(*layers) + + # OneShotMutableChannelUnit + OneShotMutableChannelUnit._register_channel_container( + self, MutableChannelContainer) + + self.register_mutables() + + def register_mutables(self): + # mutablevalue + self.blocks.register_mutable_attr('depth', self.mutable_depth) + # mutablechannel + MutableChannelContainer.register_mutable_channel_to_module( + self.patch_embed, self.mutable_embed_dims, True) + + for i in range(self.mutable_depth.max_choice): + layer = self.blocks[i] + layer.register_mutables( + mutable_num_heads=self.mutable_num_heads[i], + mutable_mlp_ratios=self.mutable_mlp_ratios[i], + mutable_q_embed_dims=self.mutable_q_embed_dims[i], + mutable_head_dims=self.base_embed_dims, + mutable_embed_dims=self.mutable_embed_dims) + + def forward(self, x: torch.Tensor): + B = x.shape[0] + x = self.patch_embed(x) + embed_dims = self.mutable_embed_dims.current_choice + cls_tokens = self.cls_token[..., :embed_dims].expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed[..., :embed_dims] + x = self.blocks(x) + return torch.mean(x[:, 1:], dim=1) + + default_models = [ LineModel, ResBlock, diff --git a/tests/test_models/test_algorithms/test_autoformer.py b/tests/test_models/test_algorithms/test_autoformer.py new file mode 100644 index 000000000..d2e4bf014 --- /dev/null +++ b/tests/test_models/test_algorithms/test_autoformer.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from unittest import TestCase + +import torch + +from mmrazor.models import Autoformer +from mmrazor.registry import MODELS + +arch_setting = dict( + mlp_ratios=[3.0, 3.5, 4.0], + num_heads=[8, 9, 10], + depth=[14, 15, 16], + embed_dims=[528, 576, 624]) + +MUTATOR_CFG = dict( + channel_mutator=dict( + type='mmrazor.OneShotChannelMutator', + channel_unit_cfg={ + 'type': 'OneShotMutableChannelUnit', + 'default_args': { + 'unit_predefined': True + } + }, + parse_cfg={'type': 'Predefined'}), + value_mutator=dict(type='mmrazor.DynamicValueMutator')) + +ARCHITECTURE_CFG = dict( + _scope_='mmrazor', + type='SearchableImageClassifier', + backbone=dict( + _scope_='mmrazor', + type='AutoformerBackbone', + arch_setting=arch_setting), + neck=None, + head=dict( + type='DynamicLinearClsHead', + num_classes=1000, + in_channels=624, + loss=dict( + type='mmcls.LabelSmoothLoss', + mode='original', + num_classes=1000, + label_smooth_val=0.1, + loss_weight=1.0), + topk=(1, 5)), + connect_head=dict(connect_with_backbone='backbone.last_mutable'), +) + +ALGORITHM_CFG = dict( + type='mmrazor.Autoformer', + architecture=ARCHITECTURE_CFG, + fix_subnet=None, + mutators=dict( + channel_mutator=dict( + type='mmrazor.OneShotChannelMutator', + channel_unit_cfg={ + 'type': 'OneShotMutableChannelUnit', + 'default_args': { + 'unit_predefined': True + } + }, + parse_cfg={'type': 'Predefined'}), + value_mutator=dict(type='mmrazor.DynamicValueMutator'))) + + +class TestAUTOFORMER(TestCase): + + def test_init(self): + ALGORITHM_CFG_SUPERNET = copy.deepcopy(ALGORITHM_CFG) + # initiate autoformer with built `algorithm`. + autoformer_algo = MODELS.build(ALGORITHM_CFG_SUPERNET) + self.assertIsInstance(autoformer_algo, Autoformer) + # autoformer mutators include channel_mutator and value_mutator + assert 'channel_mutator' in autoformer_algo.mutators + assert 'value_mutator' in autoformer_algo.mutators + + # autoformer search_groups + random_subnet = autoformer_algo.sample_subnet() + self.assertIsInstance(random_subnet, dict) + + # autoformer_algo support training + self.assertTrue(autoformer_algo.is_supernet) + + # initiate autoformer without any `mutator`. + ALGORITHM_CFG_SUPERNET.pop('type') + ALGORITHM_CFG_SUPERNET['mutators'] = None + with self.assertRaisesRegex( + AssertionError, + 'mutator cannot be None when fix_subnet is None.'): + _ = Autoformer(**ALGORITHM_CFG_SUPERNET) + + # initiate autoformer with error type `mutator`. + backwardtracer_cfg = dict( + type='OneShotChannelMutator', + channel_unit_cfg=dict( + type='OneShotMutableChannelUnit', + default_args=dict( + candidate_choices=list(i / 12 for i in range(2, 13)), + choice_mode='ratio')), + parse_cfg=dict( + type='BackwardTracer', + loss_calculator=dict(type='ImageClassifierPseudoLoss'))) + ALGORITHM_CFG_SUPERNET['mutators'] = dict( + channel_mutator=backwardtracer_cfg, + value_mutator=dict(type='mmrazor.DynamicValueMutator')) + with self.assertRaisesRegex(AssertionError, + 'autoformer only support predefined.'): + _ = Autoformer(**ALGORITHM_CFG_SUPERNET) + + def test_loss(self): + # supernet + inputs = torch.randn(1, 3, 224, 224) + autoformer = MODELS.build(ALGORITHM_CFG) + loss = autoformer(inputs) + assert loss.size(1) == 1000 diff --git a/tests/test_models/test_architectures/test_backbones/test_autoformerbackbone.py b/tests/test_models/test_architectures/test_backbones/test_autoformerbackbone.py new file mode 100644 index 000000000..25217d1e8 --- /dev/null +++ b/tests/test_models/test_architectures/test_backbones/test_autoformerbackbone.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmrazor.models.architectures.dynamic_ops import ( + DynamicLinear, DynamicMultiheadAttention, DynamicPatchEmbed, + DynamicRelativePosition2D, DynamicSequential) +from mmrazor.models.mutables import MutableChannelContainer +from mmrazor.registry import MODELS + +arch_setting = dict( + mlp_ratios=[3.0, 3.5, 4.0], + num_heads=[8, 9, 10], + depth=[14, 15, 16], + embed_dims=[528, 576, 624]) + +BACKBONE_CFG = dict( + type='mmrazor.AutoformerBackbone', + arch_setting=arch_setting, + img_size=224, + patch_size=16, + in_channels=3, + norm_cfg=dict(type='mmrazor.DynamicLayerNorm'), + act_cfg=dict(type='GELU')) + + +def test_searchable_autoformer_mutable() -> None: + backbone = MODELS.build(BACKBONE_CFG) + + num_heads = backbone.arch_setting['num_heads'] + mlp_ratios = backbone.arch_setting['mlp_ratios'] + depth = backbone.arch_setting['depth'] + embed_dims = backbone.arch_setting['embed_dims'] + embed_dims_expansion = [i * j for i in mlp_ratios for j in embed_dims] + head_expansion = [i * 64 for i in num_heads] + + for name, module in backbone.named_modules(): + if isinstance(module, DynamicRelativePosition2D): + assert len(module.mutable_head_dims.current_choice) == 64 + elif isinstance(module, DynamicMultiheadAttention): + assert len( + module.mutable_embed_dims.current_choice) == max(embed_dims) + assert len(module.mutable_q_embed_dims.current_choice) == max( + head_expansion) + assert module.mutable_num_heads.choices == num_heads + elif isinstance(module, DynamicLinear): + if 'fc1' in name: + assert module.mutable_attrs['in_features'].num_channels == max( + embed_dims) + assert module.mutable_attrs[ + 'out_features'].num_channels == max(embed_dims_expansion) + elif 'fc2' in name: + assert module.mutable_attrs['in_features'].num_channels == max( + embed_dims_expansion) + assert module.mutable_attrs[ + 'out_features'].num_channels == max(embed_dims) + elif isinstance(module, DynamicPatchEmbed): + assert type(module.mutable_embed_dims) == MutableChannelContainer + assert len( + module.mutable_embed_dims.current_choice) == max(embed_dims) + elif isinstance(module, DynamicSequential): + assert module.mutable_depth.choices == depth + assert backbone.last_mutable.num_channels == max(embed_dims) diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_attention.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_attention.py new file mode 100644 index 000000000..4ed47c0ce --- /dev/null +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_attention.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmrazor.models.architectures.dynamic_ops import DynamicMultiheadAttention +from mmrazor.models.architectures.ops import MultiheadAttention +from mmrazor.models.mutables import (MutableChannelContainer, + OneShotMutableChannel, + OneShotMutableChannelUnit, + OneShotMutableValue) + + +class TestDynamicMHA(TestCase): + + def setUp(self) -> None: + self.mutable_num_heads = OneShotMutableValue( + value_list=[2, 4, 8], default_value=8) + self.mutable_embed_dims = OneShotMutableChannel(num_channels=128) + self.base_embed_dims = OneShotMutableChannel( + num_channels=8, candidate_choices=[8]) + self.mutable_q_embed_dims = self.mutable_num_heads * \ + self.base_embed_dims + + self.dynamic_m = DynamicMultiheadAttention(embed_dims=128, num_heads=8) + + OneShotMutableChannelUnit._register_channel_container( + self.dynamic_m, MutableChannelContainer) + + self.dynamic_m.register_mutable_attr('num_heads', + self.mutable_num_heads) + + MutableChannelContainer.register_mutable_channel_to_module( + self.dynamic_m, self.mutable_embed_dims, False) + MutableChannelContainer.register_mutable_channel_to_module( + self.dynamic_m, self.mutable_q_embed_dims, True, end=64) + MutableChannelContainer.register_mutable_channel_to_module( + self.dynamic_m.rel_pos_embed_k, self.base_embed_dims, False) + MutableChannelContainer.register_mutable_channel_to_module( + self.dynamic_m.rel_pos_embed_v, self.base_embed_dims, False) + + def test_forward(self) -> None: + x = torch.randn(8, 197, 128) + output = self.dynamic_m(x) + self.assertIsNotNone(output) + + def test_convert(self) -> None: + static_m = MultiheadAttention(embed_dims=100, num_heads=10) + dynamic_m = DynamicMultiheadAttention.convert_from(static_m) + self.assertIsNotNone(dynamic_m) diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_container.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_container.py new file mode 100644 index 000000000..469ce0a9b --- /dev/null +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_container.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import pytest +import torch.nn as nn +from torch.nn import Sequential + +from mmrazor.models.architectures.dynamic_ops import DynamicSequential +from mmrazor.models.mutables import OneShotMutableValue + + +class TestDynamicSequential(TestCase): + + def setUp(self) -> None: + self.layers = [ + nn.Linear(4, 5), + nn.Linear(5, 6), + nn.Linear(6, 7), + nn.Linear(7, 8), + ] + self.dynamic_m = DynamicSequential(*self.layers) + mutable_depth = OneShotMutableValue( + value_list=[2, 3, 4], default_value=3) + + self.dynamic_m.register_mutable_attr('depth', mutable_depth) + + def test_init(self) -> None: + self.assertEqual( + self.dynamic_m.get_mutable_attr('depth').current_choice, 3) + + def test_to_static_op(self) -> None: + with pytest.raises(RuntimeError): + self.dynamic_m.to_static_op() + + current_mutable = self.dynamic_m.get_mutable_attr('depth') + current_mutable.fix_chosen(current_mutable.dump_chosen().chosen) + + static_op = self.dynamic_m.to_static_op() + self.assertIsNotNone(static_op) + + def test_convert_from(self) -> None: + static_m = Sequential(*self.layers) + + dynamic_m = DynamicSequential.convert_from(static_m) + + self.assertIsNotNone(dynamic_m) diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_embed.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_embed.py new file mode 100644 index 000000000..65c2b39a4 --- /dev/null +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_embed.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import pytest +from mmcls.models.utils import PatchEmbed + +from mmrazor.models.architectures.dynamic_ops import DynamicPatchEmbed +from mmrazor.models.mutables import SquentialMutableChannel + + +class TestPatchEmbed(TestCase): + + def setUp(self): + self.dynamic_embed = DynamicPatchEmbed( + img_size=224, in_channels=3, embed_dims=100) + + mutable_embed_dims = SquentialMutableChannel(num_channels=100) + + mutable_embed_dims.current_choice = 50 + self.dynamic_embed.register_mutable_attr('embed_dims', + mutable_embed_dims) + + def test_patch_embed(self): + mutable = SquentialMutableChannel(num_channels=120) + + with pytest.raises(ValueError): + self.dynamic_embed.register_mutable_attr('embed_dims', mutable) + + self.assertTrue( + self.dynamic_embed.get_mutable_attr('embed_dims').current_choice == + 50) + + def test_convert(self): + static_m = PatchEmbed(img_size=224, in_channels=3, embed_dims=768) + + dynamic_m = DynamicPatchEmbed.convert_from(static_m) + + self.assertIsNotNone(dynamic_m) + + def test_to_static_op(self): + mutable_embed_dims = SquentialMutableChannel(num_channels=100) + + mutable_embed_dims.current_choice = 10 + + with pytest.raises(RuntimeError): + self.dynamic_embed.to_static_op() + + mutable_embed_dims.fix_chosen(mutable_embed_dims.dump_chosen().chosen) + self.dynamic_embed.register_mutable_attr('embed_dims', + mutable_embed_dims) + static_op = self.dynamic_embed.to_static_op() + + self.assertIsNotNone(static_op) diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_layernorm.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_layernorm.py new file mode 100644 index 000000000..619881f33 --- /dev/null +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_layernorm.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import pytest +from torch.nn import LayerNorm + +from mmrazor.models.architectures.dynamic_ops import DynamicLayerNorm +from mmrazor.models.mutables import SquentialMutableChannel + + +class TestDynamicLayerNorm(TestCase): + + def setUp(self) -> None: + self.dynamic_m = DynamicLayerNorm(100) + + mutable_num_features = SquentialMutableChannel(num_channels=100) + + mutable_num_features.current_choice = 50 + + self.dynamic_m.register_mutable_attr('num_features', + mutable_num_features) + + def test_init(self) -> None: + mutable = SquentialMutableChannel(num_channels=100) + self.dynamic_m.register_mutable_attr('in_channels', mutable) + self.dynamic_m.register_mutable_attr('out_channels', mutable) + + self.assertEqual( + self.dynamic_m.get_mutable_attr('num_features').current_choice, 50) + + def test_to_static_op(self): + with pytest.raises(RuntimeError): + self.dynamic_m.to_static_op() + + current_mutable = self.dynamic_m.get_mutable_attr('num_features') + current_mutable.fix_chosen(current_mutable.dump_chosen().chosen) + static_op = self.dynamic_m.to_static_op() + + self.assertIsNotNone(static_op) + + def test_convert(self) -> None: + static_m = LayerNorm(100) + dynamic_m = DynamicLayerNorm.convert_from(static_m) + + self.assertIsNotNone(dynamic_m) diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_relative_position.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_relative_position.py new file mode 100644 index 000000000..9f82fe1d3 --- /dev/null +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_relative_position.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import pytest +import torch + +from mmrazor.models.architectures.dynamic_ops import DynamicRelativePosition2D +from mmrazor.models.architectures.ops import RelativePosition2D +from mmrazor.models.mutables import SquentialMutableChannel + + +class TestDynamicRP(TestCase): + + def setUp(self) -> None: + mutable_head_dims = SquentialMutableChannel(num_channels=8) + + self.dynamic_rp = DynamicRelativePosition2D( + head_dims=8, max_relative_position=14) + + mutable_head_dims.current_choice = 6 + self.dynamic_rp.register_mutable_attr('head_dims', mutable_head_dims) + + def test_mutable_attrs(self) -> None: + + assert self.dynamic_rp.mutable_head_dims.current_choice == 6 + + embed = self.dynamic_rp.forward(14, 14) + + self.assertIsNotNone(embed) + + def test_convert(self): + static_model = RelativePosition2D( + head_dims=10, max_relative_position=14) + + dynamic_model = DynamicRelativePosition2D.convert_from(static_model) + + self.assertIsNotNone(dynamic_model) + + def test_to_static_op(self): + with pytest.raises(RuntimeError): + static_m = self.dynamic_rp.to_static_op() + + mutable = SquentialMutableChannel(num_channels=8) + mutable.current_choice = 4 + + mutable.fix_chosen(mutable.dump_chosen().chosen) + + self.dynamic_rp.register_mutable_attr('head_dims', mutable) + static_m = self.dynamic_rp.to_static_op() + + self.assertIsNotNone(static_m) + + dynamic_output = self.dynamic_rp.forward(14, 14) + static_output = static_m.forward(14, 14) + self.assertTrue(torch.equal(dynamic_output, static_output)) diff --git a/tests/test_models/test_classifier/test_imageclassifier.py b/tests/test_models/test_classifier/test_imageclassifier.py new file mode 100644 index 000000000..169d34995 --- /dev/null +++ b/tests/test_models/test_classifier/test_imageclassifier.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +from mmrazor.models import SearchableImageClassifier + + +class TestSearchableImageClassifier(TestCase): + + def test_init(self): + + arch_setting = dict( + mlp_ratios=[3.0, 3.5, 4.0], + num_heads=[8, 9, 10], + depth=[14, 15, 16], + embed_dims=[528, 576, 624]) + + supernet_kwargs = dict( + backbone=dict( + _scope_='mmrazor', + type='AutoformerBackbone', + arch_setting=arch_setting), + neck=None, + head=dict( + _scope_='mmrazor', + type='DynamicLinearClsHead', + num_classes=1000, + in_channels=624, + loss=dict( + type='mmcls.LabelSmoothLoss', + mode='original', + num_classes=1000, + label_smooth_val=0.1, + loss_weight=1.0), + topk=(1, 5)), + connect_head=dict(connect_with_backbone='backbone.last_mutable'), + ) + + supernet = SearchableImageClassifier(**supernet_kwargs) + + # test connect_with_backbone + self.assertEqual( + supernet.backbone.last_mutable.activated_channels, + len( + supernet.head.fc.get_mutable_attr( + 'in_channels').current_choice)) diff --git a/tests/test_models/test_mutables/test_derived_mutable.py b/tests/test_models/test_mutables/test_derived_mutable.py index 0b5f55e88..8ec7c5cd5 100644 --- a/tests/test_models/test_mutables/test_derived_mutable.py +++ b/tests/test_models/test_mutables/test_derived_mutable.py @@ -123,6 +123,27 @@ def test_mutable_divide(self) -> None: mv.current_choice == 120 assert mv_derived.current_choice == 16 + mc_derived = mc // 8.0 + assert mc_derived.source_mutables == {mc} + + mc.current_choice = 128. + assert mc_derived.current_choice == 16 + assert torch.equal(mc_derived.current_mask, + torch.ones(16, dtype=torch.bool)) + mc.current_choice = 120. + assert mc_derived.current_choice == 16 + assert torch.equal(mc_derived.current_mask, + torch.ones(16, dtype=torch.bool)) + + mv = OneShotMutableValue(value_list=[112, 120, 128]) + mv_derived = mv // 8.0 + assert mv_derived.source_mutables == {mv} + + mv.current_choice == 128. + assert mv_derived.current_choice == 16 + mv.current_choice == 120. + assert mv_derived.current_choice == 16 + def test_source_mutables(self) -> None: def useless_fn(x): @@ -207,6 +228,43 @@ def test_nested_mutables(self) -> None: derived_e.current_mask, torch.tensor([1, 0, 1, 1, 1, 1, 0], dtype=torch.bool)) + def test_mutable_channel_value_calculation(self) -> None: + mc = SquentialMutableChannel(num_channels=10) + mv = OneShotMutableValue(value_list=[2.0, 2.5, 3.0, 3.5]) + derived_mutable = mc * mv + assert derived_mutable.source_mutables == {mv, mc} + + mc.current_choice = 6 + mv.current_choice = 3.5 + assert derived_mutable.current_choice == 21 + + mc.current_choice = 9 + mv.current_choice = 3.5 + assert derived_mutable.current_choice == 31 + + mc.current_choice = 7 + mv.current_choice = 2.5 + assert derived_mutable.current_choice == 17 + + assert isinstance(derived_mutable, BaseMutable) + assert isinstance(derived_mutable, DerivedMutable) + assert not derived_mutable.is_fixed + + mc.current_choice = mc.num_channels + mv.current_choice = mv.min_choice + assert derived_mutable.current_choice == \ + mv.current_choice * mc.num_channels + mv.current_choice = mv.max_choice + assert derived_mutable.current_choice == \ + mv.current_choice * mc.current_choice + + with pytest.raises(RuntimeError): + derived_mutable.is_fixed = True + mc.fix_chosen(mc.dump_chosen().chosen) + assert not derived_mutable.is_fixed + mv.fix_chosen(mv.dump_chosen().chosen) + assert derived_mutable.is_fixed + @pytest.mark.parametrize('expand_ratio', [1, 2, 3]) def test_derived_expand_mutable(expand_ratio: int) -> None: @@ -232,3 +290,29 @@ def test_derived_expand_mutable(expand_ratio: int) -> None: mv.current_choice = 5 assert mv_derived.current_choice == 5 * expand_ratio + + +@pytest.mark.parametrize('expand_ratio', [1.5, 2.0, 2.5]) +def test_derived_expand_mutable_float(expand_ratio: float) -> None: + mv = OneShotMutableValue(value_list=[3, 5, 7]) + + mv_derived = mv * expand_ratio + assert mv_derived.source_mutables == {mv} + + assert isinstance(mv_derived, BaseMutable) + assert isinstance(mv_derived, DerivedMutable) + assert not mv_derived.is_fixed + assert mv_derived.num_choices == 1 + + mv.current_choice = mv.max_choice + assert mv_derived.current_choice == int(mv.current_choice * expand_ratio) + mv.current_choice = mv.min_choice + assert mv_derived.current_choice == int(mv.current_choice * expand_ratio) + + with pytest.raises(RuntimeError): + mv_derived.current_choice = 123 + with pytest.raises(RuntimeError): + _ = mv_derived.current_mask + + mv.current_choice = 5 + assert mv_derived.current_choice == int(5 * expand_ratio) diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_sequential_mutable_channel.py b/tests/test_models/test_mutables/test_mutable_channel/test_sequential_mutable_channel.py index 253084d07..c807cabe5 100644 --- a/tests/test_models/test_mutables/test_mutable_channel/test_sequential_mutable_channel.py +++ b/tests/test_models/test_mutables/test_mutable_channel/test_sequential_mutable_channel.py @@ -3,7 +3,8 @@ import torch -from mmrazor.models.mutables import SquentialMutableChannel +from mmrazor.models.mutables import (OneShotMutableValue, + SquentialMutableChannel) class TestSquentialMutableChannel(TestCase): @@ -41,3 +42,16 @@ def test_float_choice(self): channel = SquentialMutableChannel(10, choice_mode='ratio') self._test_mutable(channel, 0.5, 0.5, 5, self._generate_mask(5, 10)) self._test_mutable(channel, 2, 0.2, 2, self._generate_mask(2, 10)) + + def test_mutable_channel_mul(self): + channel = SquentialMutableChannel(2) + self.assertEqual(channel.current_choice, 2) + mv = OneShotMutableValue(value_list=[1, 2, 3], default_value=3) + derived1 = channel * mv + derived2 = mv * channel + assert derived1.current_choice == 6 + assert derived2.current_choice == 6 + mv.current_choice = mv.min_choice + assert derived1.current_choice == 2 + assert derived2.current_choice == 2 + assert torch.equal(derived1.current_mask, derived2.current_mask) diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_units/test_one_shot_mutable_channel_unit.py b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_one_shot_mutable_channel_unit.py index 690382596..a73673919 100644 --- a/tests/test_models/test_mutables/test_mutable_channel/test_units/test_one_shot_mutable_channel_unit.py +++ b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_one_shot_mutable_channel_unit.py @@ -2,6 +2,8 @@ from unittest import TestCase from mmrazor.models.mutables import OneShotMutableChannelUnit +from mmrazor.models.mutators.channel_mutator import ChannelMutator +from .....data.models import DynamicAttention class TestSequentialMutableChannelUnit(TestCase): @@ -14,3 +16,20 @@ def test_init(self): unit = OneShotMutableChannelUnit( 48, [0.3, 0.5, 0.7], choice_mode='ratio', divisor=8) self.assertSequenceEqual(unit.candidate_choices, [1 / 3, 0.5, 2 / 3]) + + def test_unit_predefined(self): + model = DynamicAttention() + mutator = ChannelMutator( + channel_unit_cfg={ + 'type': 'OneShotMutableChannelUnit', + 'default_args': { + 'unit_predefined': False + } + }, + parse_cfg={'type': 'Predefined'}) + mutator.prepare_from_supernet(model) + choices = mutator.sample_choices() + mutator.set_choices(choices) + self.assertSequenceEqual(mutator.units[0].candidate_choices, + [576, 624]) + self.assertSequenceEqual(mutator.units[1].candidate_choices, [64]) diff --git a/tests/test_models/test_mutables/test_mutable_value.py b/tests/test_models/test_mutables/test_mutable_value.py index b33dfcc98..11ac7d49c 100644 --- a/tests/test_models/test_mutables/test_mutable_value.py +++ b/tests/test_models/test_mutables/test_mutable_value.py @@ -73,9 +73,6 @@ def test_mul(self) -> None: assert mul_derived_mv.current_choice == 4 assert rmul_derived_mv.current_choice == 4 - with pytest.raises(TypeError): - _ = mv * 1.2 - mv = MutableValue(value_list=[1, 2, 3], default_value=3) mc = SquentialMutableChannel(num_channels=4) @@ -114,9 +111,6 @@ def test_floordiv(self) -> None: mv.current_choice = 136 assert derived_mv.current_choice == 18 - with pytest.raises(TypeError): - _ = mv // 1.2 - def test_repr(self) -> None: value_list = [2, 4, 6] mv = MutableValue(value_list=value_list) diff --git a/tests/test_models/test_mutators/test_channel_mutator.py b/tests/test_models/test_mutators/test_channel_mutator.py index 3d6ed7773..b4a702bdf 100644 --- a/tests/test_models/test_mutators/test_channel_mutator.py +++ b/tests/test_models/test_mutators/test_channel_mutator.py @@ -10,7 +10,7 @@ L1MutableChannelUnit, SequentialMutableChannelUnit) from mmrazor.models.mutators.channel_mutator import ChannelMutator from mmrazor.registry import MODELS -from ...data.models import DynamicLinearModel +from ...data.models import DynamicAttention, DynamicLinearModel from ...test_core.test_graph.test_graph import TestGraph sys.setrecursionlimit(2000) @@ -135,6 +135,30 @@ def test_models_with_predefined_dynamic_op(self): mutator.prepare_from_supernet(model) self._test_a_mutator(mutator, model) + def test_models_with_predefined_dynamic_op_without_pruning(self): + for Model in [ + DynamicAttention, + ]: + with self.subTest(model=Model): + model = Model() + mutator = ChannelMutator( + channel_unit_cfg={ + 'type': 'OneShotMutableChannelUnit', + 'default_args': { + 'unit_predefined': True + } + }, + parse_cfg={'type': 'Predefined'}) + mutator.prepare_from_supernet(model) + choices = mutator.sample_choices() + mutator.set_choices(choices) + self.assertGreater(len(mutator.mutable_units), 0) + x = torch.rand([2, 3, 224, 224]) + y = model(x) + self.assertEqual( + list(y.shape), + [2, list(mutator.current_choices.values())[0]]) + def test_custom_group(self): ARCHITECTURE_CFG = dict( type='mmcls.ImageClassifier', diff --git a/tests/test_models/test_mutators/test_value_mutator.py b/tests/test_models/test_mutators/test_value_mutator.py new file mode 100644 index 000000000..fefe28195 --- /dev/null +++ b/tests/test_models/test_mutators/test_value_mutator.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch + +from mmrazor.models.mutables import MutableValue +from mmrazor.models.mutators import DynamicValueMutator +from ...data.models import DynamicAttention + + +class TestValueMutator(unittest.TestCase): + + def test_models_with_predefined_dynamic_op(self): + for Model in [ + DynamicAttention, + ]: + with self.subTest(model=Model): + model = Model() + value_mutator = DynamicValueMutator() + value_mutator.prepare_from_supernet(model) + value_choices = value_mutator.sample_choices() + value_mutator.set_choices(value_choices) + + mutable_value_space = [] + for mutable_value, module in model.named_modules(): + if isinstance(module, MutableValue): + mutable_value_space.append(mutable_value) + assert len( + value_mutator.search_groups) == len(mutable_value_space) + + x = torch.rand([2, 3, 224, 224]) + y = model(x) + self.assertEqual(list(y.shape), [2, 624]) diff --git a/tests/test_models/test_subnet/test_candidate.py b/tests/test_models/test_subnet/test_candidate.py index 4cf44846d..7f8bfe640 100644 --- a/tests/test_models/test_subnet/test_candidate.py +++ b/tests/test_models/test_subnet/test_candidate.py @@ -10,7 +10,27 @@ class TestCandidates(TestCase): def setUp(self) -> None: self.fake_subnet = {'1': 'choice1', '2': 'choice2'} - self.fake_subnet_with_score = (self.fake_subnet, 1.) + self.fake_subnet_with_resource = { + str(self.fake_subnet): { + 'score': 0., + 'flops': 50., + 'params': 0., + 'latency': 0. + } + } + self.fake_subnet_with_score = { + str(self.fake_subnet): { + 'score': 99., + 'flops': 0., + 'params': 0., + 'latency': 0. + } + } + self.has_flops_network = { + str(self.fake_subnet): { + 'flops': 50., + } + } def test_init(self): # initlist is None @@ -23,16 +43,25 @@ def test_init(self): # initlist is UserList data = UserList([self.fake_subnet] * 2) self.assertEqual(len(candidates.data), 2) + self.assertEqual(candidates.resources('flops'), [-1, -1]) + # initlist is list(Dict[str, Dict]) + candidates = Candidates([self.has_flops_network] * 2) + self.assertEqual(candidates.resources('flops'), [50., 50.]) def test_scores(self): # test property: scores data = [self.fake_subnet_with_score] * 2 candidates = Candidates(data) - self.assertEqual(candidates.scores, [1., 1.]) + self.assertEqual(candidates.scores, [99., 99.]) + + def test_resources(self): + data = [self.fake_subnet_with_resource] * 2 + candidates = Candidates(data) + self.assertEqual(candidates.resources('flops'), [50., 50.]) def test_subnets(self): # test property: subnets - data = [self.fake_subnet_with_score] * 2 + data = [self.fake_subnet] * 2 candidates = Candidates(data) self.assertEqual(candidates.subnets, [self.fake_subnet] * 2) @@ -41,17 +70,20 @@ def test_append(self): candidates = Candidates() candidates.append(self.fake_subnet) self.assertEqual(len(candidates), 1) - # item is tuple + # item is List candidates = Candidates() - candidates.append(self.fake_subnet_with_score) - self.assertEqual(len(candidates), 1) + candidates.append([self.fake_subnet_with_score]) + # item is Candidates + candidates_2 = Candidates([self.fake_subnet_with_resource]) + candidates.append(candidates_2) + self.assertEqual(len(candidates), 2) def test_insert(self): # item is dict - candidates = Candidates([self.fake_subnet_with_score]) + candidates = Candidates(self.fake_subnet_with_score) candidates.insert(1, self.fake_subnet) self.assertEqual(len(candidates), 2) - # item is tuple + # item is List candidates = Candidates([self.fake_subnet_with_score]) candidates.insert(1, self.fake_subnet_with_score) self.assertEqual(len(candidates), 2) @@ -61,13 +93,60 @@ def test_extend(self): candidates = Candidates([self.fake_subnet_with_score]) candidates.extend([self.fake_subnet]) self.assertEqual(len(candidates), 2) - # other is UserList + # other is Candidates candidates = Candidates([self.fake_subnet_with_score]) - candidates.extend(UserList([self.fake_subnet_with_score])) + candidates_2 = Candidates([self.fake_subnet_with_resource]) + candidates.extend(candidates_2) + self.assertEqual(len(candidates), 2) + + def test_set_resource(self): + # test set_resource + candidates = Candidates([self.fake_subnet]) + for kk in ['flops', 'params', 'latency']: + self.assertEqual(candidates.resources(kk)[0], -1) + candidates.set_resource(0, 49.9, kk) + self.assertEqual(candidates.resources(kk)[0], 49.9) + candidates.insert(0, self.fake_subnet_with_resource) self.assertEqual(len(candidates), 2) + self.assertEqual(candidates.resources('flops'), [50., 49.9]) + self.assertEqual(candidates.resources('latency'), [0., 49.9]) + candidates = Candidates([self.fake_subnet_with_score]) + candidates.set_resource(0, 100.0, 'score') + self.assertEqual(candidates.scores[0], 100.) + candidates = Candidates([self.fake_subnet_with_score]) + candidates.set_resource(0, 100.0, 'score') + candidates.extend(UserList([self.fake_subnet_with_resource])) + candidates.set_resource(1, 99.9, 'score') + self.assertEqual(candidates.scores, [100., 99.9]) + + def test_update_resources(self): + # test update_resources + candidates = Candidates([self.fake_subnet]) + candidates.append([self.fake_subnet_with_score]) + candidates_2 = Candidates(self.fake_subnet_with_resource) + candidates.append(candidates_2) + self.assertEqual(len(candidates), 3) + self.assertEqual(candidates.resources('flops'), [-1, 0., 50.]) + self.assertEqual(candidates.resources('latency'), [-1, 0., 0.]) + resources = [{'flops': -2}, {'latency': 4.}] + candidates.update_resources(resources, start=1) + self.assertEqual(candidates.resources('flops'), [-1, -2, 50.]) + self.assertEqual(candidates.resources('latency'), [-1, 0., 4]) + candidates.update_resources(resources, start=0) + self.assertEqual(candidates.resources('flops'), [-2, -2, 50.]) + self.assertEqual(candidates.resources('latency'), [-1, 4., 4.]) - def test_set_score(self): - # test set_score + def test_sort(self): + # test set_sort candidates = Candidates([self.fake_subnet_with_score]) - candidates.set_score(0, 0.5) - self.assertEqual(candidates[0][1], 0.5) + candidates.extend(UserList([self.fake_subnet_with_resource])) + candidates.insert(0, self.fake_subnet) + candidates.set_resource(0, 100., 'score') + candidates.set_resource(2, 98., 'score') + self.assertEqual(candidates.scores, [100., 99., 98.]) + candidates.sort_by(key_indicator='score', reverse=False) + self.assertEqual(candidates.scores, [98., 99., 100.]) + candidates.sort_by(key_indicator='latency') + self.assertEqual(candidates.scores, [98., 99., 100.]) + candidates.sort_by(key_indicator='flops', reverse=False) + self.assertEqual(candidates.scores, [100., 99., 98.]) diff --git a/tests/test_runners/test_evolution_search_loop.py b/tests/test_runners/test_evolution_search_loop.py index f30019274..6d8814a7b 100644 --- a/tests/test_runners/test_evolution_search_loop.py +++ b/tests/test_runners/test_evolution_search_loop.py @@ -82,7 +82,7 @@ def setUp(self): num_mutation=2, num_crossover=2, mutate_prob=0.1, - flops_range=None, + constraints_range=dict(flops=(0, 330)), score_key='coco/bbox_mAP') self.train_cfg = Config(train_cfg) self.runner = MagicMock(spec=ToyRunner) @@ -103,7 +103,7 @@ def test_init(self): # test init_candidates is not None fake_subnet = {'1': 'choice1', '2': 'choice2'} - fake_candidates = Candidates((fake_subnet, 0.)) + fake_candidates = Candidates(fake_subnet) init_candidates_path = os.path.join(self.temp_dir, 'candidates.yaml') fileio.dump(fake_candidates, init_candidates_path) loop_cfg.init_candidates = init_candidates_path @@ -111,8 +111,12 @@ def test_init(self): self.assertIsInstance(loop, EvolutionSearchLoop) self.assertEqual(loop.candidates, fake_candidates) - @patch('mmrazor.engine.runner.evolution_search_loop.export_fix_subnet') - def test_run_epoch(self, mock_export_fix_subnet): + @patch('mmrazor.engine.runner.utils.check.load_fix_subnet') + @patch('mmrazor.engine.runner.utils.check.export_fix_subnet') + @patch('mmrazor.models.task_modules.estimators.resource_estimator.' + 'get_model_flops_params') + def test_run_epoch(self, flops_params, mock_export_fix_subnet, + load_status): # test_run_epoch: distributed == False loop_cfg = copy.deepcopy(self.train_cfg) loop_cfg.runner = self.runner @@ -120,20 +124,20 @@ def test_run_epoch(self, mock_export_fix_subnet): loop_cfg.evaluator = self.evaluator loop = LOOPS.build(loop_cfg) self.runner.rank = 0 - loop._epoch = 1 self.runner.distributed = False self.runner.work_dir = self.temp_dir fake_subnet = {'1': 'choice1', '2': 'choice2'} - self.runner.model.sample_subnet = MagicMock(return_value=fake_subnet) + loop.model.sample_subnet = MagicMock(return_value=fake_subnet) + load_status.return_value = True + flops_params.return_value = 0, 0 loop.run_epoch() self.assertEqual(len(loop.candidates), 4) self.assertEqual(len(loop.top_k_candidates), 2) - self.assertEqual(loop._epoch, 2) + self.assertEqual(loop._epoch, 1) # test_run_epoch: distributed == True loop = LOOPS.build(loop_cfg) self.runner.rank = 0 - loop._epoch = 1 self.runner.distributed = True self.runner.work_dir = self.temp_dir fake_subnet = {'1': 'choice1', '2': 'choice2'} @@ -141,26 +145,27 @@ def test_run_epoch(self, mock_export_fix_subnet): loop.run_epoch() self.assertEqual(len(loop.candidates), 4) self.assertEqual(len(loop.top_k_candidates), 2) - self.assertEqual(loop._epoch, 2) + self.assertEqual(loop._epoch, 1) # test_check_constraints - loop_cfg.flops_range = (0, 100) + loop_cfg.constraints_range = dict(params=(0, 100)) loop = LOOPS.build(loop_cfg) self.runner.rank = 0 - loop._epoch = 1 self.runner.distributed = True self.runner.work_dir = self.temp_dir fake_subnet = {'1': 'choice1', '2': 'choice2'} loop.model.sample_subnet = MagicMock(return_value=fake_subnet) - loop._check_constraints = MagicMock(return_value=True) + flops_params.return_value = (50., 1) mock_export_fix_subnet.return_value = fake_subnet loop.run_epoch() self.assertEqual(len(loop.candidates), 4) self.assertEqual(len(loop.top_k_candidates), 2) - self.assertEqual(loop._epoch, 2) + self.assertEqual(loop._epoch, 1) - @patch('mmrazor.engine.runner.evolution_search_loop.export_fix_subnet') - def test_run(self, mock_export_fix_subnet): + @patch('mmrazor.engine.runner.utils.check.export_fix_subnet') + @patch('mmrazor.models.task_modules.estimators.resource_estimator.' + 'get_model_flops_params') + def test_run_loop(self, mock_flops, mock_export_fix_subnet): # test a new search: resume == None loop_cfg = copy.deepcopy(self.train_cfg) loop_cfg.runner = self.runner @@ -169,16 +174,26 @@ def test_run(self, mock_export_fix_subnet): loop = LOOPS.build(loop_cfg) self.runner.rank = 0 loop._epoch = 1 + fake_subnet = {'1': 'choice1', '2': 'choice2'} self.runner.work_dir = self.temp_dir loop.update_candidate_pool = MagicMock() loop.val_candidate_pool = MagicMock() + + mutation_candidates = Candidates([fake_subnet] * loop.num_mutation) + for i in range(loop.num_mutation): + mutation_candidates.set_resource(i, 0.1 + 0.1 * i, 'flops') + mutation_candidates.set_resource(i, 99 + i, 'score') + crossover_candidates = Candidates([fake_subnet] * loop.num_crossover) + for i in range(loop.num_crossover): + crossover_candidates.set_resource(i, 0.1 + 0.1 * i, 'flops') + crossover_candidates.set_resource(i, 99 + i, 'score') loop.gen_mutation_candidates = \ - MagicMock(return_value=[fake_subnet]*loop.num_mutation) + MagicMock(return_value=mutation_candidates) loop.gen_crossover_candidates = \ - MagicMock(return_value=[fake_subnet]*loop.num_crossover) - loop.top_k_candidates = Candidates([(fake_subnet, 1.0), - (fake_subnet, 0.9)]) + MagicMock(return_value=crossover_candidates) + loop.candidates = Candidates([fake_subnet] * 4) + mock_flops.return_value = (0.5, 101) mock_export_fix_subnet.return_value = fake_subnet loop.run() assert os.path.exists( diff --git a/tests/test_runners/test_subnet_sampler_loop.py b/tests/test_runners/test_subnet_sampler_loop.py index fca29b823..1c9422fc1 100644 --- a/tests/test_runners/test_subnet_sampler_loop.py +++ b/tests/test_runners/test_subnet_sampler_loop.py @@ -119,7 +119,7 @@ def setUp(self): max_iters=12, val_interval=2, score_key='acc', - flops_range=None, + constraints_range=None, num_candidates=4, num_samples=2, top_k=2, @@ -190,7 +190,7 @@ def test_sample_subnet(self): loop._iter = loop.val_interval subnet = loop.sample_subnet() self.assertEqual(subnet, fake_subnet) - self.assertEqual(len(loop.top_k_candidates), loop.top_k - 1) + self.assertEqual(len(loop.top_k_candidates), loop.top_k) def test_run(self): # test run with _check_constraints @@ -200,7 +200,7 @@ def test_run(self): fake_subnet = {'1': 'choice1', '2': 'choice2'} runner.model.sample_subnet = MagicMock(return_value=fake_subnet) loop = runner.build_train_loop(cfg.train_cfg) - loop._check_constraints = MagicMock(return_value=True) + loop._check_constraints = MagicMock(return_value=(True, dict())) runner.train() self.assertEqual(runner.iter, runner.max_iters) diff --git a/tests/test_runners/test_utils/test_check.py b/tests/test_runners/test_utils/test_check.py index b9bd57989..2f3a80eaa 100644 --- a/tests/test_runners/test_utils/test_check.py +++ b/tests/test_runners/test_utils/test_check.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest.mock import patch -from mmrazor.engine.runner.utils import check_subnet_flops +from mmrazor.engine.runner.utils import check_subnet_resources try: from mmdet.models.detectors import BaseDetector @@ -12,29 +12,33 @@ @patch('mmrazor.models.ResourceEstimator') @patch('mmrazor.models.SPOS') -def test_check_subnet_flops(mock_model, mock_estimator): - # flops_range = None - flops_range = None +def test_check_subnet_resources(mock_model, mock_estimator): + # constraints_range = dict() + constraints_range = dict() fake_subnet = {'1': 'choice1', '2': 'choice2'} - result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, - flops_range) - assert result is True + is_pass, _ = check_subnet_resources(mock_model, fake_subnet, + mock_estimator, constraints_range) + assert is_pass is True - # flops_range is not None + # constraints_range is not None # architecturte is BaseDetector - flops_range = (0., 100.) + constraints_range = dict(flops=(0, 330)) mock_model.architecture = BaseDetector fake_results = {'flops': 50.} mock_estimator.estimate.return_value = fake_results - result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, - flops_range) - assert result is True + is_pass, _ = check_subnet_resources( + mock_model, + fake_subnet, + mock_estimator, + constraints_range, + ) + assert is_pass is True - # flops_range is not None + # constraints_range is not None # architecturte is BaseDetector - flops_range = (0., 100.) + constraints_range = dict(flops=(0, 330)) fake_results = {'flops': -50.} mock_estimator.estimate.return_value = fake_results - result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, - flops_range) - assert result is False + is_pass, _ = check_subnet_resources(mock_model, fake_subnet, + mock_estimator, constraints_range) + assert is_pass is False