From 76c3773e83176d40d9ab8f3e700d4ad0a853a143 Mon Sep 17 00:00:00 2001 From: zengyi <31244134+spynccat@users.noreply.github.com> Date: Wed, 23 Nov 2022 09:55:33 +0800 Subject: [PATCH] [Feature] Add DCFF (#295) * add ChannelGroup (#250) * rebase new dev-1.x * modification for adding config_template * add docstring to channel_group.py * add docstring to mutable_channel_group.py * rm channel_group_cfg from Graph2ChannelGroups * change choice type of SequentialChannelGroup from float to int * add a warning about group-wise conv * restore __init__ of dynamic op * in_channel_mutable -> mutable_in_channel * rm abstractproperty * add a comment about VT * rm registry for ChannelGroup * MUTABLECHANNELGROUP -> ChannelGroupType * refine docstring of IndexDict * update docstring * update docstring * is_prunable -> is_mutable * update docstring * fix error in pre-commit * update unittest * add return type * unify init_xxx apit * add unitest about init of MutableChannelGroup * update according to reviews * sequential_channel_group -> sequential_mutable_channel_group Co-authored-by: liukai * Add BaseChannelMutator and refactor Autoslim (#289) * add BaseChannelMutator * add autoslim * tmp * make SequentialMutableChannelGroup accpeted both of num and ratio as choice. and supports divisior * update OneShotMutableChannelGroup * pass supernet training of autoslim * refine autoslim * fix bug in OneShotMutableChannelGroup * refactor make_divisible * fix spell error: channl -> channel * init_using_backward_tracer -> init_from_backward_tracer init_from_fx_tracer -> init_from_fx_tracer * refine SequentialMutableChannelGroup * let mutator support models with dynamicop * support define search space in model * tracer_cfg -> parse_cfg * refine * using -> from * update docstring * update docstring Co-authored-by: liukai * tmpsave * migrate ut * tmpsave2 * add loss collector * refactor slimmable and add l1-norm (#291) * refactor slimmable and add l1-norm * make l1-norm support convnd * update get_channel_groups * add l1-norm_resnet34_8xb32_in1k.py * add pretrained to resnet34-l1 * remove old channel mutator * BaseChannelMutator -> ChannelMutator * update according to reviews * add readme to l1-norm * MBV2_slimmable -> MBV2_slimmable_config Co-authored-by: liukai * update config * fix md & pytorch support <1.9.0 in batchnorm init * Clean old codes. (#296) * remove old dynamic ops * move dynamic ops * clean old mutable_channels * rm OneShotMutableChannel * rm MutableChannel * refine * refine * use SquentialMutableChannel to replace OneshotMutableChannel * refactor dynamicops folder * let SquentialMutableChannel support float Co-authored-by: liukai * fix ci * ci fix py3.6.x & add mmpose * ci fix py3.6.9 in utils/index_dict.py * fix mmpose * minimum_version_cpu=3.7 * fix ci 3.7.13 * fix pruning &meta ci * support python3.6.9 * fix py3.6 import caused by circular import patch in py3.7 * fix py3.6.9 * Add channel-flow (#301) * base_channel_mutator -> channel_mutator * init * update docstring * allow omitting redundant configs for channel * add register_mutable_channel_to_a_module to MutableChannelContainer * update according to reviews 1 * update according to reviews 2 * update according to reviews 3 * remove old docstring * fix error * using->from * update according to reviews * support self-define input channel number * update docstring * chanenl -> channel_elem Co-authored-by: liukai Co-authored-by: jacky * support >=3.7 * support py3.6.9 * Rename: ChannelGroup -> ChannelUnit (#302) * refine repr of MutableChannelGroup * rename folder name * ChannelGroup -> ChannelUnit * filename in units folder * channel_group -> channel_unit * groups -> units * group -> unit * update * get_mutable_channel_groups -> get_mutable_channel_units * fix bug * refine docstring * fix ci * fix bug in tracer Co-authored-by: liukai * update new channel config format * update pruning refactor * update merged pruning * update commit * fix dynamic_conv_mixin * update comments: readme&dynamic_conv_mixins.py * update readme * move kl softmax channel pooling to op by comments * fix comments: fix redundant & split README.md * dcff in ItePruneAlgorithm * partial dynamic params for fuseconv * add step_freq & prune_time check * update comments * update comments * update comments * fix ut * fix gpu ut & revise step_freq in ItePruneAlgorithm * update readme * revise ItePruneAlgorithm * fix docs * fix dynamic_conv attr * fix ci Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai Co-authored-by: zengyi.vendor Co-authored-by: jacky --- .circleci/test.yml | 3 +- configs/pruning/mmcls/dcff/README.md | 82 +++ .../dcff/dcff_compact_resnet_8xb32_in1k.py | 5 + .../mmcls/dcff/dcff_resnet_8xb32_in1k.py | 81 +++ configs/pruning/mmcls/dcff/resnet_cls.json | 509 +++++++++++++++++ .../l1-norm/l1-norm_resnet34_8xb32_in1k.py | 3 +- configs/pruning/mmdet/dcff/README.md | 82 +++ ..._compact_faster_rcnn_resnet50_8xb4_coco.py | 5 + .../dcff_faster_rcnn_resnet50_8xb4_coco.py | 92 +++ .../dcff/dcff_faster_rcnn_resnet50_fpn.py | 114 ++++ configs/pruning/mmdet/dcff/resnet_det.json | 522 ++++++++++++++++++ configs/pruning/mmpose/dcff/README.md | 82 +++ ...f_compact_topdown_heatmap_resnet50_coco.py | 5 + .../dcff_topdown_heatmap_resnet50_coco.py | 188 +++++++ configs/pruning/mmpose/dcff/resnet_pose.json | 509 +++++++++++++++++ configs/pruning/mmseg/dcff/README.md | 82 +++ ...pact_pointrend_resnet50_8xb2_cityscapes.py | 5 + ...dcff_pointrend_resnet50_8xb2_cityscapes.py | 99 ++++ .../pruning/mmseg/dcff/pointrend_resnet50.py | 63 +++ configs/pruning/mmseg/dcff/resnet_seg.json | 496 +++++++++++++++++ mmrazor/models/algorithms/__init__.py | 3 +- mmrazor/models/algorithms/pruning/__init__.py | 3 +- mmrazor/models/algorithms/pruning/dcff.py | 172 ++++++ .../algorithms/pruning/ite_prune_algorithm.py | 173 ++++-- .../dynamic_ops/bricks/__init__.py | 4 +- .../dynamic_ops/bricks/dynamic_conv.py | 49 +- .../dynamic_ops/mixins/dynamic_conv_mixins.py | 173 +++++- mmrazor/models/mutables/__init__.py | 6 +- .../mutables/mutable_channel/__init__.py | 10 +- ..._channel.py => oneshot_mutable_channel.py} | 0 .../mutable_channel/units/__init__.py | 4 +- .../units/dcff_channel_unit.py | 50 ++ .../units/one_shot_mutable_channel_unit.py | 2 +- mmrazor/models/mutators/__init__.py | 6 +- .../mutators/channel_mutator/__init__.py | 4 +- .../channel_mutator/channel_mutator.py | 3 + .../channel_mutator/dcff_channel_mutator.py | 46 ++ .../tracer/loss_calculator/__init__.py | 11 +- ...cascade_encoder_decoder_loss_calculator.py | 26 + ...top_down_pose_estimator_loss_calculator.py | 25 + .../two_stage_detector_loss_calculator.py | 27 + .../test_models/test_algorithm/MBV2_220M.yaml | 474 ++++++++++++++++ .../test_models/test_mutator/subnet1.json | 15 + .../test_algorithms/test_dcff_network.py | 231 ++++++++ .../test_algorithms/test_prune_algorithm.py | 67 ++- .../test_bricks/test_dynamic_conv.py | 47 +- .../test_units/test_dcff_channel_unit.py | 77 +++ .../test_sequential_mutable_channel.py | 14 + .../test_mutators/test_channel_mutator.py | 1 + .../test_mutators/test_dcff_mutator.py | 110 ++++ 50 files changed, 4746 insertions(+), 114 deletions(-) create mode 100644 configs/pruning/mmcls/dcff/README.md create mode 100644 configs/pruning/mmcls/dcff/dcff_compact_resnet_8xb32_in1k.py create mode 100644 configs/pruning/mmcls/dcff/dcff_resnet_8xb32_in1k.py create mode 100644 configs/pruning/mmcls/dcff/resnet_cls.json create mode 100644 configs/pruning/mmdet/dcff/README.md create mode 100644 configs/pruning/mmdet/dcff/dcff_compact_faster_rcnn_resnet50_8xb4_coco.py create mode 100644 configs/pruning/mmdet/dcff/dcff_faster_rcnn_resnet50_8xb4_coco.py create mode 100644 configs/pruning/mmdet/dcff/dcff_faster_rcnn_resnet50_fpn.py create mode 100644 configs/pruning/mmdet/dcff/resnet_det.json create mode 100644 configs/pruning/mmpose/dcff/README.md create mode 100644 configs/pruning/mmpose/dcff/dcff_compact_topdown_heatmap_resnet50_coco.py create mode 100644 configs/pruning/mmpose/dcff/dcff_topdown_heatmap_resnet50_coco.py create mode 100644 configs/pruning/mmpose/dcff/resnet_pose.json create mode 100644 configs/pruning/mmseg/dcff/README.md create mode 100644 configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py create mode 100644 configs/pruning/mmseg/dcff/dcff_pointrend_resnet50_8xb2_cityscapes.py create mode 100644 configs/pruning/mmseg/dcff/pointrend_resnet50.py create mode 100644 configs/pruning/mmseg/dcff/resnet_seg.json create mode 100644 mmrazor/models/algorithms/pruning/dcff.py rename mmrazor/models/mutables/mutable_channel/{oneshot_mutalbe_channel.py => oneshot_mutable_channel.py} (100%) create mode 100644 mmrazor/models/mutables/mutable_channel/units/dcff_channel_unit.py create mode 100644 mmrazor/models/mutators/channel_mutator/dcff_channel_mutator.py create mode 100644 mmrazor/models/task_modules/tracer/loss_calculator/cascade_encoder_decoder_loss_calculator.py create mode 100644 mmrazor/models/task_modules/tracer/loss_calculator/top_down_pose_estimator_loss_calculator.py create mode 100644 mmrazor/models/task_modules/tracer/loss_calculator/two_stage_detector_loss_calculator.py create mode 100644 tests/data/test_models/test_algorithm/MBV2_220M.yaml create mode 100644 tests/data/test_models/test_mutator/subnet1.json create mode 100644 tests/test_models/test_algorithms/test_dcff_network.py create mode 100644 tests/test_models/test_mutables/test_mutable_channel/test_units/test_dcff_channel_unit.py create mode 100644 tests/test_models/test_mutables/test_sequential_mutable_channel.py create mode 100644 tests/test_models/test_mutators/test_dcff_mutator.py diff --git a/.circleci/test.yml b/.circleci/test.yml index 5da20de36..f8dc9d212 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -70,6 +70,7 @@ jobs: pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x pip install git+https://github.com/open-mmlab/mmsegmentation.git@dev-1.x + python -m pip install git+ssh://git@github.com/open-mmlab/mmpose.git@dev-1.x pip install -r requirements.txt - run: name: Build and install @@ -163,7 +164,7 @@ workflows: torchvision: 0.13.1 python: 3.9.0 requires: - - minimum_version_cpu + - lint - hold: type: approval requires: diff --git a/configs/pruning/mmcls/dcff/README.md b/configs/pruning/mmcls/dcff/README.md new file mode 100644 index 000000000..9f7157c8a --- /dev/null +++ b/configs/pruning/mmcls/dcff/README.md @@ -0,0 +1,82 @@ +# Training Compact CNNs for Image Classification using Dynamic-coded Filter Fusion + +## Abstract + +The mainstream approach for filter pruning is usually either to force a hard-coded importance estimation upon a computation-heavy pretrained model to select “important” filters, or to impose a hyperparameter-sensitive sparse constraint on the loss objective to regularize the network training. In this paper, we present a novel filter pruning method, dubbed dynamic-coded filter fusion (DCFF), to derive compact CNNs in a computationeconomical and regularization-free manner for efficient image classification. Each filter in our DCFF is firstly given an intersimilarity distribution with a temperature parameter as a filter proxy, on top of which, a fresh Kullback-Leibler divergence based dynamic-coded criterion is proposed to evaluate the filter importance. In contrast to simply keeping high-score filters in other methods, we propose the concept of filter fusion, i.e., the weighted averages using the assigned proxies, as our preserved filters. We obtain a one-hot inter-similarity distribution as the temperature parameter approaches infinity. Thus, the relative importance of each filter can vary along with the training of the compact CNN, leading to dynamically changeable fused filters without both the dependency on the pretrained model and the introduction of sparse constraints. Extensive experiments on classification benchmarks demonstrate the superiority of our DCFF over the compared counterparts. For example, our DCFF derives a compact VGGNet-16 with only 72.77M FLOPs and 1.06M parameters while reaching top-1 accuracy of 93.47% on CIFAR-10. A compact ResNet-50 is obtained with 63.8% FLOPs and 58.6% parameter reductions, retaining 75.60% top1 accuracy on ILSVRC-2012. + +![pipeline](https://user-images.githubusercontent.com/31244134/189286581-722853ba-c6d7-4a39-b902-37995b444c71.jpg) + +## Results and models + +### 1. Classification + +| Dataset | Backbone | Params(M) | FLOPs(M) | lr_type | Top-1 (%) | Top-5 (%) | CPrate | Config | Download | +| :------: | :----------: | :-------: | :------: | :-----: | :-------: | :-------: | :---------------------------------------------: | :--------------------------------------------------: | :--------------------------: | +| ImageNet | DCFFResNet50 | 15.16 | 2260 | step | 73.96 | 91.66 | \[0.0\]+\[0.35,0.4,0.1\]\*10+\[0.3,0.3,0.1\]\*6 | [config](../../mmcls/dcff/dcff_resnet_8xb32_in1k.py) | [model](<>) \| \[log\] (\<>) | + +### 2. Detection + +| Dataset | Method | Backbone | Style | Lr schd | Params(M) | FLOPs(M) | bbox AP | CPrate | Config | Download | +| :-----: | :---------: | :----------: | :-----: | :-----: | :-------: | :------: | :-----: | :---------------------------------------------: | :---------------------------------------------------------------: | :--------------------------: | +| COCO | Faster_RCNN | DCFFResNet50 | pytorch | step | 33.31 | 168320 | 35.8 | \[0.0\]+\[0.35,0.4,0.1\]\*10+\[0.3,0.3,0.1\]\*6 | [config](../../mmdet/dcff/dcff_faster_rcnn_resnet50_8xb4_coco.py) | [model](<>) \| \[log\] (\<>) | + +### 3. Segmentation + +| Dataset | Method | Backbone | crop size | Lr schd | Params(M) | FLOPs(M) | mIoU | CPrate | Config | Download | +| :--------: | :-------: | :-------------: | :-------: | :-----: | :-------: | :------: | :---: | :-----------------------------------------------------------------: | :-------------------------------------------------------------------: | :--------------------------: | +| Cityscapes | PointRend | DCFFResNetV1c50 | 512x1024 | 160k | 18.43 | 74410 | 76.75 | \[0.0, 0.0, 0.0\] + \[0.35, 0.4, 0.1\] * 10 + \[0.3, 0.3, 0.1\] * 6 | [config](../../mmseg/dcff/dcff_pointrend_resnet50_8xb2_cityscapes.py) | [model](<>) \| \[log\] (\<>) | + +### 4. Pose + +| Dataset | Method | Backbone | crop size | total epochs | Params(M) | FLOPs(M) | AP | CPrate | Config | Download | +| :-----: | :-------------: | :----------: | :-------: | :----------: | :-------: | :------: | :--: | :--------------------------------------------------------: | :---------------------------------------------------------------: | :--------------------------: | +| COCO | TopDown HeatMap | DCFFResNet50 | 256x192 | 300 | 26.95 | 4290 | 68.3 | \[0.0\] + \[0.2, 0.2, 0.1\] * 10 + \[0.15, 0.15, 0.1\] * 6 | [config](../../mmpose/dcff/dcff_topdown_heatmap_resnet50_coco.py) | [model](<>) \| \[log\] (\<>) | + +## Citation + +```latex +@article{lin2021training, + title={Training Compact CNNs for Image Classification using Dynamic-coded Filter Fusion}, + author={Lin, Mingbao and Ji, Rongrong and Chen, Bohong and Chao, Fei and Liu, Jianzhuang and Zeng, Wei and Tian, Yonghong and Tian, Qi}, + journal={arXiv preprint arXiv:2107.06916}, + year={2021} +} +``` + +## Getting Started + +### Generate channel_config file + +Generate `resnet_cls.json` with `tools/get_channel_units.py`. + +```bash +python tools/get_channel_units.py + configs/pruning/mmcls/dcff/dcff_resnet50_8xb32_in1k.py \ + -c -i --output-path=configs/pruning/mmcls/dcff/resnet_cls.json +``` + +Then set layers' pruning rates `target_pruning_ratio` by `resnet_cls.json`. + +### Train DCFF + +#### Classification + +##### ImageNet + +```bash +sh tools/slurm_train.sh $PARTITION $JOB_NAME \ + configs/pruning/mmcls/dcff/dcff_resnet50_8xb32_in1k.py \ + $WORK_DIR +``` + +### Test DCFF + +#### Classification + +##### ImageNet + +```bash +sh tools/slurm_test.sh $PARTITION $JOB_NAME \ + configs/pruning/mmcls/dcff/dcff_compact_resnet50_8xb32_in1k.py \ + $WORK_DIR +``` diff --git a/configs/pruning/mmcls/dcff/dcff_compact_resnet_8xb32_in1k.py b/configs/pruning/mmcls/dcff/dcff_compact_resnet_8xb32_in1k.py new file mode 100644 index 000000000..66a2587cd --- /dev/null +++ b/configs/pruning/mmcls/dcff/dcff_compact_resnet_8xb32_in1k.py @@ -0,0 +1,5 @@ +_base_ = ['dcff_resnet_8xb32_in1k.py'] + +# model settings +model = _base_.model +model['is_deployed'] = True diff --git a/configs/pruning/mmcls/dcff/dcff_resnet_8xb32_in1k.py b/configs/pruning/mmcls/dcff/dcff_resnet_8xb32_in1k.py new file mode 100644 index 000000000..34a9a15c7 --- /dev/null +++ b/configs/pruning/mmcls/dcff/dcff_resnet_8xb32_in1k.py @@ -0,0 +1,81 @@ +_base_ = [ + 'mmcls::_base_/datasets/imagenet_bs32.py', + 'mmcls::_base_/schedules/imagenet_bs256.py', + 'mmcls::_base_/default_runtime.py' +] + +stage_ratio_1 = 0.65 +stage_ratio_2 = 0.6 +stage_ratio_3 = 0.9 +stage_ratio_4 = 0.7 + +# the config template of target_pruning_ratio can be got by +# python ./tools/get_channel_units.py {config_file} --choice +target_pruning_ratio = { + 'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.0.conv2_(0, 64)_64': stage_ratio_2, + 'backbone.layer1.0.conv3_(0, 256)_256': stage_ratio_3, + 'backbone.layer1.1.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.1.conv2_(0, 64)_64': stage_ratio_2, + 'backbone.layer1.2.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.2.conv2_(0, 64)_64': stage_ratio_2, + # block 1 [0.65, 0.6] downsample=[0.9] + 'backbone.layer2.0.conv1_(0, 128)_128': stage_ratio_1, + 'backbone.layer2.0.conv2_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.0.conv3_(0, 512)_512': stage_ratio_3, + 'backbone.layer2.1.conv1_(0, 128)_128': stage_ratio_1, + 'backbone.layer2.1.conv2_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.2.conv1_(0, 128)_128': stage_ratio_1, + 'backbone.layer2.2.conv2_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.3.conv1_(0, 128)_128': stage_ratio_1, + 'backbone.layer2.3.conv2_(0, 128)_128': stage_ratio_2, + # block 2 [0.65, 0.6] downsample=[0.9] + 'backbone.layer3.0.conv1_(0, 256)_256': stage_ratio_1, + 'backbone.layer3.0.conv2_(0, 256)_256': stage_ratio_2, + 'backbone.layer3.0.conv3_(0, 1024)_1024': stage_ratio_3, + 'backbone.layer3.1.conv1_(0, 256)_256': stage_ratio_1, + 'backbone.layer3.1.conv2_(0, 256)_256': stage_ratio_2, + 'backbone.layer3.2.conv1_(0, 256)_256': stage_ratio_1, + 'backbone.layer3.2.conv2_(0, 256)_256': stage_ratio_2, + 'backbone.layer3.3.conv1_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.3.conv2_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.4.conv1_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.4.conv2_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.5.conv1_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.5.conv2_(0, 256)_256': stage_ratio_4, + # block 3 [0.65, 0.6]*2+[0.7, 0.7]*2 downsample=[0.9] + 'backbone.layer4.0.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.0.conv2_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.0.conv3_(0, 2048)_2048': stage_ratio_3, + 'backbone.layer4.1.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.1.conv2_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.2.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.2.conv2_(0, 512)_512': stage_ratio_4 + # block 4 [0.7, 0.7] downsample=[0.9] +} + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)) +param_scheduler = dict( + type='MultiStepLR', by_epoch=True, milestones=[30, 60, 90], gamma=0.1) +train_cfg = dict(by_epoch=True, max_epochs=120, val_interval=1) + +data_preprocessor = {'type': 'mmcls.ClsDataPreprocessor'} + +# model settings +model = dict( + _scope_='mmrazor', + type='DCFF', + architecture=dict( + cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False), + mutator_cfg=dict( + type='DCFFChannelMutator', + channel_unit_cfg=dict( + type='DCFFChannelUnit', default_args=dict(choice_mode='ratio')), + parse_cfg=dict( + type='BackwardTracer', + loss_calculator=dict(type='ImageClassifierPseudoLoss'))), + target_pruning_ratio=target_pruning_ratio, + step_freq=1, + linear_schedule=False, + is_deployed=False) diff --git a/configs/pruning/mmcls/dcff/resnet_cls.json b/configs/pruning/mmcls/dcff/resnet_cls.json new file mode 100644 index 000000000..3fafa125d --- /dev/null +++ b/configs/pruning/mmcls/dcff/resnet_cls.json @@ -0,0 +1,509 @@ +{ + "backbone.conv1_(0, 3)_3":{ + "init_args":{ + "num_channels":3, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 3 + ], + "choice_mode":"number" + }, + "choice":3 + }, + "backbone.conv1_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 64 + ], + "choice_mode":"number" + }, + "choice":64 + }, + "backbone.layer1.0.conv1_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 41 + ], + "choice_mode":"number" + }, + "choice":41 + }, + "backbone.layer1.0.conv2_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 38 + ], + "choice_mode":"number" + }, + "choice":38 + }, + "backbone.layer1.0.conv3_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 230 + ], + "choice_mode":"number" + }, + "choice":230 + }, + "backbone.layer1.1.conv1_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 41 + ], + "choice_mode":"number" + }, + "choice":41 + }, + "backbone.layer1.1.conv2_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 38 + ], + "choice_mode":"number" + }, + "choice":38 + }, + "backbone.layer1.2.conv1_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 41 + ], + "choice_mode":"number" + }, + "choice":41 + }, + "backbone.layer1.2.conv2_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 38 + ], + "choice_mode":"number" + }, + "choice":38 + }, + "backbone.layer2.0.conv1_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 83 + ], + "choice_mode":"number" + }, + "choice":83 + }, + "backbone.layer2.0.conv2_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 76 + ], + "choice_mode":"number" + }, + "choice":76 + }, + "backbone.layer2.0.conv3_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 460 + ], + "choice_mode":"number" + }, + "choice":460 + }, + "backbone.layer2.1.conv1_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 83 + ], + "choice_mode":"number" + }, + "choice":83 + }, + "backbone.layer2.1.conv2_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 76 + ], + "choice_mode":"number" + }, + "choice":76 + }, + "backbone.layer2.2.conv1_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 83 + ], + "choice_mode":"number" + }, + "choice":83 + }, + "backbone.layer2.2.conv2_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 76 + ], + "choice_mode":"number" + }, + "choice":76 + }, + "backbone.layer2.3.conv1_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 83 + ], + "choice_mode":"number" + }, + "choice":83 + }, + "backbone.layer2.3.conv2_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 76 + ], + "choice_mode":"number" + }, + "choice":76 + }, + "backbone.layer3.0.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 166 + ], + "choice_mode":"number" + }, + "choice":166 + }, + "backbone.layer3.0.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 153 + ], + "choice_mode":"number" + }, + "choice":153 + }, + "backbone.layer3.0.conv3_(0, 1024)_1024":{ + "init_args":{ + "num_channels":1024, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 921 + ], + "choice_mode":"number" + }, + "choice":921 + }, + "backbone.layer3.1.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 166 + ], + "choice_mode":"number" + }, + "choice":166 + }, + "backbone.layer3.1.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 153 + ], + "choice_mode":"number" + }, + "choice":153 + }, + "backbone.layer3.2.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 166 + ], + "choice_mode":"number" + }, + "choice":166 + }, + "backbone.layer3.2.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 153 + ], + "choice_mode":"number" + }, + "choice":153 + }, + "backbone.layer3.3.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer3.3.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer3.4.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer3.4.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer3.5.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer3.5.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer4.0.conv1_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "backbone.layer4.0.conv2_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "backbone.layer4.0.conv3_(0, 2048)_2048":{ + "init_args":{ + "num_channels":2048, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 1843 + ], + "choice_mode":"number" + }, + "choice":1843 + }, + "backbone.layer4.1.conv1_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "backbone.layer4.1.conv2_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "backbone.layer4.2.conv1_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "backbone.layer4.2.conv2_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "head.fc_(0, 1000)_1000":{ + "init_args":{ + "num_channels":1000, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 1000 + ], + "choice_mode":"number" + }, + "choice":1000 + } +} diff --git a/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k.py b/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k.py index 89ef4138f..1da311cf4 100644 --- a/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k.py +++ b/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k.py @@ -51,6 +51,5 @@ type='L1MutableChannelUnit', default_args=dict(choice_mode='ratio'))), target_pruning_ratio=target_pruning_ratio, - step_epoch=1, - prune_times=1, + step_freq=1, ) diff --git a/configs/pruning/mmdet/dcff/README.md b/configs/pruning/mmdet/dcff/README.md new file mode 100644 index 000000000..8742f2997 --- /dev/null +++ b/configs/pruning/mmdet/dcff/README.md @@ -0,0 +1,82 @@ +# Training Compact CNNs for Image Classification using Dynamic-coded Filter Fusion + +## Abstract + +The mainstream approach for filter pruning is usually either to force a hard-coded importance estimation upon a computation-heavy pretrained model to select “important” filters, or to impose a hyperparameter-sensitive sparse constraint on the loss objective to regularize the network training. In this paper, we present a novel filter pruning method, dubbed dynamic-coded filter fusion (DCFF), to derive compact CNNs in a computationeconomical and regularization-free manner for efficient image classification. Each filter in our DCFF is firstly given an intersimilarity distribution with a temperature parameter as a filter proxy, on top of which, a fresh Kullback-Leibler divergence based dynamic-coded criterion is proposed to evaluate the filter importance. In contrast to simply keeping high-score filters in other methods, we propose the concept of filter fusion, i.e., the weighted averages using the assigned proxies, as our preserved filters. We obtain a one-hot inter-similarity distribution as the temperature parameter approaches infinity. Thus, the relative importance of each filter can vary along with the training of the compact CNN, leading to dynamically changeable fused filters without both the dependency on the pretrained model and the introduction of sparse constraints. Extensive experiments on classification benchmarks demonstrate the superiority of our DCFF over the compared counterparts. For example, our DCFF derives a compact VGGNet-16 with only 72.77M FLOPs and 1.06M parameters while reaching top-1 accuracy of 93.47% on CIFAR-10. A compact ResNet-50 is obtained with 63.8% FLOPs and 58.6% parameter reductions, retaining 75.60% top1 accuracy on ILSVRC-2012. + +![pipeline](https://user-images.githubusercontent.com/31244134/189286581-722853ba-c6d7-4a39-b902-37995b444c71.jpg) + +## Results and models + +### 1. Classification + +| Dataset | Backbone | Params(M) | FLOPs(M) | lr_type | Top-1 (%) | Top-5 (%) | CPrate | Config | Download | +| :------: | :----------: | :-------: | :------: | :-----: | :-------: | :-------: | :---------------------------------------------: | :--------------------------------------------------: | :--------------------------: | +| ImageNet | DCFFResNet50 | 15.16 | 2260 | step | 73.96 | 91.66 | \[0.0\]+\[0.35,0.4,0.1\]\*10+\[0.3,0.3,0.1\]\*6 | [config](../../mmcls/dcff/dcff_resnet_8xb32_in1k.py) | [model](<>) \| \[log\] (\<>) | + +### 2. Detection + +| Dataset | Method | Backbone | Style | Lr schd | Params(M) | FLOPs(M) | bbox AP | CPrate | Config | Download | +| :-----: | :---------: | :----------: | :-----: | :-----: | :-------: | :------: | :-----: | :---------------------------------------------: | :---------------------------------------------------------------: | :--------------------------: | +| COCO | Faster_RCNN | DCFFResNet50 | pytorch | step | 33.31 | 168320 | 35.8 | \[0.0\]+\[0.35,0.4,0.1\]\*10+\[0.3,0.3,0.1\]\*6 | [config](../../mmdet/dcff/dcff_faster_rcnn_resnet50_8xb4_coco.py) | [model](<>) \| \[log\] (\<>) | + +### 3. Segmentation + +| Dataset | Method | Backbone | crop size | Lr schd | Params(M) | FLOPs(M) | mIoU | CPrate | Config | Download | +| :--------: | :-------: | :-------------: | :-------: | :-----: | :-------: | :------: | :---: | :-----------------------------------------------------------------: | :-------------------------------------------------------------------: | :--------------------------: | +| Cityscapes | PointRend | DCFFResNetV1c50 | 512x1024 | 160k | 18.43 | 74410 | 76.75 | \[0.0, 0.0, 0.0\] + \[0.35, 0.4, 0.1\] * 10 + \[0.3, 0.3, 0.1\] * 6 | [config](../../mmseg/dcff/dcff_pointrend_resnet50_8xb2_cityscapes.py) | [model](<>) \| \[log\] (\<>) | + +### 4. Pose + +| Dataset | Method | Backbone | crop size | total epochs | Params(M) | FLOPs(M) | AP | CPrate | Config | Download | +| :-----: | :-------------: | :----------: | :-------: | :----------: | :-------: | :------: | :--: | :--------------------------------------------------------: | :---------------------------------------------------------------: | :--------------------------: | +| COCO | TopDown HeatMap | DCFFResNet50 | 256x192 | 300 | 26.95 | 4290 | 68.3 | \[0.0\] + \[0.2, 0.2, 0.1\] * 10 + \[0.15, 0.15, 0.1\] * 6 | [config](../../mmpose/dcff/dcff_topdown_heatmap_resnet50_coco.py) | [model](<>) \| \[log\] (\<>) | + +## Citation + +```latex +@article{lin2021training, + title={Training Compact CNNs for Image Classification using Dynamic-coded Filter Fusion}, + author={Lin, Mingbao and Ji, Rongrong and Chen, Bohong and Chao, Fei and Liu, Jianzhuang and Zeng, Wei and Tian, Yonghong and Tian, Qi}, + journal={arXiv preprint arXiv:2107.06916}, + year={2021} +} +``` + +## Getting Started + +### Generate channel_config file + +Generate `resnet_det.json` with `tools/get_channel_units.py`. + +```bash +python tools/get_channel_units.py + configs/pruning/mmdet/dcff/dcff_faster_rcnn_resnet50_8xb4_coco.py \ + -c -i --output-path=configs/pruning/mmcls/dcff/resnet_det.json +``` + +Then set layers' pruning rates `target_pruning_ratio` by `resnet_det.json`. + +### Train DCFF + +#### Detection + +##### COCO + +```bash +sh tools/slurm_train.sh $PARTITION $JOB_NAME \ + configs/pruning/mmdet/dcff/dcff_faster_rcnn_resnet50_8xb4_coco.py \ + $WORK_DIR +``` + +### Test DCFF + +#### Detection + +##### COCO + +```bash +sh tools/slurm_test.sh $PARTITION $JOB_NAME \ + configs/pruning/mmdet/dcff/dcff_compact_faster_rcnn_resnet50_8xb4_coco.py \ + $WORK_DIR +``` diff --git a/configs/pruning/mmdet/dcff/dcff_compact_faster_rcnn_resnet50_8xb4_coco.py b/configs/pruning/mmdet/dcff/dcff_compact_faster_rcnn_resnet50_8xb4_coco.py new file mode 100644 index 000000000..7efb17b7e --- /dev/null +++ b/configs/pruning/mmdet/dcff/dcff_compact_faster_rcnn_resnet50_8xb4_coco.py @@ -0,0 +1,5 @@ +_base_ = ['dcff_faster_rcnn_resnet50_8xb4_coco.py'] + +# model settings +model = _base_.model +model['is_deployed'] = True diff --git a/configs/pruning/mmdet/dcff/dcff_faster_rcnn_resnet50_8xb4_coco.py b/configs/pruning/mmdet/dcff/dcff_faster_rcnn_resnet50_8xb4_coco.py new file mode 100644 index 000000000..d19828b73 --- /dev/null +++ b/configs/pruning/mmdet/dcff/dcff_faster_rcnn_resnet50_8xb4_coco.py @@ -0,0 +1,92 @@ +_base_ = [ + './dcff_faster_rcnn_resnet50_fpn.py', + 'mmdet::_base_/datasets/coco_detection.py', + 'mmdet::_base_/schedules/schedule_2x.py', + 'mmdet::_base_/default_runtime.py' +] + +stage_ratio_1 = 0.65 +stage_ratio_2 = 0.6 +stage_ratio_3 = 0.9 +stage_ratio_4 = 0.7 + +# the config template of target_pruning_ratio can be got by +# python ./tools/get_channel_units.py {config_file} --choice +target_pruning_ratio = { + 'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.0.conv2_(0, 64)_64': stage_ratio_2, + 'backbone.layer1.0.conv3_(0, 256)_256': stage_ratio_3, + 'backbone.layer1.1.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.1.conv2_(0, 64)_64': stage_ratio_2, + 'backbone.layer1.2.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.2.conv2_(0, 64)_64': stage_ratio_2, + # block 1 [0.65, 0.6] downsample=[0.9] + 'backbone.layer2.0.conv1_(0, 128)_128': stage_ratio_1, + 'backbone.layer2.0.conv2_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.0.conv3_(0, 512)_512': stage_ratio_3, + 'backbone.layer2.1.conv1_(0, 128)_128': stage_ratio_1, + 'backbone.layer2.1.conv2_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.2.conv1_(0, 128)_128': stage_ratio_1, + 'backbone.layer2.2.conv2_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.3.conv1_(0, 128)_128': stage_ratio_1, + 'backbone.layer2.3.conv2_(0, 128)_128': stage_ratio_2, + # block 2 [0.65, 0.6] downsample=[0.9] + 'backbone.layer3.0.conv1_(0, 256)_256': stage_ratio_1, + 'backbone.layer3.0.conv2_(0, 256)_256': stage_ratio_2, + 'backbone.layer3.0.conv3_(0, 1024)_1024': stage_ratio_3, + 'backbone.layer3.1.conv1_(0, 256)_256': stage_ratio_1, + 'backbone.layer3.1.conv2_(0, 256)_256': stage_ratio_2, + 'backbone.layer3.2.conv1_(0, 256)_256': stage_ratio_1, + 'backbone.layer3.2.conv2_(0, 256)_256': stage_ratio_2, + 'backbone.layer3.3.conv1_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.3.conv2_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.4.conv1_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.4.conv2_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.5.conv1_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.5.conv2_(0, 256)_256': stage_ratio_4, + # block 3 [0.65, 0.6]*2+[0.7, 0.7]*2 downsample=[0.9] + 'backbone.layer4.0.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.0.conv2_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.0.conv3_(0, 2048)_2048': stage_ratio_3, + 'backbone.layer4.1.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.1.conv2_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.2.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.2.conv2_(0, 512)_512': stage_ratio_4 + # block 4 [0.7, 0.7] downsample=[0.9] +} + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.04, momentum=0.9, weight_decay=0.0001)) +param_scheduler = dict( + type='MultiStepLR', + by_epoch=True, + milestones=[60, 80, 95], + gamma=0.1, + _delete_=True) +train_cfg = dict(max_epochs=120, val_interval=1) + +# !dataset config +# ========================================================================== +# data preprocessor + +model = dict( + _scope_='mmrazor', + type='DCFF', + architecture=_base_.architecture, + mutator_cfg=dict( + type='DCFFChannelMutator', + channel_unit_cfg=dict( + type='DCFFChannelUnit', + units='configs/pruning/mmdet/dcff/resnet_det.json'), + parse_cfg=dict( + type='BackwardTracer', + loss_calculator=dict(type='TwoStageDetectorPseudoLoss'))), + target_pruning_ratio=target_pruning_ratio, + step_freq=1, + linear_schedule=False, + is_deployed=False) + +model_wrapper = dict( + type='mmcv.MMDistributedDataParallel', find_unused_parameters=True) + +val_cfg = dict(_delete_=True) diff --git a/configs/pruning/mmdet/dcff/dcff_faster_rcnn_resnet50_fpn.py b/configs/pruning/mmdet/dcff/dcff_faster_rcnn_resnet50_fpn.py new file mode 100644 index 000000000..0ce540338 --- /dev/null +++ b/configs/pruning/mmdet/dcff/dcff_faster_rcnn_resnet50_fpn.py @@ -0,0 +1,114 @@ +# architecture settings +architecture = dict( + _scope_='mmdet', + type='FasterRCNN', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + )) diff --git a/configs/pruning/mmdet/dcff/resnet_det.json b/configs/pruning/mmdet/dcff/resnet_det.json new file mode 100644 index 000000000..7e3de46b3 --- /dev/null +++ b/configs/pruning/mmdet/dcff/resnet_det.json @@ -0,0 +1,522 @@ +{ + "backbone.conv1_(0, 3)_3":{ + "init_args":{ + "num_channels":3, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 3 + ], + "choice_mode":"number" + }, + "choice":3 + }, + "backbone.conv1_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 64 + ], + "choice_mode":"number" + }, + "choice":64 + }, + "backbone.layer1.0.conv1_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 41 + ], + "choice_mode":"number" + }, + "choice":41 + }, + "backbone.layer1.0.conv2_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 38 + ], + "choice_mode":"number" + }, + "choice":38 + }, + "backbone.layer1.0.conv3_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 230 + ], + "choice_mode":"number" + }, + "choice":230 + }, + "backbone.layer1.1.conv1_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 41 + ], + "choice_mode":"number" + }, + "choice":41 + }, + "backbone.layer1.1.conv2_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 38 + ], + "choice_mode":"number" + }, + "choice":38 + }, + "backbone.layer1.2.conv1_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 41 + ], + "choice_mode":"number" + }, + "choice":41 + }, + "backbone.layer1.2.conv2_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 38 + ], + "choice_mode":"number" + }, + "choice":38 + }, + "backbone.layer2.0.conv1_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 83 + ], + "choice_mode":"number" + }, + "choice":83 + }, + "backbone.layer2.0.conv2_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 76 + ], + "choice_mode":"number" + }, + "choice":76 + }, + "backbone.layer2.0.conv3_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 460 + ], + "choice_mode":"number" + }, + "choice":460 + }, + "backbone.layer2.1.conv1_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 83 + ], + "choice_mode":"number" + }, + "choice":83 + }, + "backbone.layer2.1.conv2_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 76 + ], + "choice_mode":"number" + }, + "choice":76 + }, + "backbone.layer2.2.conv1_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 83 + ], + "choice_mode":"number" + }, + "choice":83 + }, + "backbone.layer2.2.conv2_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 76 + ], + "choice_mode":"number" + }, + "choice":76 + }, + "backbone.layer2.3.conv1_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 83 + ], + "choice_mode":"number" + }, + "choice":83 + }, + "backbone.layer2.3.conv2_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 76 + ], + "choice_mode":"number" + }, + "choice":76 + }, + "backbone.layer3.0.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 166 + ], + "choice_mode":"number" + }, + "choice":166 + }, + "backbone.layer3.0.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 153 + ], + "choice_mode":"number" + }, + "choice":153 + }, + "backbone.layer3.0.conv3_(0, 1024)_1024":{ + "init_args":{ + "num_channels":1024, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 921 + ], + "choice_mode":"number" + }, + "choice":921 + }, + "backbone.layer3.1.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 166 + ], + "choice_mode":"number" + }, + "choice":166 + }, + "backbone.layer3.1.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 153 + ], + "choice_mode":"number" + }, + "choice":153 + }, + "backbone.layer3.2.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 166 + ], + "choice_mode":"number" + }, + "choice":166 + }, + "backbone.layer3.2.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 153 + ], + "choice_mode":"number" + }, + "choice":153 + }, + "backbone.layer3.3.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer3.3.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer3.4.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer3.4.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer3.5.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer3.5.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer4.0.conv1_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "backbone.layer4.0.conv2_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "backbone.layer4.0.conv3_(0, 2048)_2048":{ + "init_args":{ + "num_channels":2048, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 1843 + ], + "choice_mode":"number" + }, + "choice":1843 + }, + "backbone.layer4.1.conv1_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "backbone.layer4.1.conv2_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "backbone.layer4.2.conv1_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "backbone.layer4.2.conv2_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "backbone.layer4.2.conv3_(0, 2048)_2048":{ + "init_args":{ + "num_channels":2048, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 1843 + ], + "choice_mode":"number" + }, + "choice":1843 + }, + "head.fc_(0, 1000)_1000":{ + "init_args":{ + "num_channels":1000, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 1000 + ], + "choice_mode":"number" + }, + "choice":1000 + } +} diff --git a/configs/pruning/mmpose/dcff/README.md b/configs/pruning/mmpose/dcff/README.md new file mode 100644 index 000000000..95d7d7db6 --- /dev/null +++ b/configs/pruning/mmpose/dcff/README.md @@ -0,0 +1,82 @@ +# Training Compact CNNs for Image Classification using Dynamic-coded Filter Fusion + +## Abstract + +The mainstream approach for filter pruning is usually either to force a hard-coded importance estimation upon a computation-heavy pretrained model to select “important” filters, or to impose a hyperparameter-sensitive sparse constraint on the loss objective to regularize the network training. In this paper, we present a novel filter pruning method, dubbed dynamic-coded filter fusion (DCFF), to derive compact CNNs in a computationeconomical and regularization-free manner for efficient image classification. Each filter in our DCFF is firstly given an intersimilarity distribution with a temperature parameter as a filter proxy, on top of which, a fresh Kullback-Leibler divergence based dynamic-coded criterion is proposed to evaluate the filter importance. In contrast to simply keeping high-score filters in other methods, we propose the concept of filter fusion, i.e., the weighted averages using the assigned proxies, as our preserved filters. We obtain a one-hot inter-similarity distribution as the temperature parameter approaches infinity. Thus, the relative importance of each filter can vary along with the training of the compact CNN, leading to dynamically changeable fused filters without both the dependency on the pretrained model and the introduction of sparse constraints. Extensive experiments on classification benchmarks demonstrate the superiority of our DCFF over the compared counterparts. For example, our DCFF derives a compact VGGNet-16 with only 72.77M FLOPs and 1.06M parameters while reaching top-1 accuracy of 93.47% on CIFAR-10. A compact ResNet-50 is obtained with 63.8% FLOPs and 58.6% parameter reductions, retaining 75.60% top1 accuracy on ILSVRC-2012. + +![pipeline](https://user-images.githubusercontent.com/31244134/189286581-722853ba-c6d7-4a39-b902-37995b444c71.jpg) + +## Results and models + +### 1. Classification + +| Dataset | Backbone | Params(M) | FLOPs(M) | lr_type | Top-1 (%) | Top-5 (%) | CPrate | Config | Download | +| :------: | :----------: | :-------: | :------: | :-----: | :-------: | :-------: | :---------------------------------------------: | :--------------------------------------------------: | :--------------------------: | +| ImageNet | DCFFResNet50 | 15.16 | 2260 | step | 73.96 | 91.66 | \[0.0\]+\[0.35,0.4,0.1\]\*10+\[0.3,0.3,0.1\]\*6 | [config](../../mmcls/dcff/dcff_resnet_8xb32_in1k.py) | [model](<>) \| \[log\] (\<>) | + +### 2. Detection + +| Dataset | Method | Backbone | Style | Lr schd | Params(M) | FLOPs(M) | bbox AP | CPrate | Config | Download | +| :-----: | :---------: | :----------: | :-----: | :-----: | :-------: | :------: | :-----: | :---------------------------------------------: | :---------------------------------------------------------------: | :--------------------------: | +| COCO | Faster_RCNN | DCFFResNet50 | pytorch | step | 33.31 | 168320 | 35.8 | \[0.0\]+\[0.35,0.4,0.1\]\*10+\[0.3,0.3,0.1\]\*6 | [config](../../mmdet/dcff/dcff_faster_rcnn_resnet50_8xb4_coco.py) | [model](<>) \| \[log\] (\<>) | + +### 3. Segmentation + +| Dataset | Method | Backbone | crop size | Lr schd | Params(M) | FLOPs(M) | mIoU | CPrate | Config | Download | +| :--------: | :-------: | :-------------: | :-------: | :-----: | :-------: | :------: | :---: | :-----------------------------------------------------------------: | :-------------------------------------------------------------------: | :--------------------------: | +| Cityscapes | PointRend | DCFFResNetV1c50 | 512x1024 | 160k | 18.43 | 74410 | 76.75 | \[0.0, 0.0, 0.0\] + \[0.35, 0.4, 0.1\] * 10 + \[0.3, 0.3, 0.1\] * 6 | [config](../../mmseg/dcff/dcff_pointrend_resnet50_8xb2_cityscapes.py) | [model](<>) \| \[log\] (\<>) | + +### 4. Pose + +| Dataset | Method | Backbone | crop size | total epochs | Params(M) | FLOPs(M) | AP | CPrate | Config | Download | +| :-----: | :-------------: | :----------: | :-------: | :----------: | :-------: | :------: | :--: | :--------------------------------------------------------: | :---------------------------------------------------------------: | :--------------------------: | +| COCO | TopDown HeatMap | DCFFResNet50 | 256x192 | 300 | 26.95 | 4290 | 68.3 | \[0.0\] + \[0.2, 0.2, 0.1\] * 10 + \[0.15, 0.15, 0.1\] * 6 | [config](../../mmpose/dcff/dcff_topdown_heatmap_resnet50_coco.py) | [model](<>) \| \[log\] (\<>) | + +## Citation + +```latex +@article{lin2021training, + title={Training Compact CNNs for Image Classification using Dynamic-coded Filter Fusion}, + author={Lin, Mingbao and Ji, Rongrong and Chen, Bohong and Chao, Fei and Liu, Jianzhuang and Zeng, Wei and Tian, Yonghong and Tian, Qi}, + journal={arXiv preprint arXiv:2107.06916}, + year={2021} +} +``` + +## Getting Started + +### Generate channel_config file + +Generate `resnet_pose.json` with `tools/get_channel_units.py`. + +```bash +python tools/get_channel_units.py + configs/pruning/mmpose/dcff/dcff_topdown_heatmap_resnet50.py \ + -c -i --output-path=configs/pruning/mmpose/dcff/resnet_pose.json +``` + +Then set layers' pruning rates `target_pruning_ratio` by `resnet_pose.json`. + +### Train DCFF + +#### Pose + +##### COCO + +```bash +sh tools/slurm_train.sh $PARTITION $JOB_NAME \ + configs/pruning/mmpose/dcff/dcff_topdown_heatmap_resnet50.py \ + $WORK_DIR +``` + +### Test DCFF + +#### Pose + +##### COCO + +```bash +sh tools/slurm_test.sh $PARTITION $JOB_NAME \ + configs/pruning/mmpose/dcff/dcff_compact_topdown_heatmap_resnet50.py \ + $WORK_DIR +``` diff --git a/configs/pruning/mmpose/dcff/dcff_compact_topdown_heatmap_resnet50_coco.py b/configs/pruning/mmpose/dcff/dcff_compact_topdown_heatmap_resnet50_coco.py new file mode 100644 index 000000000..8ec4867b2 --- /dev/null +++ b/configs/pruning/mmpose/dcff/dcff_compact_topdown_heatmap_resnet50_coco.py @@ -0,0 +1,5 @@ +_base_ = ['dcff_topdown_heatmap_resnet50_coco.py'] + +# model settings +model = _base_.model +model['is_deployed'] = True diff --git a/configs/pruning/mmpose/dcff/dcff_topdown_heatmap_resnet50_coco.py b/configs/pruning/mmpose/dcff/dcff_topdown_heatmap_resnet50_coco.py new file mode 100644 index 000000000..3981a54b4 --- /dev/null +++ b/configs/pruning/mmpose/dcff/dcff_topdown_heatmap_resnet50_coco.py @@ -0,0 +1,188 @@ +_base_ = [ + 'mmpose::_base_/default_runtime.py', +] +train_cfg = dict(max_epochs=300, val_interval=10) + +optim_wrapper = dict(optimizer=dict(type='Adam', lr=5e-4), clip_grad=None) + +# learning policy +param_scheduler = [ + dict( + type='LinearLR', begin=0, end=500, start_factor=0.001, + by_epoch=False), # warm-up + dict( + type='MultiStepLR', + begin=0, + end=300, + milestones=[170, 220, 280], + gamma=0.1, + by_epoch=True) +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=512) + +# hooks +default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater')) + +# codec settings +codec = dict( + type='MSRAHeatmap', input_size=(192, 256), heatmap_size=(48, 64), sigma=2) + +# model settings +architecture = dict( + type='mmpose.TopdownPoseEstimator', + data_preprocessor=dict( + type='mmpose.PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + type='mmpose.ResNet', + depth=50, + num_stages=4, + out_indices=(3, ), + ), + head=dict( + type='mmpose.HeatmapHead', + in_channels=1843, + out_channels=17, + loss=dict(type='mmpose.KeypointMSELoss', use_target_weight=True), + decoder=codec), + test_cfg=dict( + flip_test=True, + flip_mode='heatmap', + shift_heatmap=True, + )) + +stage_ratio_1 = 0.8 +stage_ratio_2 = 0.8 +stage_ratio_3 = 0.9 +stage_ratio_4 = 0.85 + +# the config template of target_pruning_ratio can be got by +# python ./tools/get_channel_units.py {config_file} --choice +target_pruning_ratio = { + 'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.0.conv2_(0, 64)_64': stage_ratio_2, + 'backbone.layer1.0.conv3_(0, 256)_256': stage_ratio_3, + 'backbone.layer1.1.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.1.conv2_(0, 64)_64': stage_ratio_2, + 'backbone.layer1.2.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.2.conv2_(0, 64)_64': stage_ratio_2, + # block 1 [0.8, 0.8] downsample=[0.9] + 'backbone.layer2.0.conv1_(0, 128)_128': stage_ratio_1, + 'backbone.layer2.0.conv2_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.0.conv3_(0, 512)_512': stage_ratio_3, + 'backbone.layer2.1.conv1_(0, 128)_128': stage_ratio_1, + 'backbone.layer2.1.conv2_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.2.conv1_(0, 128)_128': stage_ratio_1, + 'backbone.layer2.2.conv2_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.3.conv1_(0, 128)_128': stage_ratio_1, + 'backbone.layer2.3.conv2_(0, 128)_128': stage_ratio_2, + # block 2 [0.8, 0.8] downsample=[0.9] + 'backbone.layer3.0.conv1_(0, 256)_256': stage_ratio_1, + 'backbone.layer3.0.conv2_(0, 256)_256': stage_ratio_2, + 'backbone.layer3.0.conv3_(0, 1024)_1024': stage_ratio_3, + 'backbone.layer3.1.conv1_(0, 256)_256': stage_ratio_1, + 'backbone.layer3.1.conv2_(0, 256)_256': stage_ratio_2, + 'backbone.layer3.2.conv1_(0, 256)_256': stage_ratio_1, + 'backbone.layer3.2.conv2_(0, 256)_256': stage_ratio_2, + 'backbone.layer3.3.conv1_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.3.conv2_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.4.conv1_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.4.conv2_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.5.conv1_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.5.conv2_(0, 256)_256': stage_ratio_4, + # block 3 [0.8, 0.8]*2+[0.8, 0.85]*2 downsample=[0.9] + 'backbone.layer4.0.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.0.conv2_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.0.conv3_(0, 2048)_2048': stage_ratio_3, + 'backbone.layer4.1.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.1.conv2_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.2.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.2.conv2_(0, 512)_512': stage_ratio_4 + # block 4 [0.85, 0.85] downsample=[0.9] +} + +model = dict( + _scope_='mmrazor', + type='DCFF', + architecture=dict( + cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False), + mutator_cfg=dict( + type='DCFFChannelMutator', + channel_unit_cfg=dict( + type='DCFFChannelUnit', + units='configs/pruning/mmpose/dcff/resnet_pose.json'), + parse_cfg=dict( + type='BackwardTracer', + loss_calculator=dict(type='TopdownPoseEstimatorPseudoLoss'))), + target_pruning_ratio=target_pruning_ratio, + step_freq=1, + linear_schedule=False, + is_deployed=False) + +dataset_type = 'CocoDataset' +data_mode = 'topdown' +data_root = 'data/coco' + +file_client_args = dict(backend='disk') + +train_pipeline = [ + dict(type='LoadImage', file_client_args=file_client_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict(type='RandomBBoxTransform'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='GenerateTarget', target_type='heatmap', encoder=codec), + dict(type='PackPoseInputs') +] + +test_pipeline = [ + dict(type='LoadImage', file_client_args=file_client_args), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='PackPoseInputs') +] + +train_dataloader = dict( + batch_size=32, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/person_keypoints_train2017.json', + data_prefix=dict(img='train2017/'), + pipeline=train_pipeline, + )) +val_dataloader = dict( + batch_size=32, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/person_keypoints_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + bbox_file='data/coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', + pipeline=test_pipeline, + )) +test_dataloader = val_dataloader + +model_wrapper = dict( + type='mmcv.MMDistributedDataParallel', find_unused_parameters=True) + +val_evaluator = dict( + type='mmpose.CocoMetric', + ann_file=data_root + 'annotations/person_keypoints_val2017.json') +test_evaluator = val_evaluator diff --git a/configs/pruning/mmpose/dcff/resnet_pose.json b/configs/pruning/mmpose/dcff/resnet_pose.json new file mode 100644 index 000000000..a08b40503 --- /dev/null +++ b/configs/pruning/mmpose/dcff/resnet_pose.json @@ -0,0 +1,509 @@ +{ + "backbone.conv1_(0, 3)_3":{ + "init_args":{ + "num_channels":3, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 3 + ], + "choice_mode":"number" + }, + "choice":3 + }, + "backbone.conv1_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 64 + ], + "choice_mode":"number" + }, + "choice":64 + }, + "backbone.layer1.0.conv1_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 51 + ], + "choice_mode":"number" + }, + "choice":51 + }, + "backbone.layer1.0.conv2_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 51 + ], + "choice_mode":"number" + }, + "choice":51 + }, + "backbone.layer1.0.conv3_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 230 + ], + "choice_mode":"number" + }, + "choice":230 + }, + "backbone.layer1.1.conv1_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 51 + ], + "choice_mode":"number" + }, + "choice":51 + }, + "backbone.layer1.1.conv2_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 51 + ], + "choice_mode":"number" + }, + "choice":51 + }, + "backbone.layer1.2.conv1_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 51 + ], + "choice_mode":"number" + }, + "choice":51 + }, + "backbone.layer1.2.conv2_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 51 + ], + "choice_mode":"number" + }, + "choice":51 + }, + "backbone.layer2.0.conv1_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 102 + ], + "choice_mode":"number" + }, + "choice":102 + }, + "backbone.layer2.0.conv2_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 102 + ], + "choice_mode":"number" + }, + "choice":102 + }, + "backbone.layer2.0.conv3_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 460 + ], + "choice_mode":"number" + }, + "choice":460 + }, + "backbone.layer2.1.conv1_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 102 + ], + "choice_mode":"number" + }, + "choice":102 + }, + "backbone.layer2.1.conv2_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 102 + ], + "choice_mode":"number" + }, + "choice":102 + }, + "backbone.layer2.2.conv1_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 102 + ], + "choice_mode":"number" + }, + "choice":102 + }, + "backbone.layer2.2.conv2_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 102 + ], + "choice_mode":"number" + }, + "choice":102 + }, + "backbone.layer2.3.conv1_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 102 + ], + "choice_mode":"number" + }, + "choice":102 + }, + "backbone.layer2.3.conv2_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 102 + ], + "choice_mode":"number" + }, + "choice":102 + }, + "backbone.layer3.0.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 204 + ], + "choice_mode":"number" + }, + "choice":204 + }, + "backbone.layer3.0.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 204 + ], + "choice_mode":"number" + }, + "choice":204 + }, + "backbone.layer3.0.conv3_(0, 1024)_1024":{ + "init_args":{ + "num_channels":1024, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 921 + ], + "choice_mode":"number" + }, + "choice":921 + }, + "backbone.layer3.1.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 204 + ], + "choice_mode":"number" + }, + "choice":204 + }, + "backbone.layer3.1.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 204 + ], + "choice_mode":"number" + }, + "choice":204 + }, + "backbone.layer3.2.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 204 + ], + "choice_mode":"number" + }, + "choice":204 + }, + "backbone.layer3.2.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 204 + ], + "choice_mode":"number" + }, + "choice":204 + }, + "backbone.layer3.3.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 217 + ], + "choice_mode":"number" + }, + "choice":217 + }, + "backbone.layer3.3.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 217 + ], + "choice_mode":"number" + }, + "choice":217 + }, + "backbone.layer3.4.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 217 + ], + "choice_mode":"number" + }, + "choice":217 + }, + "backbone.layer3.4.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 217 + ], + "choice_mode":"number" + }, + "choice":217 + }, + "backbone.layer3.5.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 217 + ], + "choice_mode":"number" + }, + "choice":217 + }, + "backbone.layer3.5.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 217 + ], + "choice_mode":"number" + }, + "choice":217 + }, + "backbone.layer4.0.conv1_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 435 + ], + "choice_mode":"number" + }, + "choice":435 + }, + "backbone.layer4.0.conv2_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 435 + ], + "choice_mode":"number" + }, + "choice":435 + }, + "backbone.layer4.0.conv3_(0, 2048)_2048":{ + "init_args":{ + "num_channels":2048, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 1843 + ], + "choice_mode":"number" + }, + "choice":1843 + }, + "backbone.layer4.1.conv1_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 435 + ], + "choice_mode":"number" + }, + "choice":435 + }, + "backbone.layer4.1.conv2_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 435 + ], + "choice_mode":"number" + }, + "choice":435 + }, + "backbone.layer4.2.conv1_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 435 + ], + "choice_mode":"number" + }, + "choice":435 + }, + "backbone.layer4.2.conv2_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 435 + ], + "choice_mode":"number" + }, + "choice":435 + }, + "backbone.layer4.2.conv3_(0, 2048)_2048":{ + "init_args":{ + "num_channels":2048, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 1843 + ], + "choice_mode":"number" + }, + "choice":1843 + } +} diff --git a/configs/pruning/mmseg/dcff/README.md b/configs/pruning/mmseg/dcff/README.md new file mode 100644 index 000000000..e6a2fb3e6 --- /dev/null +++ b/configs/pruning/mmseg/dcff/README.md @@ -0,0 +1,82 @@ +# Training Compact CNNs for Image Classification using Dynamic-coded Filter Fusion + +## Abstract + +The mainstream approach for filter pruning is usually either to force a hard-coded importance estimation upon a computation-heavy pretrained model to select “important” filters, or to impose a hyperparameter-sensitive sparse constraint on the loss objective to regularize the network training. In this paper, we present a novel filter pruning method, dubbed dynamic-coded filter fusion (DCFF), to derive compact CNNs in a computationeconomical and regularization-free manner for efficient image classification. Each filter in our DCFF is firstly given an intersimilarity distribution with a temperature parameter as a filter proxy, on top of which, a fresh Kullback-Leibler divergence based dynamic-coded criterion is proposed to evaluate the filter importance. In contrast to simply keeping high-score filters in other methods, we propose the concept of filter fusion, i.e., the weighted averages using the assigned proxies, as our preserved filters. We obtain a one-hot inter-similarity distribution as the temperature parameter approaches infinity. Thus, the relative importance of each filter can vary along with the training of the compact CNN, leading to dynamically changeable fused filters without both the dependency on the pretrained model and the introduction of sparse constraints. Extensive experiments on classification benchmarks demonstrate the superiority of our DCFF over the compared counterparts. For example, our DCFF derives a compact VGGNet-16 with only 72.77M FLOPs and 1.06M parameters while reaching top-1 accuracy of 93.47% on CIFAR-10. A compact ResNet-50 is obtained with 63.8% FLOPs and 58.6% parameter reductions, retaining 75.60% top1 accuracy on ILSVRC-2012. + +![pipeline](https://user-images.githubusercontent.com/31244134/189286581-722853ba-c6d7-4a39-b902-37995b444c71.jpg) + +## Results and models + +### 1. Classification + +| Dataset | Backbone | Params(M) | FLOPs(M) | lr_type | Top-1 (%) | Top-5 (%) | CPrate | Config | Download | +| :------: | :----------: | :-------: | :------: | :-----: | :-------: | :-------: | :---------------------------------------------: | :--------------------------------------------------: | :--------------------------: | +| ImageNet | DCFFResNet50 | 15.16 | 2260 | step | 73.96 | 91.66 | \[0.0\]+\[0.35,0.4,0.1\]\*10+\[0.3,0.3,0.1\]\*6 | [config](../../mmcls/dcff/dcff_resnet_8xb32_in1k.py) | [model](<>) \| \[log\] (\<>) | + +### 2. Detection + +| Dataset | Method | Backbone | Style | Lr schd | Params(M) | FLOPs(M) | bbox AP | CPrate | Config | Download | +| :-----: | :---------: | :----------: | :-----: | :-----: | :-------: | :------: | :-----: | :---------------------------------------------: | :---------------------------------------------------------------: | :--------------------------: | +| COCO | Faster_RCNN | DCFFResNet50 | pytorch | step | 33.31 | 168320 | 35.8 | \[0.0\]+\[0.35,0.4,0.1\]\*10+\[0.3,0.3,0.1\]\*6 | [config](../../mmdet/dcff/dcff_faster_rcnn_resnet50_8xb4_coco.py) | [model](<>) \| \[log\] (\<>) | + +### 3. Segmentation + +| Dataset | Method | Backbone | crop size | Lr schd | Params(M) | FLOPs(M) | mIoU | CPrate | Config | Download | +| :--------: | :-------: | :-------------: | :-------: | :-----: | :-------: | :------: | :---: | :-----------------------------------------------------------------: | :-------------------------------------------------------------------: | :--------------------------: | +| Cityscapes | PointRend | DCFFResNetV1c50 | 512x1024 | 160k | 18.43 | 74410 | 76.75 | \[0.0, 0.0, 0.0\] + \[0.35, 0.4, 0.1\] * 10 + \[0.3, 0.3, 0.1\] * 6 | [config](../../mmseg/dcff/dcff_pointrend_resnet50_8xb2_cityscapes.py) | [model](<>) \| \[log\] (\<>) | + +### 4. Pose + +| Dataset | Method | Backbone | crop size | total epochs | Params(M) | FLOPs(M) | AP | CPrate | Config | Download | +| :-----: | :-------------: | :----------: | :-------: | :----------: | :-------: | :------: | :--: | :--------------------------------------------------------: | :---------------------------------------------------------------: | :--------------------------: | +| COCO | TopDown HeatMap | DCFFResNet50 | 256x192 | 300 | 26.95 | 4290 | 68.3 | \[0.0\] + \[0.2, 0.2, 0.1\] * 10 + \[0.15, 0.15, 0.1\] * 6 | [config](../../mmpose/dcff/dcff_topdown_heatmap_resnet50_coco.py) | [model](<>) \| \[log\] (\<>) | + +## Citation + +```latex +@article{lin2021training, + title={Training Compact CNNs for Image Classification using Dynamic-coded Filter Fusion}, + author={Lin, Mingbao and Ji, Rongrong and Chen, Bohong and Chao, Fei and Liu, Jianzhuang and Zeng, Wei and Tian, Yonghong and Tian, Qi}, + journal={arXiv preprint arXiv:2107.06916}, + year={2021} +} +``` + +## Getting Started + +### Generate channel_config file + +Generate `resnet_seg.json` with `tools/get_channel_units.py`. + +```bash +python tools/get_channel_units.py + configs/pruning/mmseg/dcff/dcff_pointrend_resnet50_8xb2_cityscapes.py \ + -c -i --output-path=configs/pruning/mmseg/dcff/resnet_seg.json +``` + +Then set layers' pruning rates `target_pruning_ratio` by `resnet_seg.json`. + +### Train DCFF + +#### Segmentation + +##### Citpscapes + +```bash +sh tools/slurm_train.sh $PARTITION $JOB_NAME \ + configs/pruning/mmseg/dcff/dcff_pointrend_resnet50_8xb2_cityscapes.py \ + $WORK_DIR +``` + +### Test DCFF + +#### Segmentation + +##### Citpscapes + +```bash +sh tools/slurm_test.sh $PARTITION $JOB_NAME \ + configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py \ + $WORK_DIR +``` diff --git a/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py b/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py new file mode 100644 index 000000000..2914b7d84 --- /dev/null +++ b/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py @@ -0,0 +1,5 @@ +_base_ = ['dcff_pointrend_resnet50_8xb2_cityscapes.py'] + +# model settings +model = _base_.model +model['is_deployed'] = True diff --git a/configs/pruning/mmseg/dcff/dcff_pointrend_resnet50_8xb2_cityscapes.py b/configs/pruning/mmseg/dcff/dcff_pointrend_resnet50_8xb2_cityscapes.py new file mode 100644 index 000000000..8e4b7f342 --- /dev/null +++ b/configs/pruning/mmseg/dcff/dcff_pointrend_resnet50_8xb2_cityscapes.py @@ -0,0 +1,99 @@ +_base_ = [ + # TODO: use autoaug pipeline. + 'mmseg::_base_/datasets/cityscapes.py', + 'mmseg::_base_/schedules/schedule_160k.py', + 'mmseg::_base_/default_runtime.py', + './pointrend_resnet50.py' +] + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005), + clip_grad=dict(max_norm=25, norm_type=2), + _delete_=True) +train_cfg = dict(type='IterBasedTrainLoop', max_iters=160000, val_interval=800) + +param_scheduler = [ + # warm up + dict(type='LinearLR', by_epoch=False, start_factor=0.1, begin=0, end=200), + dict( + type='PolyLR', + eta_min=1e-4, + power=0.9, + begin=200, + end=80000, + by_epoch=False, + ) +] + +stage_ratio_1 = 0.65 +stage_ratio_2 = 0.6 +stage_ratio_3 = 0.9 +stage_ratio_4 = 0.7 + +# the config template of target_pruning_ratio can be got by +# python ./tools/get_channel_units.py {config_file} --choice +target_pruning_ratio = { + 'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.0.conv2_(0, 64)_64': stage_ratio_2, + 'backbone.layer1.0.conv3_(0, 256)_256': stage_ratio_3, + 'backbone.layer1.1.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.1.conv2_(0, 64)_64': stage_ratio_2, + 'backbone.layer1.2.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.2.conv2_(0, 64)_64': stage_ratio_2, + # block 1 [0.8, 0.8] downsample=[0.9] + 'backbone.layer2.0.conv1_(0, 128)_128': stage_ratio_1, + 'backbone.layer2.0.conv2_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.0.conv3_(0, 512)_512': stage_ratio_3, + 'backbone.layer2.1.conv1_(0, 128)_128': stage_ratio_1, + 'backbone.layer2.1.conv2_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.2.conv1_(0, 128)_128': stage_ratio_1, + 'backbone.layer2.2.conv2_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.3.conv1_(0, 128)_128': stage_ratio_1, + 'backbone.layer2.3.conv2_(0, 128)_128': stage_ratio_2, + # block 2 [0.8, 0.8] downsample=[0.9] + 'backbone.layer3.0.conv1_(0, 256)_256': stage_ratio_1, + 'backbone.layer3.0.conv2_(0, 256)_256': stage_ratio_2, + 'backbone.layer3.0.conv3_(0, 1024)_1024': stage_ratio_3, + 'backbone.layer3.1.conv1_(0, 256)_256': stage_ratio_1, + 'backbone.layer3.1.conv2_(0, 256)_256': stage_ratio_2, + 'backbone.layer3.2.conv1_(0, 256)_256': stage_ratio_1, + 'backbone.layer3.2.conv2_(0, 256)_256': stage_ratio_2, + 'backbone.layer3.3.conv1_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.3.conv2_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.4.conv1_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.4.conv2_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.5.conv1_(0, 256)_256': stage_ratio_4, + 'backbone.layer3.5.conv2_(0, 256)_256': stage_ratio_4, + # block 3 [0.8, 0.8]*2+[0.8, 0.85]*2 downsample=[0.9] + 'backbone.layer4.0.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.0.conv2_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.0.conv3_(0, 2048)_2048': stage_ratio_3, + 'backbone.layer4.1.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.1.conv2_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.2.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.2.conv2_(0, 512)_512': stage_ratio_4 + # block 4 [0.85, 0.85] downsample=[0.9] +} + +# model settings +model = dict( + _scope_='mmrazor', + type='DCFF', + architecture=dict( + cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False), + mutator_cfg=dict( + type='DCFFChannelMutator', + channel_unit_cfg=dict( + type='DCFFChannelUnit', + units='configs/pruning/mmseg/dcff/resnet_seg.json'), + parse_cfg=dict( + type='BackwardTracer', + loss_calculator=dict(type='CascadeEncoderDecoderPseudoLoss'))), + target_pruning_ratio=target_pruning_ratio, + step_freq=200, + linear_schedule=False, + is_deployed=False) + +model_wrapper = dict( + type='mmcv.MMDistributedDataParallel', find_unused_parameters=True) diff --git a/configs/pruning/mmseg/dcff/pointrend_resnet50.py b/configs/pruning/mmseg/dcff/pointrend_resnet50.py new file mode 100644 index 000000000..816ec8386 --- /dev/null +++ b/configs/pruning/mmseg/dcff/pointrend_resnet50.py @@ -0,0 +1,63 @@ +data_preprocessor = dict( + _scope_='mmseg', + type='SegDataPreProcessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + size=(512, 1024), + pad_val=0, + seg_pad_val=255) +architecture = dict( + _scope_='mmseg', + type='CascadeEncoderDecoder', + data_preprocessor=data_preprocessor, + num_stages=2, + pretrained=None, + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_eval=False, + style='pytorch', + contract_dilation=True), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + decode_head=[ + dict( + type='FPNHead', + in_channels=[256, 256, 256, 256], + in_index=[0, 1, 2, 3], + feature_strides=[4, 8, 16, 32], + channels=128, + dropout_ratio=-1, + num_classes=19, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='PointHead', + in_channels=[256], + in_index=[0], + channels=256, + num_fcs=3, + coarse_pred_each_layer=True, + dropout_ratio=-1, + num_classes=19, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + ], + # model training and testing settings + train_cfg=dict( + num_points=2048, oversample_ratio=3, importance_sample_ratio=0.75), + test_cfg=dict( + mode='whole', + subdivision_steps=2, + subdivision_num_points=8196, + scale_factor=2)) diff --git a/configs/pruning/mmseg/dcff/resnet_seg.json b/configs/pruning/mmseg/dcff/resnet_seg.json new file mode 100644 index 000000000..317fba020 --- /dev/null +++ b/configs/pruning/mmseg/dcff/resnet_seg.json @@ -0,0 +1,496 @@ +{ + "backbone.conv1_(0, 3)_3":{ + "init_args":{ + "num_channels":3, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 3 + ], + "choice_mode":"number" + }, + "choice":3 + }, + "backbone.conv1_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 64 + ], + "choice_mode":"number" + }, + "choice":64 + }, + "backbone.layer1.0.conv1_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 41 + ], + "choice_mode":"number" + }, + "choice":41 + }, + "backbone.layer1.0.conv2_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 38 + ], + "choice_mode":"number" + }, + "choice":38 + }, + "backbone.layer1.0.conv3_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 230 + ], + "choice_mode":"number" + }, + "choice":230 + }, + "backbone.layer1.1.conv1_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 41 + ], + "choice_mode":"number" + }, + "choice":41 + }, + "backbone.layer1.1.conv2_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 38 + ], + "choice_mode":"number" + }, + "choice":38 + }, + "backbone.layer1.2.conv1_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 41 + ], + "choice_mode":"number" + }, + "choice":41 + }, + "backbone.layer1.2.conv2_(0, 64)_64":{ + "init_args":{ + "num_channels":64, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 38 + ], + "choice_mode":"number" + }, + "choice":38 + }, + "backbone.layer2.0.conv1_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 83 + ], + "choice_mode":"number" + }, + "choice":83 + }, + "backbone.layer2.0.conv2_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 76 + ], + "choice_mode":"number" + }, + "choice":76 + }, + "backbone.layer2.0.conv3_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 460 + ], + "choice_mode":"number" + }, + "choice":460 + }, + "backbone.layer2.1.conv1_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 83 + ], + "choice_mode":"number" + }, + "choice":83 + }, + "backbone.layer2.1.conv2_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 76 + ], + "choice_mode":"number" + }, + "choice":76 + }, + "backbone.layer2.2.conv1_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 83 + ], + "choice_mode":"number" + }, + "choice":83 + }, + "backbone.layer2.2.conv2_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 76 + ], + "choice_mode":"number" + }, + "choice":76 + }, + "backbone.layer2.3.conv1_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 83 + ], + "choice_mode":"number" + }, + "choice":83 + }, + "backbone.layer2.3.conv2_(0, 128)_128":{ + "init_args":{ + "num_channels":128, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 76 + ], + "choice_mode":"number" + }, + "choice":76 + }, + "backbone.layer3.0.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 166 + ], + "choice_mode":"number" + }, + "choice":166 + }, + "backbone.layer3.0.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 153 + ], + "choice_mode":"number" + }, + "choice":153 + }, + "backbone.layer3.0.conv3_(0, 1024)_1024":{ + "init_args":{ + "num_channels":1024, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 921 + ], + "choice_mode":"number" + }, + "choice":921 + }, + "backbone.layer3.1.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 166 + ], + "choice_mode":"number" + }, + "choice":166 + }, + "backbone.layer3.1.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 153 + ], + "choice_mode":"number" + }, + "choice":153 + }, + "backbone.layer3.2.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 166 + ], + "choice_mode":"number" + }, + "choice":166 + }, + "backbone.layer3.2.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 153 + ], + "choice_mode":"number" + }, + "choice":153 + }, + "backbone.layer3.3.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer3.3.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer3.4.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer3.4.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer3.5.conv1_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer3.5.conv2_(0, 256)_256":{ + "init_args":{ + "num_channels":256, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 179 + ], + "choice_mode":"number" + }, + "choice":179 + }, + "backbone.layer4.0.conv1_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "backbone.layer4.0.conv2_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "backbone.layer4.0.conv3_(0, 2048)_2048":{ + "init_args":{ + "num_channels":2048, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 1843 + ], + "choice_mode":"number" + }, + "choice":1843 + }, + "backbone.layer4.1.conv1_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "backbone.layer4.1.conv2_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "backbone.layer4.2.conv1_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + }, + "backbone.layer4.2.conv2_(0, 512)_512":{ + "init_args":{ + "num_channels":512, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 358 + ], + "choice_mode":"number" + }, + "choice":358 + } +} diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index d7e4dc3af..00b2b9c0a 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -5,7 +5,7 @@ SelfDistill, SingleTeacherDistill) from .nas import (DSNAS, DSNASDDP, SPOS, Autoformer, AutoSlim, AutoSlimDDP, Darts, DartsDDP) -from .pruning import SlimmableNetwork, SlimmableNetworkDDP +from .pruning import DCFF, SlimmableNetwork, SlimmableNetworkDDP from .pruning.ite_prune_algorithm import ItePruneAlgorithm __all__ = [ @@ -19,6 +19,7 @@ 'AutoSlimDDP', 'Darts', 'DartsDDP', + 'DCFF', 'SelfDistill', 'DataFreeDistillation', 'DAFLDataFreeDistillation', diff --git a/mmrazor/models/algorithms/pruning/__init__.py b/mmrazor/models/algorithms/pruning/__init__.py index 0b426146b..ea7d77901 100644 --- a/mmrazor/models/algorithms/pruning/__init__.py +++ b/mmrazor/models/algorithms/pruning/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .dcff import DCFF from .slimmable_network import SlimmableNetwork, SlimmableNetworkDDP -__all__ = ['SlimmableNetwork', 'SlimmableNetworkDDP'] +__all__ = ['SlimmableNetwork', 'SlimmableNetworkDDP', 'DCFF'] diff --git a/mmrazor/models/algorithms/pruning/dcff.py b/mmrazor/models/algorithms/pruning/dcff.py new file mode 100644 index 000000000..12c827556 --- /dev/null +++ b/mmrazor/models/algorithms/pruning/dcff.py @@ -0,0 +1,172 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmengine import MMLogger +from mmengine.model import BaseModel +from mmengine.structures import BaseDataElement + +from mmrazor.models.mutables import BaseMutable +from mmrazor.models.mutators import DCFFChannelMutator +from mmrazor.registry import MODELS +from mmrazor.structures.subnet.fix_subnet import _dynamic_to_static +from .ite_prune_algorithm import ItePruneAlgorithm, ItePruneConfigManager + +LossResults = Dict[str, torch.Tensor] +TensorResults = Union[Tuple[torch.Tensor], torch.Tensor] +PredictResults = List[BaseDataElement] +ForwardResults = Union[LossResults, TensorResults, PredictResults] + + +@MODELS.register_module() +class DCFF(ItePruneAlgorithm): + """DCFF Networks. + + Please refer to paper + [Dynamic-coded Filter Fusion](https://arxiv.org/abs/2107.06916). + + Args: + architecture (Union[BaseModel, Dict]): The model to be pruned. + mutator_cfg (Union[Dict, ChannelMutator], optional): The config + of a mutator. Defaults to dict( type='ChannelMutator', + channel_unit_cfg=dict( type='SequentialMutableChannelUnit')). + data_preprocessor (Optional[Union[Dict, nn.Module]], optional): + Defaults to None. + target_pruning_ratio (dict, optional): The prune-target. The template + of the prune-target can be get by calling + mutator.choice_template(). Defaults to {}. + step_freq (int, optional): The step between two pruning operations. + Defaults to 1. Legal input includes [1, self._max_iters] + One and only one of (step_freq, prune_times) is set to legal int. + prune_times (int, optional): The total times to prune a model. + Defaults to 0. Legal input includes [1, self._max_iters] + One and only one of (step_freq, prune_times) is set to legal int. + init_cfg (Optional[Dict], optional): init config for architecture. + Defaults to None. + linear_schedule (bool, optional): flag to set linear ratio schedule. + Defaults to False due to dcff fixed pruning rate. + is_deployed (bool, optional): flag to set deployed algorithm. + Defaults to False. + """ + + def __init__(self, + architecture: Union[BaseModel, Dict], + mutator_cfg: Union[Dict, DCFFChannelMutator] = dict( + type=' DCFFChannelMutator', + channel_unit_cfg=dict(type='DCFFChannelUnit')), + data_preprocessor: Optional[Union[Dict, nn.Module]] = None, + target_pruning_ratio: Optional[Dict[str, float]] = None, + step_freq=1, + prune_times=0, + init_cfg: Optional[Dict] = None, + linear_schedule=False, + is_deployed=False) -> None: + # invalid param prune_times, reset after message_hub get [max_epoch] + super().__init__(architecture, mutator_cfg, data_preprocessor, + target_pruning_ratio, step_freq, prune_times, + init_cfg, linear_schedule) + self.is_deployed = is_deployed + if (self.is_deployed): + # To static ops for loaded pruned network. + self._deploy() + + def _fix_archtecture(self): + for module in self.architecture.modules(): + if isinstance(module, BaseMutable): + if not module.is_fixed: + module.fix_chosen(None) + + def _deploy(self): + config = self.prune_config_manager.prune_at(self._iter) + self.mutator.set_choices(config) + self.mutator.fix_channel_mutables() + self._fix_archtecture() + _dynamic_to_static(self.architecture) + self.is_deployed = True + + def _calc_temperature(self, cur_num: int, max_num: int): + """Calculate temperature param.""" + # Set the fixed parameters required to calculate the temperature t + t_s, t_e, k = 1, 10000, 1 + + A = 2 * (t_e - t_s) * (1 + math.exp(-k * max_num)) / ( + 1 - math.exp(-k * max_num)) + T = A / (1 + math.exp(-k * cur_num)) + t_s - A / 2 + t = 1 / T + return t + + def _legal_freq_time(self, freq_time): + """check whether step_freq or prune_times belongs to legal range: + + [1, self._max_iters] + + Args: + freq_time (Int): step_freq or prune_times. + """ + return (freq_time > 0) and (freq_time < self._max_iters) + + def _init_prune_config_manager(self): + """init prune_config_manager and check step_freq & prune_times. + + In DCFF, prune_times is set by step_freq and self._max_iters. + """ + if self.target_pruning_ratio is None: + group_target_ratio = self.mutator.current_choices + else: + group_target_ratio = self.group_target_pruning_ratio( + self.target_pruning_ratio, self.mutator.search_groups) + + if self.by_epoch: + # step_freq based on iterations + self.step_freq *= self._iters_per_epoch + + if self._legal_freq_time(self.step_freq) ^ self._legal_freq_time( + self.prune_times): + if self._legal_freq_time(self.step_freq): + self.prune_times = self._max_iters // self.step_freq + else: + self.step_freq = self._max_iters // self.prune_times + else: + raise RuntimeError('One and only one of (step_freq, prune_times)' + 'can be set to legal int.') + + # config_manager move to forward. + # message_hub['max_epoch'] unaccessible when init + prune_config_manager = ItePruneConfigManager( + group_target_ratio, + self.mutator.current_choices, + self.step_freq, + prune_times=self.prune_times, + linear_schedule=self.linear_schedule) + + return prune_config_manager + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + mode: str = 'tensor') -> ForwardResults: + """Forward.""" + # In DCFF prune_message is related to total_num + # Set self.prune_config_manager after message_hub has['max_epoch/iter'] + if not hasattr(self, 'prune_config_manager'): + # iter num per epoch only available after initiation + self.prune_config_manager = self._init_prune_config_manager() + if self.prune_config_manager.is_prune_time(self._iter): + config = self.prune_config_manager.prune_at(self._iter) + self.mutator.set_choices(config) + + # calc fusion channel + temperature = self._calc_temperature(self._iter, self._max_iters) + self.mutator.calc_information(temperature) + + logger = MMLogger.get_current_instance() + if (self.by_epoch): + logger.info( + f'The model is pruned at {self._epoch}th epoch once.') + else: + logger.info( + f'The model is pruned at {self._iter}th iter once.') + + return super().forward(inputs, data_samples, mode) diff --git a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py index 4b592740a..d0aab73fd 100644 --- a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py +++ b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py @@ -26,46 +26,54 @@ class ItePruneConfigManager: target (Dict[str, Union[int, float]]): The target structure to prune. supernet (Dict[str, Union[int, float]]): The sturecture of the supernet. - epoch_step (int, optional): The prune step to prune. Defaults to 1. - times (int, optional): The times to prune. Defaults to 1. + step_freq (int, optional): The prune step of epoch/iter to prune. + Defaults to 1. + prune_times (int, optional): The times to prune. Defaults to 1. + linear_schedule (bool, optional): flag to set linear ratio schedule. + Defaults to True. """ def __init__(self, target: Dict[str, Union[int, float]], supernet: Dict[str, Union[int, float]], - epoch_step=1, - times=1) -> None: + step_freq=1, + prune_times=1, + linear_schedule=True) -> None: self.supernet = supernet self.target = target - self.epoch_step = epoch_step - self.prune_times = times + self.step_freq = step_freq + self.prune_times = prune_times + self.linear_schedule = linear_schedule - self.delta: Dict = self._get_delta_each_epoch(self.target, - self.supernet, - self.prune_times) + self.delta: Dict = self._get_delta_each_iter(self.target, + self.supernet, + self.prune_times) - def is_prune_time(self, epoch, ite): + def is_prune_time(self, iteration): """Is the time to prune during training process.""" - return epoch % self.epoch_step == 0 \ - and epoch//self.epoch_step < self.prune_times \ - and ite == 0 + return iteration % self.step_freq == 0 \ + and iteration // self.step_freq < self.prune_times - def prune_at(self, epoch): - """Get the pruning structure in a time(epoch).""" - times = epoch // self.epoch_step + 1 + def prune_at(self, iteration): + """Get the pruning structure in a time(iteration).""" + times = iteration // self.step_freq + 1 assert times <= self.prune_times prune_current = {} ratio = times / self.prune_times for key in self.target: - prune_current[key] = (self.target[key] - self.supernet[key] - ) * ratio + self.supernet[key] + if self.linear_schedule: + # TO DO: add scheduler for more pruning rate schedule + prune_current[key] = (self.target[key] - self.supernet[key] + ) * ratio + self.supernet[key] + else: + prune_current[key] = self.target[key] if isinstance(self.supernet[key], int): prune_current[key] = int(prune_current[key]) return prune_current - def _get_delta_each_epoch(self, target: Dict, supernet: Dict, times: int): + def _get_delta_each_iter(self, target: Dict, supernet: Dict, times: int): """Get the structure change for pruning once.""" delta = {} for key in target: @@ -94,11 +102,14 @@ class ItePruneAlgorithm(BaseAlgorithm): target_pruning_ratio (dict, optional): The prune-target. The template of the prune-target can be get by calling mutator.choice_template(). Defaults to {}. - step_epoch (int, optional): The step between two pruning operations. + step_freq (int, optional): The step between two pruning operations. + Defaults to 1. + prune_times (int, optional): The total times to prune a model. Defaults to 1. - prune_times (int, optional): The times to prune a model. Defaults to 1. init_cfg (Optional[Dict], optional): init config for architecture. Defaults to None. + linear_schedule (bool, optional): flag to set linear ratio schedule. + Defaults to True. """ def __init__(self, @@ -109,29 +120,23 @@ def __init__(self, type='SequentialMutableChannelUnit')), data_preprocessor: Optional[Union[Dict, nn.Module]] = None, target_pruning_ratio: Optional[Dict[str, float]] = None, - step_epoch=1, - prune_times=1, - init_cfg: Optional[Dict] = None) -> None: + step_freq=-1, + prune_times=-1, + init_cfg: Optional[Dict] = None, + linear_schedule=True) -> None: super().__init__(architecture, data_preprocessor, init_cfg) + # decided by EpochBasedRunner or IterBasedRunner + self.target_pruning_ratio = target_pruning_ratio + self.step_freq = step_freq + self.prune_times = prune_times + self.linear_schedule = linear_schedule + # mutator self.mutator: ChannelMutator = MODELS.build(mutator_cfg) self.mutator.prepare_from_supernet(self.architecture) - if target_pruning_ratio is None: - group_target_ratio = self.mutator.current_choices - else: - group_target_ratio = self.group_target_pruning_ratio( - target_pruning_ratio, self.mutator.search_groups) - - # config_manager - self.prune_config_manager = ItePruneConfigManager( - group_target_ratio, - self.mutator.current_choices, - step_epoch, - times=prune_times) - def group_target_pruning_ratio( self, target: Dict[str, float], search_groups: Dict[int, @@ -158,7 +163,6 @@ def group_target_pruning_ratio( unit_target = target[unit_name] assert isinstance(unit_target, (float, int)) group_target[group_id] = unit_target - return group_target def check_prune_target(self, config: Dict): @@ -166,21 +170,54 @@ def check_prune_target(self, config: Dict): for value in config.values(): assert isinstance(value, int) or isinstance(value, float) + def _init_prune_config_manager(self): + """init prune_config_manager and check step_freq & prune_times. + + message_hub['max_epoch/iter'] unaccessible when initiation. + """ + if self.target_pruning_ratio is None: + group_target_ratio = self.mutator.current_choices + else: + group_target_ratio = self.group_target_pruning_ratio( + self.target_pruning_ratio, self.mutator.search_groups) + + if self.by_epoch: + # step_freq based on iterations + self.step_freq *= self._iters_per_epoch + + # config_manager move to forward. + # message_hub['max_epoch'] unaccessible when init + prune_config_manager = ItePruneConfigManager( + group_target_ratio, + self.mutator.current_choices, + self.step_freq, + prune_times=self.prune_times, + linear_schedule=self.linear_schedule) + + return prune_config_manager + def forward(self, inputs: torch.Tensor, data_samples: Optional[List[BaseDataElement]] = None, mode: str = 'tensor') -> ForwardResults: """Forward.""" - print(self._epoch, self._iteration) - if self.prune_config_manager.is_prune_time(self._epoch, - self._iteration): + if not hasattr(self, 'prune_config_manager'): + # self._iters_per_epoch() only available after initiation + self.prune_config_manager = self._init_prune_config_manager() + + if self.prune_config_manager.is_prune_time(self._iter): - config = self.prune_config_manager.prune_at(self._epoch) + config = self.prune_config_manager.prune_at(self._iter) self.mutator.set_choices(config) logger = MMLogger.get_current_instance() - logger.info(f'The model is pruned at {self._epoch}th epoch once.') + if (self.by_epoch): + logger.info( + f'The model is pruned at {self._epoch}th epoch once.') + else: + logger.info( + f'The model is pruned at {self._iter}th iter once.') return super().forward(inputs, data_samples, mode) @@ -189,6 +226,13 @@ def init_weights(self): # private methods + @property + def by_epoch(self): + """Get epoch/iter based train loop.""" + # IterBasedTrainLoop max_epochs default to 1 + # TO DO: Add by_epoch params or change default max_epochs? + return self._max_epochs != 1 + @property def _epoch(self): """Get current epoch number.""" @@ -196,16 +240,43 @@ def _epoch(self): if 'epoch' in message_hub.runtime_info: return message_hub.runtime_info['epoch'] else: - return 0 + raise RuntimeError('Use MessageHub before initiation.' + 'epoch is inited in before_run_epoch().') @property - def _iteration(self): - """Get current iteration number.""" + def _iter(self): + """Get current sum iteration number.""" message_hub = MessageHub.get_current_instance() if 'iter' in message_hub.runtime_info: - iter = message_hub.runtime_info['iter'] - max_iter = message_hub.runtime_info['max_iters'] - max_epoch = message_hub.runtime_info['max_epochs'] - return iter % (max_iter // max_epoch) + return message_hub.runtime_info['iter'] else: - return 0 + raise RuntimeError('Use MessageHub before initiation.' + 'iter is inited in before_run_iter().') + + @property + def _max_epochs(self): + """Get max epoch number. + + Default 1 for IterTrainLoop + """ + message_hub = MessageHub.get_current_instance() + if 'max_epochs' in message_hub.runtime_info: + return message_hub.runtime_info['max_epochs'] + else: + raise RuntimeError('Use MessageHub before initiation.' + 'max_epochs is inited in before_run_epoch().') + + @property + def _max_iters(self): + """Get max iteration number.""" + message_hub = MessageHub.get_current_instance() + if 'max_iters' in message_hub.runtime_info: + return message_hub.runtime_info['max_iters'] + else: + raise RuntimeError('Use MessageHub before initiation.' + 'max_iters is inited in before_run_iter().') + + @property + def _iters_per_epoch(self): + """Get iter num per epoch.""" + return self._max_iters / self._max_epochs diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/__init__.py b/mmrazor/models/architectures/dynamic_ops/bricks/__init__.py index e9cde3f7d..0ae0438e9 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/__init__.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .dynamic_container import DynamicSequential -from .dynamic_conv import BigNasConv2d, DynamicConv2d, OFAConv2d +from .dynamic_conv import BigNasConv2d, DynamicConv2d, FuseConv2d, OFAConv2d from .dynamic_embed import DynamicPatchEmbed from .dynamic_linear import DynamicLinear from .dynamic_multi_head_attention import DynamicMultiheadAttention @@ -13,6 +13,6 @@ 'BigNasConv2d', 'DynamicConv2d', 'OFAConv2d', 'DynamicLinear', 'DynamicBatchNorm1d', 'DynamicBatchNorm2d', 'DynamicBatchNorm3d', 'SwitchableBatchNorm2d', 'DynamicSequential', 'DynamicPatchEmbed', - 'DynamicLayerNorm', 'DynamicRelativePosition2D', + 'DynamicLayerNorm', 'DynamicRelativePosition2D', 'FuseConv2d', 'DynamicMultiheadAttention' ] diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py index 71fc7ab98..d1e10c4d0 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py @@ -8,7 +8,7 @@ from mmrazor.models.mutables.base_mutable import BaseMutable from mmrazor.registry import MODELS from ..mixins.dynamic_conv_mixins import (BigNasConvMixin, DynamicConvMixin, - OFAConvMixin) + FuseConvMixin, OFAConvMixin) GroupWiseConvWarned = False @@ -188,3 +188,50 @@ def static_op_factory(self): def forward(self, x: Tensor) -> Tensor: """Forward of OFA's conv2d.""" return self.forward_mixin(x) + + +@MODELS.register_module() +class FuseConv2d(nn.Conv2d, FuseConvMixin): + """FuseConv2d used in `DCFF`. + + Refers to `Training Compact CNNs for Image Classification + using Dynamic-coded Filter Fusion `_. + Attributes: + mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes, + such as `in_channels`. The key of the dict must in + ``accepted_mutable_attrs``. + """ + accepted_mutable_attrs = {'in_channels', 'out_channels'} + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict() + + @classmethod + def convert_from(cls, module: nn.Conv2d) -> 'FuseConv2d': + """Convert an instance of `nn.Conv2d` to a new instance of + `FuseConv2d`.""" + return cls( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + bias=True if module.bias is not None else False, + padding_mode=module.padding_mode) + + @property + def conv_func(self) -> Callable: + """The function that will be used in ``forward_mixin``.""" + return F.conv2d + + @property + def static_op_factory(self): + """Corresponding Pytorch OP.""" + return nn.Conv2d + + def forward(self, x: Tensor) -> Tensor: + """Forward of fused conv2d.""" + return self.forward_mixin(x) diff --git a/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_conv_mixins.py b/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_conv_mixins.py index e3ed46ded..82cfcf390 100644 --- a/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_conv_mixins.py +++ b/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_conv_mixins.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import abstractmethod +from functools import partial from itertools import repeat -from typing import Callable, Iterable, Optional, Tuple +from typing import Any, Callable, Iterable, Optional, Tuple import torch import torch.nn.functional as F @@ -11,6 +12,8 @@ from mmrazor.models.mutables.base_mutable import BaseMutable from .dynamic_mixins import DynamicChannelMixin +PartialType = Callable[[Any, Optional[nn.Parameter]], Any] + def _ntuple(n: int) -> Callable: # pragma: no cover """Repeat a number n times.""" @@ -317,16 +320,6 @@ def _get_dynamic_params_by_mutable_kernel_size( return current_weight, current_padding - def forward_mixin(self: _ConvNd, x: Tensor) -> Tensor: - """Forward of dynamic conv2d OP.""" - groups = self.groups - if self.groups == self.in_channels == self.out_channels: - groups = x.size(1) - weight, bias, padding = self.get_dynamic_params() - - return self.conv_func(x, weight, bias, self.stride, padding, - self.dilation, groups) - class OFAConvMixin(BigNasConvMixin): """A mixin class for Pytorch conv, which can mutate ``in_channels``, @@ -404,3 +397,161 @@ def _get_dynamic_params_by_mutable_kernel_size( current_weight = target_weight return current_weight, current_padding + + +class FuseConvMixin(DynamicConvMixin): + """A mixin class for fuse conv, which can mutate ``in_channels``, + ``out_channels`` .""" + + def set_forward_args(self, choice: Tensor) -> None: + """Interface for modifying the arch_param using partial.""" + param_channel_with_default_args: PartialType = \ + partial( + self._get_dynamic_params_by_mutable_channels_choice, + choice=choice) + setattr(self, '_get_dynamic_params_by_mutable_channels', + param_channel_with_default_args) + + def get_dynamic_params( + self: _ConvNd) -> Tuple[Tensor, Optional[Tensor], Tuple[int]]: + """Get dynamic parameters that will be used in forward process. + + Returns: + Tuple[Tensor, Optional[Tensor], Tuple[int]]: Sliced weight, bias + and padding. + """ + # slice in/out channel of weight according to mutable in_channels + # and mutable out channels. + weight, bias = self._get_dynamic_params_by_mutable_channels( + self.weight, self.bias) + return weight, bias, self.padding + + def _get_dynamic_params_by_mutable_channels_choice( + self: _ConvNd, weight: Tensor, bias: Optional[Tensor], + choice: Tensor) -> Tuple[Tensor, Optional[Tensor]]: + """Get sliced weight and bias according to ``mutable_in_channels`` and + ``mutable_out_channels``. + + Returns: + Tuple[Tensor, Optional[Tensor]]: Sliced weight and bias. + """ + + mutable_in_channels = 0 + mutable_out_channels = 0 + + if 'in_channels' in self.mutable_attrs: + mutable_in_channels = self.mutable_attrs[ + 'in_channels'].current_mask.sum().item() + + if 'out_channels' in self.mutable_attrs: + mutable_out_channels = self.mutable_attrs[ + 'out_channels'].current_mask.sum().item() + + if mutable_in_channels == 0: + mutable_in_channels = self.in_channels + if mutable_out_channels == 0: + mutable_out_channels = self.out_channels + + # if channel not in mutable_attrs or unchanged + if mutable_in_channels == self.in_channels and \ + mutable_out_channels == self.out_channels: + return weight, bias + + weight = self.weight[:, 0:mutable_in_channels, :, :] + if self.groups == 1: + cout, cin, k, _ = weight.shape + fused_weight = torch.mm(choice, + weight.reshape(cout, + -1)).reshape(-1, cin, k, k) + elif self.groups == self.in_channels == self.out_channels: + # depth-wise conv + cout, cin, k, _ = weight.shape + fused_weight = torch.mm(choice, + weight.reshape(cout, + -1)).reshape(-1, cin, k, k) + else: + raise NotImplementedError( + 'Current `ChannelMutator` only support pruning the depth-wise ' + '`nn.Conv2d` or `nn.Conv2d` module whose group number equals ' + f'to one, but got {self.groups}.') + if (self.bias is not None): + fused_bias = torch.mm(choice, self.bias.unsqueeze(1)).squeeze(1) + else: + fused_bias = self.bias + return fused_weight, fused_bias + + def to_static_op(self: _ConvNd) -> nn.Conv2d: + """Convert dynamic conv2d to :obj:`torch.nn.Conv2d`. + + Returns: + torch.nn.Conv2d: :obj:`torch.nn.Conv2d` with sliced parameters. + """ + self.check_if_mutables_fixed() + + weight, bias, padding = self.get_dynamic_params() + groups = self.groups + if groups == self.in_channels == self.out_channels and \ + self.mutable_in_channels is not None: + mutable_in_channels = self.mutable_attrs['in_channels'] + groups = mutable_in_channels.current_mask.sum().item() + out_channels = weight.size(0) + in_channels = weight.size(1) * groups + + kernel_size = tuple(weight.shape[2:]) + + static_conv = self.static_op_factory( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=self.stride, + padding=padding, + padding_mode=self.padding_mode, + dilation=self.dilation, + groups=groups, + bias=True if bias is not None else False) + + static_conv.weight = nn.Parameter(weight) + if bias is not None: + static_conv.bias = nn.Parameter(bias) + + return static_conv + + def get_pooled_channel(self: _ConvNd, tau: float) -> Tensor: + """Calculate channel's kl and apply softmax pooling on channel. Return + `layeri_softmaxp` as pooling result. + + Args: + tau (float): Temperature by epoch/iter. + + Returns: + Tensor: softmax pooled channel. + """ + param = self.weight + + # Compute layeri_param. + layeri_param = torch.reshape(param.detach(), (param.shape[0], -1)) + layeri_Eudist = torch.cdist(layeri_param, layeri_param, p=2) + layeri_negaEudist = -layeri_Eudist + softmax = nn.Softmax(dim=1) + layeri_softmaxp = softmax(layeri_negaEudist / tau) + + # KL = [c, 1, c] * ([c, 1 ,c] / [c, c, 1]).log() + # = [c, 1, c] * ([c, 1, c].log() - [c, c, 1].log()) + # only dim0 is required, dim1 and dim2 are pooled + # calc mean(dim=1) first + + # avoid frequent NaN + eps = 1e-7 + layeri_kl = layeri_softmaxp[:, None, :] + log_p = layeri_kl * (layeri_kl + eps).log() + log_q = layeri_kl * torch.mean((layeri_softmaxp + eps).log(), dim=1) + + layeri_kl = torch.mean((log_p - log_q), dim=2) + del log_p, log_q + real_out = self.mutable_attrs['out_channels'].activated_channels + + layeri_iscore_kl = torch.sum(layeri_kl, dim=1) + _, topm_ids_order = torch.topk( + layeri_iscore_kl, int(real_out), sorted=False) + del param, layeri_param, layeri_negaEudist, layeri_kl + return layeri_softmaxp[topm_ids_order, :] diff --git a/mmrazor/models/mutables/__init__.py b/mmrazor/models/mutables/__init__.py index abfceab7f..70315482f 100644 --- a/mmrazor/models/mutables/__init__.py +++ b/mmrazor/models/mutables/__init__.py @@ -4,8 +4,8 @@ from .mutable_channel import (BaseMutableChannel, MutableChannelContainer, OneShotMutableChannel, SimpleMutableChannel, SquentialMutableChannel) -from .mutable_channel.units import (ChannelUnitType, L1MutableChannelUnit, - MutableChannelUnit, +from .mutable_channel.units import (ChannelUnitType, DCFFChannelUnit, + L1MutableChannelUnit, MutableChannelUnit, OneShotMutableChannelUnit, SequentialMutableChannelUnit, SlimmableChannelUnit) @@ -22,5 +22,5 @@ 'SimpleMutableChannel', 'MutableChannelUnit', 'SlimmableChannelUnit', 'BaseMutableChannel', 'MutableChannelContainer', 'ChannelUnitType', 'SquentialMutableChannel', 'OneHotMutableOP', 'OneShotMutableChannel', - 'BaseMutable' + 'BaseMutable', 'DCFFChannelUnit' ] diff --git a/mmrazor/models/mutables/mutable_channel/__init__.py b/mmrazor/models/mutables/mutable_channel/__init__.py index 0ef09dc78..618766e4e 100644 --- a/mmrazor/models/mutables/mutable_channel/__init__.py +++ b/mmrazor/models/mutables/mutable_channel/__init__.py @@ -1,17 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base_mutable_channel import BaseMutableChannel from .mutable_channel_container import MutableChannelContainer -from .oneshot_mutalbe_channel import OneShotMutableChannel +from .oneshot_mutable_channel import OneShotMutableChannel from .sequential_mutable_channel import SquentialMutableChannel from .simple_mutable_channel import SimpleMutableChannel -from .units import (ChannelUnitType, L1MutableChannelUnit, MutableChannelUnit, - OneShotMutableChannelUnit, SequentialMutableChannelUnit, - SlimmableChannelUnit) +from .units import (ChannelUnitType, DCFFChannelUnit, L1MutableChannelUnit, + MutableChannelUnit, OneShotMutableChannelUnit, + SequentialMutableChannelUnit, SlimmableChannelUnit) __all__ = [ 'SimpleMutableChannel', 'L1MutableChannelUnit', 'SequentialMutableChannelUnit', 'MutableChannelUnit', 'OneShotMutableChannelUnit', 'SlimmableChannelUnit', 'BaseMutableChannel', 'MutableChannelContainer', 'SquentialMutableChannel', 'ChannelUnitType', - 'OneShotMutableChannel' + 'DCFFChannelUnit', 'OneShotMutableChannel' ] diff --git a/mmrazor/models/mutables/mutable_channel/oneshot_mutalbe_channel.py b/mmrazor/models/mutables/mutable_channel/oneshot_mutable_channel.py similarity index 100% rename from mmrazor/models/mutables/mutable_channel/oneshot_mutalbe_channel.py rename to mmrazor/models/mutables/mutable_channel/oneshot_mutable_channel.py diff --git a/mmrazor/models/mutables/mutable_channel/units/__init__.py b/mmrazor/models/mutables/mutable_channel/units/__init__.py index a61816718..8cf814163 100644 --- a/mmrazor/models/mutables/mutable_channel/units/__init__.py +++ b/mmrazor/models/mutables/mutable_channel/units/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. - +from .dcff_channel_unit import DCFFChannelUnit from .l1_mutable_channel_unit import L1MutableChannelUnit from .mutable_channel_unit import ChannelUnitType, MutableChannelUnit from .one_shot_mutable_channel_unit import OneShotMutableChannelUnit @@ -9,5 +9,5 @@ __all__ = [ 'L1MutableChannelUnit', 'MutableChannelUnit', 'SequentialMutableChannelUnit', 'OneShotMutableChannelUnit', - 'SlimmableChannelUnit', 'ChannelUnitType' + 'SlimmableChannelUnit', 'ChannelUnitType', 'DCFFChannelUnit' ] diff --git a/mmrazor/models/mutables/mutable_channel/units/dcff_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/dcff_channel_unit.py new file mode 100644 index 000000000..743a6473c --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/units/dcff_channel_unit.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Union + +import torch.nn as nn + +from mmrazor.models.architectures import dynamic_ops +from mmrazor.registry import MODELS +from ..mutable_channel_container import MutableChannelContainer +from .sequential_mutable_channel_unit import SequentialMutableChannelUnit + + +@MODELS.register_module() +class DCFFChannelUnit(SequentialMutableChannelUnit): + """``DCFFChannelUnit`` is for supernet DCFF and based on + OneShotMutableChannelUnit. In DCFF supernet, each module only has one + choice. The channel choice is fixed before training. + + Args: + num_channels (int): The raw number of channels. + candidate_choices (List[Union[int, float]], optional): + A list of candidate width numbers or ratios. Each + candidate indicates how many channels to be reserved. + Defaults to [1.0](choice_mode='number'). + choice_mode (str, optional): Mode of candidates. + One of "ratio" or "number". Defaults to 'ratio'. + divisor (int): Used to make choice divisible. + min_value (int): the minimal value used when make divisible. + min_ratio (float): the minimal ratio used when make divisible. + """ + + def __init__(self, + num_channels: int, + candidate_choices: List[Union[int, float]] = [1.0], + choice_mode: str = 'ratio', + divisor: int = 1, + min_value: int = 1, + min_ratio: float = 0.9) -> None: + super().__init__(num_channels, choice_mode, divisor, min_value, + min_ratio) + + def prepare_for_pruning(self, model: nn.Module): + """In ``DCFFChannelGroup`` nn.Conv2d is replaced with FuseConv2d.""" + self._replace_with_dynamic_ops( + model, { + nn.Conv2d: dynamic_ops.FuseConv2d, + nn.BatchNorm2d: dynamic_ops.DynamicBatchNorm2d, + nn.Linear: dynamic_ops.DynamicLinear + }) + self._register_channel_container(model, MutableChannelContainer) + self._register_mutable_channel(self.mutable_channel) diff --git a/mmrazor/models/mutables/mutable_channel/units/one_shot_mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/one_shot_mutable_channel_unit.py index 8ba55b25e..220d49b41 100644 --- a/mmrazor/models/mutables/mutable_channel/units/one_shot_mutable_channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/one_shot_mutable_channel_unit.py @@ -6,7 +6,7 @@ import torch.nn as nn from mmrazor.registry import MODELS -from ..oneshot_mutalbe_channel import OneShotMutableChannel +from ..oneshot_mutable_channel import OneShotMutableChannel from .sequential_mutable_channel_unit import SequentialMutableChannelUnit diff --git a/mmrazor/models/mutators/__init__.py b/mmrazor/models/mutators/__init__.py index d11358404..6a430bed3 100644 --- a/mmrazor/models/mutators/__init__.py +++ b/mmrazor/models/mutators/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .channel_mutator import (ChannelMutator, OneShotChannelMutator, - SlimmableChannelMutator) +from .channel_mutator import (ChannelMutator, DCFFChannelMutator, + OneShotChannelMutator, SlimmableChannelMutator) from .module_mutator import (DiffModuleMutator, ModuleMutator, OneShotModuleMutator) from .value_mutator import DynamicValueMutator, ValueMutator @@ -8,5 +8,5 @@ __all__ = [ 'OneShotModuleMutator', 'DiffModuleMutator', 'ModuleMutator', 'ChannelMutator', 'OneShotChannelMutator', 'SlimmableChannelMutator', - 'ValueMutator', 'DynamicValueMutator' + 'ValueMutator', 'DynamicValueMutator', 'DCFFChannelMutator' ] diff --git a/mmrazor/models/mutators/channel_mutator/__init__.py b/mmrazor/models/mutators/channel_mutator/__init__.py index dc4b1c86d..3b64c1cf8 100644 --- a/mmrazor/models/mutators/channel_mutator/__init__.py +++ b/mmrazor/models/mutators/channel_mutator/__init__.py @@ -1,8 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .channel_mutator import ChannelMutator +from .dcff_channel_mutator import DCFFChannelMutator from .one_shot_channel_mutator import OneShotChannelMutator from .slimmable_channel_mutator import SlimmableChannelMutator __all__ = [ - 'SlimmableChannelMutator', 'ChannelMutator', 'OneShotChannelMutator' + 'SlimmableChannelMutator', 'ChannelMutator', 'OneShotChannelMutator', + 'DCFFChannelMutator' ] diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.py b/mmrazor/models/mutators/channel_mutator/channel_mutator.py index 28c395acd..7ced693b9 100644 --- a/mmrazor/models/mutators/channel_mutator/channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.py @@ -238,6 +238,9 @@ def set_choices(self, choices: Dict[int, Any]) -> None: corresponding to this group. """ for group_id, modules in self.search_groups.items(): + if group_id not in choices: + # allow optional target_prune_ratio + continue choice = choices[group_id] for module in modules: module.current_choice = choice diff --git a/mmrazor/models/mutators/channel_mutator/dcff_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/dcff_channel_mutator.py new file mode 100644 index 000000000..8dd335bff --- /dev/null +++ b/mmrazor/models/mutators/channel_mutator/dcff_channel_mutator.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Type, Union + +from mmrazor.models.architectures.dynamic_ops import FuseConv2d +from mmrazor.models.mutables import DCFFChannelUnit +from mmrazor.registry import MODELS +from .channel_mutator import ChannelMutator, ChannelUnitType + + +@MODELS.register_module() +class DCFFChannelMutator(ChannelMutator[DCFFChannelUnit]): + """DCFF channel mutable based channel mutator. It uses DCFFChannelUnit. + + Args: + channel_unit_cfg (Union[dict, Type[ChannelUnitType]], optional): + Config of MutableChannelUnits. Defaults to + dict( type='DCFFChannelUnit', units={}). + parse_cfg (Dict): The config of the tracer to parse the model. + Defaults to dict( type='BackwardTracer', + loss_calculator=dict(type='ImageClassifierPseudoLoss')). + Change loss_calculator according to task and backbone. + """ + + def __init__(self, + channel_unit_cfg: Union[dict, Type[ChannelUnitType]] = dict( + type='DCFFChannelUnit', units={}), + parse_cfg=dict( + type='BackwardTracer', + loss_calculator=dict(type='ImageClassifierPseudoLoss')), + **kwargs) -> None: + super().__init__(channel_unit_cfg, parse_cfg, **kwargs) + + def calc_information(self, tau: float): + """Calculate channel's kl and apply softmax pooling on channel to solve + CUDA out of memory problem. KL calculation & pool are conducted in ops. + + Args: + tau (float): temporature calculated by iter or epoch + """ + # Calculate the filter importance of the current epoch. + for layerid, unit in enumerate(self.units): + for channel in unit.output_related: + if isinstance(channel.module, FuseConv2d): + layeri_softmaxp = channel.module.get_pooled_channel(tau) + # update fuseconv op's selected layeri_softmax + channel.module.set_forward_args(choice=layeri_softmaxp) diff --git a/mmrazor/models/task_modules/tracer/loss_calculator/__init__.py b/mmrazor/models/task_modules/tracer/loss_calculator/__init__.py index 0371a713a..bcc272da2 100644 --- a/mmrazor/models/task_modules/tracer/loss_calculator/__init__.py +++ b/mmrazor/models/task_modules/tracer/loss_calculator/__init__.py @@ -1,6 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .cascade_encoder_decoder_loss_calculator import \ + CascadeEncoderDecoderPseudoLoss from .image_classifier_loss_calculator import ImageClassifierPseudoLoss from .single_stage_detector_loss_calculator import \ SingleStageDetectorPseudoLoss +from .top_down_pose_estimator_loss_calculator import \ + TopdownPoseEstimatorPseudoLoss +from .two_stage_detector_loss_calculator import TwoStageDetectorPseudoLoss -__all__ = ['ImageClassifierPseudoLoss', 'SingleStageDetectorPseudoLoss'] +__all__ = [ + 'ImageClassifierPseudoLoss', 'SingleStageDetectorPseudoLoss', + 'TwoStageDetectorPseudoLoss', 'TopdownPoseEstimatorPseudoLoss', + 'CascadeEncoderDecoderPseudoLoss' +] diff --git a/mmrazor/models/task_modules/tracer/loss_calculator/cascade_encoder_decoder_loss_calculator.py b/mmrazor/models/task_modules/tracer/loss_calculator/cascade_encoder_decoder_loss_calculator.py new file mode 100644 index 000000000..f4f60c843 --- /dev/null +++ b/mmrazor/models/task_modules/tracer/loss_calculator/cascade_encoder_decoder_loss_calculator.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmrazor.registry import TASK_UTILS + +try: + from mmseg.models import CascadeEncoderDecoder +except ImportError: + from mmrazor.utils import get_placeholder + CascadeEncoderDecoder = get_placeholder('mmseg') + + +@TASK_UTILS.register_module() +class CascadeEncoderDecoderPseudoLoss: + """Calculate the pseudo loss to trace the topology of a + `CascadeEncoderDecoder` in MMSegmentation with `BackwardTracer`.""" + + def __call__(self, model: CascadeEncoderDecoder) -> torch.Tensor: + pseudo_img = torch.rand(1, 3, 224, 224) + pseudo_output = model.backbone(pseudo_img) + pseudo_output = model.neck(pseudo_output) + # unmodified decode_heads + out = torch.tensor(0.) + for levels in pseudo_output: + out += sum([level.sum() for level in levels]) + return out diff --git a/mmrazor/models/task_modules/tracer/loss_calculator/top_down_pose_estimator_loss_calculator.py b/mmrazor/models/task_modules/tracer/loss_calculator/top_down_pose_estimator_loss_calculator.py new file mode 100644 index 000000000..9720194f4 --- /dev/null +++ b/mmrazor/models/task_modules/tracer/loss_calculator/top_down_pose_estimator_loss_calculator.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmrazor.registry import TASK_UTILS + +try: + from mmpose.models import TopdownPoseEstimator +except ImportError: + from mmrazor.utils import get_placeholder + TopdownPoseEstimator = get_placeholder('mmpose') + + +@TASK_UTILS.register_module() +class TopdownPoseEstimatorPseudoLoss: + """Calculate the pseudo loss to trace the topology of a + `TopdownPoseEstimator` in MMPose with `BackwardTracer`.""" + + def __call__(self, model: TopdownPoseEstimator) -> torch.Tensor: + pseudo_img = torch.rand(1, 3, 224, 224) + pseudo_output = model.backbone(pseudo_img) + # immutable decode_heads + out = torch.tensor(0.) + for levels in pseudo_output: + out += sum([level.sum() for level in levels]) + return out diff --git a/mmrazor/models/task_modules/tracer/loss_calculator/two_stage_detector_loss_calculator.py b/mmrazor/models/task_modules/tracer/loss_calculator/two_stage_detector_loss_calculator.py new file mode 100644 index 000000000..97ff7d282 --- /dev/null +++ b/mmrazor/models/task_modules/tracer/loss_calculator/two_stage_detector_loss_calculator.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmrazor.registry import TASK_UTILS + +try: + from mmdet.models import TwoStageDetector +except ImportError: + from mmrazor.utils import get_placeholder + TwoStageDetector = get_placeholder('mmdet') + + +# todo: adapt to mmdet 2.0 +@TASK_UTILS.register_module() +class TwoStageDetectorPseudoLoss: + """Calculate the pseudo loss to trace the topology of a `TwoStageDetector` + in MMDet with `BackwardTracer`.""" + + def __call__(self, model: TwoStageDetector) -> torch.Tensor: + pseudo_img = torch.rand(1, 3, 224, 224) + pseudo_output = model.backbone(pseudo_img) + pseudo_output = model.neck(pseudo_output) + out = torch.tensor(0.) + for levels in pseudo_output: + out += sum([level.sum() for level in levels]) + + return out diff --git a/tests/data/test_models/test_algorithm/MBV2_220M.yaml b/tests/data/test_models/test_algorithm/MBV2_220M.yaml new file mode 100644 index 000000000..b96ebeb49 --- /dev/null +++ b/tests/data/test_models/test_algorithm/MBV2_220M.yaml @@ -0,0 +1,474 @@ +backbone.conv1.bn.mutable_num_features: + current_choice: 8 + origin_channels: 48 +backbone.conv1.conv.mutable_in_channels: + current_choice: 3 + origin_channels: 3 +backbone.conv1.conv.mutable_out_channels: + current_choice: 8 + origin_channels: 48 +backbone.conv2.bn.mutable_num_features: + current_choice: 1920 + origin_channels: 1920 +backbone.conv2.conv.mutable_in_channels: + current_choice: 280 + origin_channels: 480 +backbone.conv2.conv.mutable_out_channels: + current_choice: 1920 + origin_channels: 1920 +backbone.layer1.0.conv.0.bn.mutable_num_features: + current_choice: 8 + origin_channels: 48 +backbone.layer1.0.conv.0.conv.mutable_in_channels: + current_choice: 8 + origin_channels: 48 +backbone.layer1.0.conv.0.conv.mutable_out_channels: + current_choice: 8 + origin_channels: 48 +backbone.layer1.0.conv.1.bn.mutable_num_features: + current_choice: 8 + origin_channels: 24 +backbone.layer1.0.conv.1.conv.mutable_in_channels: + current_choice: 8 + origin_channels: 48 +backbone.layer1.0.conv.1.conv.mutable_out_channels: + current_choice: 8 + origin_channels: 24 +backbone.layer2.0.conv.0.bn.mutable_num_features: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.0.conv.mutable_in_channels: + current_choice: 8 + origin_channels: 24 +backbone.layer2.0.conv.0.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.1.bn.mutable_num_features: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.1.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.1.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.2.bn.mutable_num_features: + current_choice: 16 + origin_channels: 40 +backbone.layer2.0.conv.2.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.2.conv.mutable_out_channels: + current_choice: 16 + origin_channels: 40 +backbone.layer2.1.conv.0.bn.mutable_num_features: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.0.conv.mutable_in_channels: + current_choice: 16 + origin_channels: 40 +backbone.layer2.1.conv.0.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.1.bn.mutable_num_features: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.1.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.1.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.2.bn.mutable_num_features: + current_choice: 16 + origin_channels: 40 +backbone.layer2.1.conv.2.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.2.conv.mutable_out_channels: + current_choice: 16 + origin_channels: 40 +backbone.layer3.0.conv.0.bn.mutable_num_features: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.0.conv.mutable_in_channels: + current_choice: 16 + origin_channels: 40 +backbone.layer3.0.conv.0.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.1.bn.mutable_num_features: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.1.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.1.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.2.bn.mutable_num_features: + current_choice: 24 + origin_channels: 48 +backbone.layer3.0.conv.2.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.2.conv.mutable_out_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer3.1.conv.0.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.0.conv.mutable_in_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer3.1.conv.0.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.1.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.1.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.1.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.2.bn.mutable_num_features: + current_choice: 24 + origin_channels: 48 +backbone.layer3.1.conv.2.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.2.conv.mutable_out_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer3.2.conv.0.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.0.conv.mutable_in_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer3.2.conv.0.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.1.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.1.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.1.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.2.bn.mutable_num_features: + current_choice: 24 + origin_channels: 48 +backbone.layer3.2.conv.2.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.2.conv.mutable_out_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer4.0.conv.0.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.0.conv.mutable_in_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer4.0.conv.0.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.1.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.1.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.1.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.2.bn.mutable_num_features: + current_choice: 48 + origin_channels: 96 +backbone.layer4.0.conv.2.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.2.conv.mutable_out_channels: + current_choice: 48 + origin_channels: 96 +backbone.layer4.1.conv.0.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.0.conv.mutable_in_channels: + current_choice: 48 + origin_channels: 96 +backbone.layer4.1.conv.0.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.1.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.1.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.1.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.2.bn.mutable_num_features: + current_choice: 48 + origin_channels: 96 +backbone.layer4.1.conv.2.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.2.conv.mutable_out_channels: + current_choice: 48 + origin_channels: 96 +backbone.layer4.2.conv.0.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.0.conv.mutable_in_channels: + current_choice: 48 + origin_channels: 96 +backbone.layer4.2.conv.0.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.1.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.1.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.1.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.2.bn.mutable_num_features: + current_choice: 48 + origin_channels: 96 +backbone.layer4.2.conv.2.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.2.conv.mutable_out_channels: + current_choice: 48 + origin_channels: 96 +backbone.layer4.3.conv.0.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.0.conv.mutable_in_channels: + current_choice: 48 + origin_channels: 96 +backbone.layer4.3.conv.0.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.1.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.1.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.1.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.2.bn.mutable_num_features: + current_choice: 48 + origin_channels: 96 +backbone.layer4.3.conv.2.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.2.conv.mutable_out_channels: + current_choice: 48 + origin_channels: 96 +backbone.layer5.0.conv.0.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.0.conv.mutable_in_channels: + current_choice: 48 + origin_channels: 96 +backbone.layer5.0.conv.0.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.1.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.1.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.1.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.2.bn.mutable_num_features: + current_choice: 64 + origin_channels: 144 +backbone.layer5.0.conv.2.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.2.conv.mutable_out_channels: + current_choice: 64 + origin_channels: 144 +backbone.layer5.1.conv.0.bn.mutable_num_features: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.0.conv.mutable_in_channels: + current_choice: 64 + origin_channels: 144 +backbone.layer5.1.conv.0.conv.mutable_out_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.1.bn.mutable_num_features: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.1.conv.mutable_in_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.1.conv.mutable_out_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.2.bn.mutable_num_features: + current_choice: 64 + origin_channels: 144 +backbone.layer5.1.conv.2.conv.mutable_in_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.2.conv.mutable_out_channels: + current_choice: 64 + origin_channels: 144 +backbone.layer5.2.conv.0.bn.mutable_num_features: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.0.conv.mutable_in_channels: + current_choice: 64 + origin_channels: 144 +backbone.layer5.2.conv.0.conv.mutable_out_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.1.bn.mutable_num_features: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.1.conv.mutable_in_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.1.conv.mutable_out_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.2.bn.mutable_num_features: + current_choice: 64 + origin_channels: 144 +backbone.layer5.2.conv.2.conv.mutable_in_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.2.conv.mutable_out_channels: + current_choice: 64 + origin_channels: 144 +backbone.layer6.0.conv.0.bn.mutable_num_features: + current_choice: 648 + origin_channels: 864 +backbone.layer6.0.conv.0.conv.mutable_in_channels: + current_choice: 64 + origin_channels: 144 +backbone.layer6.0.conv.0.conv.mutable_out_channels: + current_choice: 648 + origin_channels: 864 +backbone.layer6.0.conv.1.bn.mutable_num_features: + current_choice: 648 + origin_channels: 864 +backbone.layer6.0.conv.1.conv.mutable_in_channels: + current_choice: 648 + origin_channels: 864 +backbone.layer6.0.conv.1.conv.mutable_out_channels: + current_choice: 648 + origin_channels: 864 +backbone.layer6.0.conv.2.bn.mutable_num_features: + current_choice: 176 + origin_channels: 240 +backbone.layer6.0.conv.2.conv.mutable_in_channels: + current_choice: 648 + origin_channels: 864 +backbone.layer6.0.conv.2.conv.mutable_out_channels: + current_choice: 176 + origin_channels: 240 +backbone.layer6.1.conv.0.bn.mutable_num_features: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.1.conv.0.conv.mutable_in_channels: + current_choice: 176 + origin_channels: 240 +backbone.layer6.1.conv.0.conv.mutable_out_channels: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.1.conv.1.bn.mutable_num_features: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.1.conv.1.conv.mutable_in_channels: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.1.conv.1.conv.mutable_out_channels: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.1.conv.2.bn.mutable_num_features: + current_choice: 176 + origin_channels: 240 +backbone.layer6.1.conv.2.conv.mutable_in_channels: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.1.conv.2.conv.mutable_out_channels: + current_choice: 176 + origin_channels: 240 +backbone.layer6.2.conv.0.bn.mutable_num_features: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.2.conv.0.conv.mutable_in_channels: + current_choice: 176 + origin_channels: 240 +backbone.layer6.2.conv.0.conv.mutable_out_channels: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.2.conv.1.bn.mutable_num_features: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.2.conv.1.conv.mutable_in_channels: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.2.conv.1.conv.mutable_out_channels: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.2.conv.2.bn.mutable_num_features: + current_choice: 176 + origin_channels: 240 +backbone.layer6.2.conv.2.conv.mutable_in_channels: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.2.conv.2.conv.mutable_out_channels: + current_choice: 176 + origin_channels: 240 +backbone.layer7.0.conv.0.bn.mutable_num_features: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.0.conv.mutable_in_channels: + current_choice: 176 + origin_channels: 240 +backbone.layer7.0.conv.0.conv.mutable_out_channels: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.1.bn.mutable_num_features: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.1.conv.mutable_in_channels: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.1.conv.mutable_out_channels: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.2.bn.mutable_num_features: + current_choice: 280 + origin_channels: 480 +backbone.layer7.0.conv.2.conv.mutable_in_channels: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.2.conv.mutable_out_channels: + current_choice: 280 + origin_channels: 480 +head.fc.mutable_in_features: + current_choice: 1920 + origin_channels: 1920 +head.fc.mutable_out_features: + current_choice: 1000 + origin_channels: 1000 \ No newline at end of file diff --git a/tests/data/test_models/test_mutator/subnet1.json b/tests/data/test_models/test_mutator/subnet1.json new file mode 100644 index 000000000..2fed960b2 --- /dev/null +++ b/tests/data/test_models/test_mutator/subnet1.json @@ -0,0 +1,15 @@ +{ + "op1_(0, 8)_8": { + "init_args":{ + "num_channels":8, + "divisor":1, + "min_value":1, + "min_ratio":0.9, + "candidate_choices":[ + 6 + ], + "choice_mode":"number" + }, + "choice":6 + } +} diff --git a/tests/test_models/test_algorithms/test_dcff_network.py b/tests/test_models/test_algorithms/test_dcff_network.py new file mode 100644 index 000000000..fd108a172 --- /dev/null +++ b/tests/test_models/test_algorithms/test_dcff_network.py @@ -0,0 +1,231 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +import unittest + +import torch +from mmcls.structures import ClsDataSample +from mmengine import MessageHub +from mmengine.model import BaseModel + +from mmrazor.models.algorithms.pruning.dcff import DCFF +from mmrazor.models.algorithms.pruning.ite_prune_algorithm import \ + ItePruneConfigManager +from mmrazor.registry import MODELS + + +# @TASK_UTILS.register_module() +class ImageClassifierPseudoLoss: + """Calculate the pseudo loss to trace the topology of a `ImageClassifier` + in MMClassification with `BackwardTracer`.""" + + def __call__(self, model) -> torch.Tensor: + pseudo_img = torch.rand(2, 3, 32, 32) + pseudo_output = model(pseudo_img) + return pseudo_output.sum() + + +MODEL_CFG = dict( + _scope_='mmcls', + type='ImageClassifier', + backbone=dict( + type='ResNet', + depth=18, + num_stages=4, + out_indices=(3, ), + style='pytorch'), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=512, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) + +MUTATOR_CONFIG_NUM = dict( + type='DCFFChannelMutator', + channel_unit_cfg={ + 'type': 'DCFFChannelUnit', + 'default_args': { + 'choice_mode': 'number' + } + }) +MUTATOR_CONFIG_FLOAT = dict( + type='DCFFChannelMutator', + channel_unit_cfg={ + 'type': 'DCFFChannelUnit', + 'default_args': { + 'choice_mode': 'ratio' + } + }) + +if torch.cuda.is_available(): + DEVICE = torch.device('cuda:0') +else: + DEVICE = torch.device('cpu') + + +class TestDCFFAlgorithm(unittest.TestCase): + + def _set_epoch_ite(self, epoch, ite, max_epoch): + iter_per_epoch = 10 + message_hub = MessageHub.get_current_instance() + message_hub.update_info('epoch', epoch) + message_hub.update_info('max_epochs', max_epoch) + message_hub.update_info('max_iters', max_epoch * iter_per_epoch) + message_hub.update_info('iter', ite + iter_per_epoch * epoch) + + def fake_cifar_data(self): + imgs = torch.randn(16, 3, 32, 32).to(DEVICE) + data_samples = [ + ClsDataSample().set_gt_label(torch.randint(0, 10, + (16, ))).to(DEVICE) + ] + + return {'inputs': imgs, 'data_samples': data_samples} + + def test_ite_prune_config_manager(self): + iter_per_epoch = 10 + float_origin, float_target = 1.0, 0.5 + int_origin, int_target = 10, 5 + for origin, target, manager in [ + (float_origin, float_target, + ItePruneConfigManager({'a': float_target}, {'a': float_origin}, + 2 * iter_per_epoch, 5)), + (int_origin, int_target, + ItePruneConfigManager({'a': int_target}, {'a': int_origin}, + 2 * iter_per_epoch, 5)) + ]: + times = 1 + for e in range(1, 10): + for ite in range(iter_per_epoch): + self._set_epoch_ite(e, ite, 10) + if (e, ite) in [(0, 0), (2, 0), (4, 0), (6, 0), (8, 0)]: + self.assertTrue( + manager.is_prune_time(e * iter_per_epoch + ite)) + times += 1 + self.assertEqual( + manager.prune_at(e * iter_per_epoch + ite)['a'], + origin - (origin - target) * times / 5) + else: + self.assertFalse( + manager.is_prune_time(e * iter_per_epoch + ite)) + + def test_iterative_prune_int(self): + + data = self.fake_cifar_data() + + model = MODELS.build(MODEL_CFG) + mutator = MODELS.build(MUTATOR_CONFIG_FLOAT) + mutator.prepare_from_supernet(model) + mutator.set_choices(mutator.sample_choices()) + prune_target = mutator.choice_template + + iter_per_epoch = 10 + epoch = 10 + epoch_step = 2 + times = 5 + + algorithm = DCFF( + MODEL_CFG, + target_pruning_ratio=prune_target, + mutator_cfg=MUTATOR_CONFIG_FLOAT, + step_freq=epoch_step).to(DEVICE) + + for e in range(epoch): + for ite in range(10): + self._set_epoch_ite(e, ite, epoch) + + algorithm.forward( + data['inputs'], data['data_samples'], mode='loss') + self.assertEqual(times, algorithm.prune_times) + self.assertEqual(epoch_step * iter_per_epoch, + algorithm.step_freq) + + current_choices = algorithm.mutator.current_choices + group_prune_target = algorithm.group_target_pruning_ratio( + prune_target, mutator.search_groups) + for key in current_choices: + self.assertAlmostEqual( + current_choices[key], group_prune_target[key], delta=0.1) + + def test_load_pretrained(self): + iter_per_epoch = 10 + epoch_step = 20 + data = self.fake_cifar_data() + + # prepare checkpoint + model_cfg = copy.deepcopy(MODEL_CFG) + model: BaseModel = MODELS.build(model_cfg) + checkpoint_path = os.path.dirname(__file__) + '/checkpoint' + torch.save(model.state_dict(), checkpoint_path) + + # build algorithm + model_cfg['init_cfg'] = { + 'type': 'Pretrained', + 'checkpoint': checkpoint_path + } + algorithm = DCFF( + model_cfg, + mutator_cfg=MUTATOR_CONFIG_FLOAT, + target_pruning_ratio=None, + step_freq=epoch_step).to(DEVICE) + algorithm.init_weights() + self._set_epoch_ite(10, 5, 200) + algorithm.forward(data['inputs'], data['data_samples'], mode='loss') + self.assertEqual(algorithm.step_freq, epoch_step * iter_per_epoch) + + # delete checkpoint + os.remove(checkpoint_path) + + def test_group_target_ratio(self): + + model = MODELS.build(MODEL_CFG) + mutator = MODELS.build(MUTATOR_CONFIG_FLOAT) + mutator.prepare_from_supernet(model) + mutator.set_choices(mutator.sample_choices()) + prune_target = mutator.choice_template + + custom_groups = [[ + 'backbone.layer1.0.conv1_(0, 64)_64', + 'backbone.layer1.1.conv1_(0, 64)_64' + ]] + mutator_cfg = copy.deepcopy(MUTATOR_CONFIG_FLOAT) + mutator_cfg['custom_groups'] = custom_groups + + iter_per_epoch = 10 + epoch_step = 2 + epoch = 6 + data = self.fake_cifar_data() + + prune_target['backbone.layer1.0.conv1_(0, 64)_64'] = 0.1 + prune_target['backbone.layer1.1.conv1_(0, 64)_64'] = 0.1 + + algorithm = DCFF( + MODEL_CFG, + target_pruning_ratio=prune_target, + mutator_cfg=mutator_cfg, + step_freq=epoch_step).to(DEVICE) + + algorithm.init_weights() + self._set_epoch_ite(1, 2, epoch) + algorithm.forward(data['inputs'], data['data_samples'], mode='loss') + self.assertEqual(algorithm.step_freq, epoch_step * iter_per_epoch) + + prune_target['backbone.layer1.0.conv1_(0, 64)_64'] = 0.1 + prune_target['backbone.layer1.1.conv1_(0, 64)_64'] = 0.2 + + with self.assertRaises(ValueError): + + algorithm = DCFF( + MODEL_CFG, + target_pruning_ratio=prune_target, + mutator_cfg=mutator_cfg, + step_freq=epoch_step).to(DEVICE) + + algorithm.init_weights() + self._set_epoch_ite(1, 2, epoch) + algorithm.forward( + data['inputs'], data['data_samples'], mode='loss') + self.assertEqual(algorithm.step_freq, epoch_step * iter_per_epoch) diff --git a/tests/test_models/test_algorithms/test_prune_algorithm.py b/tests/test_models/test_algorithms/test_prune_algorithm.py index 3a00e93b9..b90b9cc62 100644 --- a/tests/test_models/test_algorithms/test_prune_algorithm.py +++ b/tests/test_models/test_algorithms/test_prune_algorithm.py @@ -85,27 +85,31 @@ def fake_cifar_data(self): return {'inputs': imgs, 'data_samples': data_samples} def test_ite_prune_config_manager(self): + iter_per_epoch = 10 float_origin, float_target = 1.0, 0.5 int_origin, int_target = 10, 5 for origin, target, manager in [ (float_origin, float_target, - ItePruneConfigManager({'a': float_target}, {'a': float_origin}, 2, - 5)), + ItePruneConfigManager({'a': float_target}, {'a': float_origin}, + 2 * iter_per_epoch, 5)), (int_origin, int_target, - ItePruneConfigManager({'a': int_target}, {'a': int_origin}, 2, 5)) + ItePruneConfigManager({'a': int_target}, {'a': int_origin}, + 2 * iter_per_epoch, 5)) ]: times = 1 - for e in range(1, 20): - for ite in range(1, 5): - self._set_epoch_ite(e, ite, 5) + for e in range(1, 10): + for ite in range(iter_per_epoch): + self._set_epoch_ite(e, ite, 10) if (e, ite) in [(0, 0), (2, 0), (4, 0), (6, 0), (8, 0)]: - self.assertTrue(manager.is_prune_time(e, ite)) + self.assertTrue( + manager.is_prune_time(e * iter_per_epoch + ite)) + times += 1 self.assertEqual( - manager.prune_at(e)['a'], + manager.prune_at(e * iter_per_epoch + ite)['a'], origin - (origin - target) * times / 5) - times += 1 else: - self.assertFalse(manager.is_prune_time(e, ite)) + self.assertFalse( + manager.is_prune_time(e * iter_per_epoch + ite)) def test_iterative_prune_int(self): @@ -117,6 +121,7 @@ def test_iterative_prune_int(self): mutator.set_choices(mutator.sample_choices()) prune_target = mutator.choice_template + iter_per_epoch = 10 epoch = 10 epoch_step = 2 times = 3 @@ -125,15 +130,18 @@ def test_iterative_prune_int(self): MODEL_CFG, target_pruning_ratio=prune_target, mutator_cfg=MUTATOR_CONFIG_FLOAT, - step_epoch=epoch_step, + step_freq=epoch_step, prune_times=times).to(DEVICE) for e in range(epoch): - for ite in range(5): - self._set_epoch_ite(e, ite, 5) + for ite in range(10): + self._set_epoch_ite(e, ite, epoch) algorithm.forward( data['inputs'], data['data_samples'], mode='loss') + self.assertEqual(times, algorithm.prune_times) + self.assertEqual(epoch_step * iter_per_epoch, + algorithm.step_freq) current_choices = algorithm.mutator.current_choices group_prune_target = algorithm.group_target_pruning_ratio( @@ -143,6 +151,7 @@ def test_iterative_prune_int(self): current_choices[key], group_prune_target[key], delta=0.1) def test_load_pretrained(self): + iter_per_epoch = 10 epoch_step = 2 times = 3 data = self.fake_cifar_data() @@ -162,11 +171,13 @@ def test_load_pretrained(self): model_cfg, mutator_cfg=MUTATOR_CONFIG_NUM, target_pruning_ratio=None, - step_epoch=epoch_step, + step_freq=epoch_step, prune_times=times, ).to(DEVICE) algorithm.init_weights() + self._set_epoch_ite(4, 5, 6) algorithm.forward(data['inputs'], data['data_samples'], mode='loss') + self.assertEqual(algorithm.step_freq, epoch_step * iter_per_epoch) # delete checkpoint os.remove(checkpoint_path) @@ -186,27 +197,41 @@ def test_group_target_ratio(self): mutator_cfg = copy.deepcopy(MUTATOR_CONFIG_FLOAT) mutator_cfg['custom_groups'] = custom_groups + iter_per_epoch = 10 epoch_step = 2 - times = 3 + time = 2 + epoch = 6 + data = self.fake_cifar_data() prune_target['backbone.layer1.0.conv1_(0, 64)_64'] = 0.1 prune_target['backbone.layer1.1.conv1_(0, 64)_64'] = 0.1 - _ = ItePruneAlgorithm( + algorithm = ItePruneAlgorithm( MODEL_CFG, target_pruning_ratio=prune_target, mutator_cfg=mutator_cfg, - step_epoch=epoch_step, - prune_times=times).to(DEVICE) + step_freq=epoch_step, + prune_times=time).to(DEVICE) + + algorithm.init_weights() + self._set_epoch_ite(1, 2, epoch) + algorithm.forward(data['inputs'], data['data_samples'], mode='loss') + self.assertEqual(algorithm.step_freq, epoch_step * iter_per_epoch) prune_target['backbone.layer1.0.conv1_(0, 64)_64'] = 0.1 prune_target['backbone.layer1.1.conv1_(0, 64)_64'] = 0.2 with self.assertRaises(ValueError): - _ = ItePruneAlgorithm( + algorithm = ItePruneAlgorithm( MODEL_CFG, target_pruning_ratio=prune_target, mutator_cfg=mutator_cfg, - step_epoch=epoch_step, - prune_times=times).to(DEVICE) + step_freq=epoch_step, + prune_times=time).to(DEVICE) + + algorithm.init_weights() + self._set_epoch_ite(1, 2, epoch) + algorithm.forward( + data['inputs'], data['data_samples'], mode='loss') + self.assertEqual(algorithm.step_freq, epoch_step * iter_per_epoch) diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_conv.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_conv.py index 8eab78af8..d7fc48cbc 100644 --- a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_conv.py +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_conv.py @@ -9,7 +9,8 @@ from torch import nn from mmrazor.models.architectures.dynamic_ops import (BigNasConv2d, - DynamicConv2d, OFAConv2d) + DynamicConv2d, + FuseConv2d, OFAConv2d) from mmrazor.models.mutables import (OneShotMutableValue, SquentialMutableChannel) from mmrazor.structures.subnet import export_fix_subnet, load_fix_subnet @@ -72,12 +73,30 @@ def test_dynamic_conv2d_depthwise(self) -> None: assert torch.equal(out1, out2) +def mock_layeri_choice(d_conv2d: FuseConv2d) -> None: + # mock selected out channel proxy for `FuseConv2d` + c_out, _, _, _ = d_conv2d.weight.size() + print('d_conv2d.mutable_attrs:', d_conv2d.mutable_attrs) + if ('out_channels' in d_conv2d.mutable_attrs): + c_current_out = \ + d_conv2d.mutable_attrs['out_channels'].current_mask.sum().item() + else: + c_current_out = c_out + device = d_conv2d.weight.device + layeri_mock = torch.rand(c_current_out, c_out).to(device) + d_conv2d.set_forward_args(choice=layeri_mock) + + +@pytest.mark.parametrize('dynamic_class', + [BigNasConv2d, DynamicConv2d, FuseConv2d, OFAConv2d]) @pytest.mark.parametrize('bias', [True, False]) -def test_dynamic_conv2d(bias: bool) -> None: - d_conv2d = DynamicConv2d( +def test_dynamic_conv2d(bias: bool, dynamic_class: Type[nn.Conv2d]) -> None: + d_conv2d = dynamic_class( in_channels=4, out_channels=10, kernel_size=3, stride=1, bias=bias) x_max = torch.rand(10, 4, 224, 224) + if (isinstance(d_conv2d, FuseConv2d)): + mock_layeri_choice(d_conv2d) out_before_mutate = d_conv2d(x_max) mutable_in_channels = SquentialMutableChannel(4) @@ -91,6 +110,8 @@ def test_dynamic_conv2d(bias: bool) -> None: d_conv2d.get_mutable_attr('in_channels').current_choice = 4 d_conv2d.mutate_out_channels = 10 + if (isinstance(d_conv2d, FuseConv2d)): + mock_layeri_choice(d_conv2d) out_max = d_conv2d(x_max) assert torch.equal(out_before_mutate, out_max) @@ -98,6 +119,8 @@ def test_dynamic_conv2d(bias: bool) -> None: d_conv2d.mutable_out_channels.current_choice = 4 x = torch.rand(10, 3, 224, 224) + if (isinstance(d_conv2d, FuseConv2d)): + mock_layeri_choice(d_conv2d) out1 = d_conv2d(x) assert out1.size(1) == 4 @@ -116,13 +139,15 @@ def test_dynamic_conv2d(bias: bool) -> None: assert torch.equal(out1, out2) +@pytest.mark.parametrize('dynamic_class', + [BigNasConv2d, DynamicConv2d, FuseConv2d, OFAConv2d]) @pytest.mark.parametrize( ['is_mutate_in_channels', 'in_channels', 'out_channels'], [(True, 6, 10), (False, 10, 4)]) -def test_dynamic_conv2d_mutable_single_channels(is_mutate_in_channels: bool, - in_channels: int, - out_channels: int) -> None: - d_conv2d = DynamicConv2d( +def test_dynamic_conv2d_mutable_single_channels( + is_mutate_in_channels: bool, in_channels: int, out_channels: int, + dynamic_class: Type[nn.Conv2d]) -> None: + d_conv2d = dynamic_class( in_channels=10, out_channels=10, kernel_size=3, stride=1, bias=True) mutable_channels = SquentialMutableChannel(10) @@ -131,6 +156,8 @@ def test_dynamic_conv2d_mutable_single_channels(is_mutate_in_channels: bool, else: d_conv2d.register_mutable_attr('out_channels', mutable_channels) + if (isinstance(d_conv2d, FuseConv2d)): + mock_layeri_choice(d_conv2d) with pytest.raises(RuntimeError): d_conv2d.to_static_op() @@ -142,6 +169,8 @@ def test_dynamic_conv2d_mutable_single_channels(is_mutate_in_channels: bool, assert d_conv2d.get_mutable_attr('in_channels') is None x = torch.rand(3, in_channels, 224, 224) + if (isinstance(d_conv2d, FuseConv2d)): + mock_layeri_choice(d_conv2d) out1 = d_conv2d(x) assert out1.size(1) == out_channels @@ -203,6 +232,8 @@ def test_kernel_dynamic_conv2d(dynamic_class: Type[nn.Conv2d], d_conv2d.mutable_attrs['kernel_size'].current_choice = kernel_size x = torch.rand(3, 8, 224, 224) + if (isinstance(d_conv2d, FuseConv2d)): + mock_layeri_choice(d_conv2d) out1 = d_conv2d(x) assert out1.size(1) == 8 @@ -245,6 +276,8 @@ def test_mutable_kernel_dynamic_conv2d_grad( for kernel_size in kernel_size_list: mutable_kernel_size.current_choice = kernel_size + if (isinstance(d_conv2d, FuseConv2d)): + mock_layeri_choice(d_conv2d) out = d_conv2d(x).sum() out.backward() diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_units/test_dcff_channel_unit.py b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_dcff_channel_unit.py new file mode 100644 index 000000000..d0462886c --- /dev/null +++ b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_dcff_channel_unit.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List +from unittest import TestCase + +import torch + +from mmrazor.models.architectures.dynamic_ops import FuseConv2d +from mmrazor.models.mutables import DCFFChannelUnit +from mmrazor.structures.graph import ModuleGraph as ModuleGraph +from .....data.models import LineModel + +DEVICE = torch.device('cuda:0') if torch.cuda.is_available() \ + else torch.device('cpu') + + +class TestDCFFChannelUnit(TestCase): + + def test_num(self): + unit = DCFFChannelUnit(48, choice_mode='number') + unit.current_choice = 24 + self.assertEqual(unit.current_choice, 24) + + unit.current_choice = 0.5 + self.assertEqual(unit.current_choice, 24) + + def test_ratio(self): + unit = DCFFChannelUnit(48, choice_mode='ratio') + unit.current_choice = 0.5 + self.assertEqual(unit.current_choice, 0.5) + unit.current_choice = 24 + self.assertEqual(unit.current_choice, 0.5) + + def test_divisor(self): + unit = DCFFChannelUnit(48, choice_mode='number', divisor=8) + unit.current_choice = 20 + self.assertEqual(unit.current_choice, 24) + self.assertTrue(unit.sample_choice() % 8 == 0) + + unit = DCFFChannelUnit(48, choice_mode='ratio', divisor=8) + unit.current_choice = 0.3 + self.assertEqual(unit.current_choice, 1 / 3) + + def test_config_template(self): + unit = DCFFChannelUnit(48, choice_mode='ratio', divisor=8) + config = unit.config_template(with_init_args=True) + unit2 = DCFFChannelUnit.init_from_cfg(None, config) + self.assertDictEqual( + unit2.config_template(with_init_args=True)['init_args'], + config['init_args']) + + def test_init_from_channel_unit(self): + # init using tracer + model = LineModel() + graph = ModuleGraph.init_from_backward_tracer(model) + units: List[DCFFChannelUnit] = DCFFChannelUnit.init_from_graph(graph) + mutable_units = [ + DCFFChannelUnit.init_from_channel_unit(unit) for unit in units + ] + model = model.to(DEVICE) + self._test_units(mutable_units, model) + + def _test_units(self, units: List[DCFFChannelUnit], model): + for unit in units: + unit.prepare_for_pruning(model) + mutable_units = [unit for unit in units if unit.is_mutable] + self.assertGreaterEqual(len(mutable_units), 1) + for unit in mutable_units: + choice = unit.sample_choice() + unit.current_choice = choice + for channel in unit.output_related: + if isinstance(channel.module, FuseConv2d): + layeri_softmaxp = channel.module.get_pooled_channel(1.0) + # update fuseconv op's selected layeri_softmax + channel.module.set_forward_args(choice=layeri_softmaxp) + x = torch.rand([2, 3, 224, 224]).to(DEVICE) + y = model(x) + self.assertSequenceEqual(y.shape, [2, 1000]) diff --git a/tests/test_models/test_mutables/test_sequential_mutable_channel.py b/tests/test_models/test_mutables/test_sequential_mutable_channel.py new file mode 100644 index 000000000..f7f4bb91e --- /dev/null +++ b/tests/test_models/test_mutables/test_sequential_mutable_channel.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +from mmrazor.models.mutables import SquentialMutableChannel + + +class TestSquentialMutableChannel(TestCase): + + def test_mul_float(self): + channel = SquentialMutableChannel(10) + new_channel = channel * 0.5 + self.assertEqual(new_channel.current_choice, 5) + channel.current_choice = 5 + self.assertEqual(new_channel.current_choice, 2) diff --git a/tests/test_models/test_mutators/test_channel_mutator.py b/tests/test_models/test_mutators/test_channel_mutator.py index b4a702bdf..a22f5c3c8 100644 --- a/tests/test_models/test_mutators/test_channel_mutator.py +++ b/tests/test_models/test_mutators/test_channel_mutator.py @@ -6,6 +6,7 @@ import torch +# from mmrazor.models.mutables import MutableChannelUnit from mmrazor.models.mutables.mutable_channel import ( L1MutableChannelUnit, SequentialMutableChannelUnit) from mmrazor.models.mutators.channel_mutator import ChannelMutator diff --git a/tests/test_models/test_mutators/test_dcff_mutator.py b/tests/test_models/test_mutators/test_dcff_mutator.py new file mode 100644 index 000000000..52e0bfd72 --- /dev/null +++ b/tests/test_models/test_mutators/test_dcff_mutator.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcls.models import * # noqa: F401,F403 +from torch import Tensor, nn +from torch.nn import Module + +from mmrazor.models.mutators import DCFFChannelMutator + + +class MultiConcatModel(Module): + + def __init__(self) -> None: + super().__init__() + + self.op1 = nn.Conv2d(3, 8, 1) + self.op2 = nn.Conv2d(3, 8, 1) + self.op3 = nn.Conv2d(16, 8, 1) + self.op4 = nn.Conv2d(3, 8, 1) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.op1(x) + x2 = self.op2(x) + cat1 = torch.cat([x1, x2], dim=1) + x3 = self.op3(cat1) + x4 = self.op4(x) + output = torch.cat([x3, x4], dim=1) + + return output + + +class MultiConcatModel2(Module): + + def __init__(self) -> None: + super().__init__() + + self.op1 = nn.Conv2d(3, 8, 1) + self.op2 = nn.Conv2d(3, 8, 1) + self.op3 = nn.Conv2d(3, 8, 1) + self.op4 = nn.Conv2d(24, 8, 1) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.op1(x) + x2 = self.op2(x) + x3 = self.op3(x) + cat1 = torch.cat([x1, x2], dim=1) + cat2 = torch.cat([cat1, x3], dim=1) + output = self.op4(cat2) + + return output + + +class ConcatModel(Module): + + def __init__(self) -> None: + super().__init__() + + self.op1 = nn.Conv2d(3, 8, 1) + self.bn1 = nn.BatchNorm2d(8) + self.op2 = nn.Conv2d(3, 8, 1) + self.bn2 = nn.BatchNorm2d(8) + self.op3 = nn.Conv2d(16, 8, 1) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.bn1(self.op1(x)) + x2 = self.bn2(self.op2(x)) + cat1 = torch.cat([x1, x2], dim=1) + x3 = self.op3(cat1) + + return x3 + + +class ResBlock(Module): + + def __init__(self) -> None: + super().__init__() + + self.op1 = nn.Conv2d(3, 8, 1) + self.bn1 = nn.BatchNorm2d(8) + self.op2 = nn.Conv2d(8, 8, 1) + self.bn2 = nn.BatchNorm2d(8) + self.op3 = nn.Conv2d(8, 8, 1) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.bn1(self.op1(x)) + x2 = self.bn2(self.op2(x1)) + x3 = self.op3(x2 + x1) + return x3 + + +def test_DCFF_channel_mutator() -> None: + imgs = torch.randn(16, 3, 224, 224) + + # ResBlock + mutator = DCFFChannelMutator( + channel_unit_cfg=dict(type='DCFFChannelUnit'), + parse_cfg=dict( + type='BackwardTracer', + loss_calculator=dict(type='ImageClassifierPseudoLoss'))) + + target_pruning_ratio = { + 0: 0.5, + } + + model = ResBlock() + mutator.prepare_from_supernet(model) + mutator.set_choices(target_pruning_ratio) + mutator.calc_information(1.0) + out3 = model(imgs) + + assert out3.shape == (16, 8, 224, 224)