From b4b7e2432a37d23ea83a73ea618a4faad73436be Mon Sep 17 00:00:00 2001 From: LKJacky <108643365+LKJacky@users.noreply.github.com> Date: Mon, 10 Oct 2022 17:30:25 +0800 Subject: [PATCH] merge pruning into dev-1.x (#312) * add ChannelGroup (#250) * rebase new dev-1.x * modification for adding config_template * add docstring to channel_group.py * add docstring to mutable_channel_group.py * rm channel_group_cfg from Graph2ChannelGroups * change choice type of SequentialChannelGroup from float to int * add a warning about group-wise conv * restore __init__ of dynamic op * in_channel_mutable -> mutable_in_channel * rm abstractproperty * add a comment about VT * rm registry for ChannelGroup * MUTABLECHANNELGROUP -> ChannelGroupType * refine docstring of IndexDict * update docstring * update docstring * is_prunable -> is_mutable * update docstring * fix error in pre-commit * update unittest * add return type * unify init_xxx apit * add unitest about init of MutableChannelGroup * update according to reviews * sequential_channel_group -> sequential_mutable_channel_group Co-authored-by: liukai * Add BaseChannelMutator and refactor Autoslim (#289) * add BaseChannelMutator * add autoslim * tmp * make SequentialMutableChannelGroup accpeted both of num and ratio as choice. and supports divisior * update OneShotMutableChannelGroup * pass supernet training of autoslim * refine autoslim * fix bug in OneShotMutableChannelGroup * refactor make_divisible * fix spell error: channl -> channel * init_using_backward_tracer -> init_from_backward_tracer init_from_fx_tracer -> init_from_fx_tracer * refine SequentialMutableChannelGroup * let mutator support models with dynamicop * support define search space in model * tracer_cfg -> parse_cfg * refine * using -> from * update docstring * update docstring Co-authored-by: liukai * refactor slimmable and add l1-norm (#291) * refactor slimmable and add l1-norm * make l1-norm support convnd * update get_channel_groups * add l1-norm_resnet34_8xb32_in1k.py * add pretrained to resnet34-l1 * remove old channel mutator * BaseChannelMutator -> ChannelMutator * update according to reviews * add readme to l1-norm * MBV2_slimmable -> MBV2_slimmable_config Co-authored-by: liukai * Clean old codes. (#296) * remove old dynamic ops * move dynamic ops * clean old mutable_channels * rm OneShotMutableChannel * rm MutableChannel * refine * refine * use SquentialMutableChannel to replace OneshotMutableChannel * refactor dynamicops folder * let SquentialMutableChannel support float Co-authored-by: liukai * Add channel-flow (#301) * base_channel_mutator -> channel_mutator * init * update docstring * allow omitting redundant configs for channel * add register_mutable_channel_to_a_module to MutableChannelContainer * update according to reviews 1 * update according to reviews 2 * update according to reviews 3 * remove old docstring * fix error * using->from * update according to reviews * support self-define input channel number * update docstring * chanenl -> channel_elem Co-authored-by: liukai Co-authored-by: jacky * Rename: ChannelGroup -> ChannelUnit (#302) * refine repr of MutableChannelGroup * rename folder name * ChannelGroup -> ChannelUnit * filename in units folder * channel_group -> channel_unit * groups -> units * group -> unit * update * get_mutable_channel_groups -> get_mutable_channel_units * fix bug * refine docstring * fix ci * fix bug in tracer Co-authored-by: liukai * Merge dev-1.x to pruning (#311) * [feature] CONTRASTIVE REPRESENTATION DISTILLATION with dataset wrapper (#281) * init * TD: CRDLoss * complete UT * fix docstrings * fix ci * update * fix CI * DONE * maintain CRD dataset unique funcs as a mixin * maintain CRD dataset unique funcs as a mixin * maintain CRD dataset unique funcs as a mixin * add UT: CRD_ClsDataset * init * TODO: UT test formatting. * init * crd dataset wrapper * update docstring Co-authored-by: huangpengsheng * [Improvement] Update estimator with api revision (#277) * update estimator usage and fix bugs * refactor api of estimator & add inner check methods * fix docstrings * update search loop and config * fix lint * update unittest * decouple mmdet dependency and fix lint Co-authored-by: humu789 * [Fix] Fix tracer (#273) * test image_classifier_loss_calculator * fix backward tracer * update SingleStageDetectorPseudoLoss * merge * [Feature] Add Dsnas Algorithm (#226) * [tmp] Update Dsnas * [tmp] refactor arch_loss & flops_loss * Update Dsnas & MMRAZOR_EVALUATOR: 1. finalized compute_loss & handle_grads in algorithm; 2. add MMRAZOR_EVALUATOR; 3. fix bugs. * Update lr scheduler & fix a bug: 1. update param_scheduler & lr_scheduler for dsnas; 2. fix a bug of switching to finetune stage. * remove old evaluators * remove old evaluators * update param_scheduler config * merge dev-1.x into gy/estimator * add flops_loss in Dsnas using ResourcesEstimator * get resources before mutator.prepare_from_supernet * delete unness broadcast api from gml * broadcast spec_modules_resources when estimating * update early fix mechanism for Dsnas * fix merge * update units in estimator * minor change * fix data_preprocessor api * add flops_loss_coef * remove DsnasOptimWrapper * fix bn eps and data_preprocessor * fix bn weight decay bug * add betas for mutator optimizer * set diff_rank_seed=True for dsnas * fix start_factor of lr when warm up * remove .module in non-ddp mode * add GlobalAveragePoolingWithDropout * add UT for dsnas * remove unness channel adjustment for shufflenetv2 * update supernet configs * delete unness dropout * delete unness part with minor change on dsnas * minor change on the flag of search stage * update README and subnet configs * add UT for OneHotMutableOP * [Feature] Update train (#279) * support auto resume * add enable auto_scale_lr in train.py * support '--amp' option * [Fix] Fix darts metafile (#278) fix darts metafile * fix ci (#284) * fix ci for circle ci * fix bug in test_metafiles * add pr_stage_test for github ci * add multiple version * fix ut * fix lint * Temporarily skip dataset UT * update github ci * add github lint ci * install wheel * remove timm from requirements * install wheel when test on windows * fix error * fix bug * remove github windows ci * fix device error of arch_params when DsnasDDP * fix CRD dataset ut * fix scope error * rm test_cuda in workflows of github * [Doc] fix typos in en/usr_guides Co-authored-by: liukai Co-authored-by: pppppM Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: huangpengsheng Co-authored-by: SheffieldCao <1751899@tongji.edu.cn> Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: huangpengsheng Co-authored-by: Yang Gao Co-authored-by: humu789 Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: liukai Co-authored-by: pppppM Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: SheffieldCao <1751899@tongji.edu.cn> * Refine pruning branch (#307) * [feature] CONTRASTIVE REPRESENTATION DISTILLATION with dataset wrapper (#281) * init * TD: CRDLoss * complete UT * fix docstrings * fix ci * update * fix CI * DONE * maintain CRD dataset unique funcs as a mixin * maintain CRD dataset unique funcs as a mixin * maintain CRD dataset unique funcs as a mixin * add UT: CRD_ClsDataset * init * TODO: UT test formatting. * init * crd dataset wrapper * update docstring Co-authored-by: huangpengsheng * [Improvement] Update estimator with api revision (#277) * update estimator usage and fix bugs * refactor api of estimator & add inner check methods * fix docstrings * update search loop and config * fix lint * update unittest * decouple mmdet dependency and fix lint Co-authored-by: humu789 * [Fix] Fix tracer (#273) * test image_classifier_loss_calculator * fix backward tracer * update SingleStageDetectorPseudoLoss * merge * [Feature] Add Dsnas Algorithm (#226) * [tmp] Update Dsnas * [tmp] refactor arch_loss & flops_loss * Update Dsnas & MMRAZOR_EVALUATOR: 1. finalized compute_loss & handle_grads in algorithm; 2. add MMRAZOR_EVALUATOR; 3. fix bugs. * Update lr scheduler & fix a bug: 1. update param_scheduler & lr_scheduler for dsnas; 2. fix a bug of switching to finetune stage. * remove old evaluators * remove old evaluators * update param_scheduler config * merge dev-1.x into gy/estimator * add flops_loss in Dsnas using ResourcesEstimator * get resources before mutator.prepare_from_supernet * delete unness broadcast api from gml * broadcast spec_modules_resources when estimating * update early fix mechanism for Dsnas * fix merge * update units in estimator * minor change * fix data_preprocessor api * add flops_loss_coef * remove DsnasOptimWrapper * fix bn eps and data_preprocessor * fix bn weight decay bug * add betas for mutator optimizer * set diff_rank_seed=True for dsnas * fix start_factor of lr when warm up * remove .module in non-ddp mode * add GlobalAveragePoolingWithDropout * add UT for dsnas * remove unness channel adjustment for shufflenetv2 * update supernet configs * delete unness dropout * delete unness part with minor change on dsnas * minor change on the flag of search stage * update README and subnet configs * add UT for OneHotMutableOP * [Feature] Update train (#279) * support auto resume * add enable auto_scale_lr in train.py * support '--amp' option * [Fix] Fix darts metafile (#278) fix darts metafile * fix ci (#284) * fix ci for circle ci * fix bug in test_metafiles * add pr_stage_test for github ci * add multiple version * fix ut * fix lint * Temporarily skip dataset UT * update github ci * add github lint ci * install wheel * remove timm from requirements * install wheel when test on windows * fix error * fix bug * remove github windows ci * fix device error of arch_params when DsnasDDP * fix CRD dataset ut * fix scope error * rm test_cuda in workflows of github * [Doc] fix typos in en/usr_guides Co-authored-by: liukai Co-authored-by: pppppM Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: huangpengsheng Co-authored-by: SheffieldCao <1751899@tongji.edu.cn> * fix bug when python=3.6 * fix lint * fix bug when test using cpu only * refine ci * fix error in ci * try ci * update repr of Channel * fix error * mv init_from_predefined_model to MutableChannelUnit * move tests * update SquentialMutableChannel * update l1 mutable channel unit * add OneShotMutableChannel * candidate_mode -> choice_mode * update docstring * change ci Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: huangpengsheng Co-authored-by: Yang Gao Co-authored-by: humu789 Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: liukai Co-authored-by: pppppM Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: SheffieldCao <1751899@tongji.edu.cn> Co-authored-by: liukai Co-authored-by: jacky Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: huangpengsheng Co-authored-by: Yang Gao Co-authored-by: humu789 Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: pppppM Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: SheffieldCao <1751899@tongji.edu.cn> --- ..._mbv2_1.5x_slimmable_subnet_8xb256_in1k.py | 14 +- ...mbv2_1.5x_subnet_8xb256_in1k_flops-220M.py | 3 +- ...mbv2_1.5x_subnet_8xb256_in1k_flops-320M.py | 3 +- ...mbv2_1.5x_subnet_8xb256_in1k_flops-530M.py | 3 +- ...autoslim_mbv2_1.5x_supernet_8xb256_in1k.py | 12 +- configs/pruning/mmcls/l1-norm/README.md | 11 + .../l1-norm/l1-norm_resnet34_8xb32_in1k.py | 56 ++ mmrazor/engine/runner/slimmable_val_loop.py | 4 +- mmrazor/models/algorithms/__init__.py | 23 +- mmrazor/models/algorithms/nas/autoslim.py | 44 +- .../algorithms/pruning/ite_prune_algorithm.py | 174 ++++++ .../algorithms/pruning/slimmable_network.py | 105 ++-- .../architectures/dynamic_ops/__init__.py | 19 +- .../models/architectures/dynamic_ops/base.py | 106 ---- .../dynamic_ops/bricks/__init__.py | 13 - .../dynamic_ops/bricks/dynamic_conv.py | 42 +- .../dynamic_ops/bricks/dynamic_linear.py | 2 +- .../dynamic_ops/bricks/dynamic_norm.py | 66 ++- .../dynamic_ops/default_dynamic_ops.py | 333 ------------ .../dynamic_ops/mixins/__init__.py | 9 + .../{bricks => mixins}/dynamic_conv_mixins.py | 0 .../{bricks => mixins}/dynamic_mixins.py | 0 .../dynamic_ops/slimmable_dynamic_ops.py | 83 --- mmrazor/models/mutables/__init__.py | 21 +- mmrazor/models/mutables/derived_mutable.py | 30 +- .../mutables/mutable_channel/__init__.py | 17 +- .../mutable_channel/base_mutable_channel.py | 92 ++++ .../mutable_channel/mutable_channel.py | 114 ---- .../mutable_channel_container.py | 122 +++++ .../one_shot_mutable_channel.py | 214 -------- .../oneshot_mutalbe_channel.py | 41 ++ .../sequential_mutable_channel.py | 138 +++++ .../mutable_channel/simple_mutable_channel.py | 51 ++ .../slimmable_mutable_channel.py | 96 ---- .../mutable_channel/units/__init__.py | 13 + .../mutable_channel/units/channel_unit.py | 287 ++++++++++ .../units/l1_mutable_channel_unit.py | 82 +++ .../units/mutable_channel_unit.py | 300 +++++++++++ .../units/one_shot_mutable_channel_unit.py | 135 +++++ .../units/sequential_mutable_channel_unit.py | 148 ++++++ .../units/slimmable_channel_unit.py | 59 ++ .../mutables/mutable_value/mutable_value.py | 8 +- .../mutators/channel_mutator/__init__.py | 2 +- .../channel_mutator/channel_mutator.py | 502 ++++++++++-------- .../one_shot_channel_mutator.py | 158 ++---- .../slimmable_channel_mutator.py | 180 ++----- mmrazor/models/mutators/utils/__init__.py | 16 - .../utils/default_module_converters.py | 126 ----- .../mutators/utils/slimmable_bn_converter.py | 23 - mmrazor/structures/graph/channel_graph.py | 71 +++ mmrazor/structures/graph/channel_modules.py | 372 +++++++++++++ mmrazor/structures/graph/channel_nodes.py | 378 +++++++++++++ mmrazor/structures/graph/module_graph.py | 47 +- mmrazor/structures/subnet/fix_subnet.py | 4 +- mmrazor/utils/__init__.py | 4 +- mmrazor/utils/index_dict.py | 61 +++ tests/__init__.py | 3 - tests/data/MBV2_slimmable_config.json | 392 ++++++++++++++ .../head => tests/data}/__init__.py | 0 tests/data/models.py | 281 +++++++++- tests/test_core/__init__.py | 1 + tests/test_core/test_graph/__init__.py | 1 + .../test_graph/test_channel_graph.py | 178 +++++++ tests/test_core/test_graph/test_graph.py | 132 +++-- tests/test_models/__init__.py | 1 + tests/test_models/test_algorithms/__init__.py | 1 + .../test_algorithms/test_autoslim.py | 45 +- .../test_algorithms/test_ofd_algo.py | 2 +- .../test_algorithms/test_prune_algorithm.py | 169 ++++++ .../test_single_teacher_distill.py | 2 +- .../test_algorithms/test_slimmable_network.py | 106 ++-- .../test_bricks/test_dynamic_conv.py | 31 +- .../test_bricks/test_dynamic_linear.py | 16 +- .../test_bricks/test_dynamic_norm.py | 14 +- .../test_default_dynamic_op.py | 93 ---- .../test_dynamic_op/utils.py | 5 +- tests/test_models/test_mutables/__init__.py | 1 + .../test_mutables/test_channel_mutable.py | 129 ----- .../test_mutables/test_derived_mutable.py | 50 +- .../test_mutables/test_dynamic_layer.py | 143 ----- .../test_mutable_channel/__init__.py | 1 + .../test_mutable_channels.py | 35 ++ .../test_sequential_mutable_channel.py | 43 ++ .../test_units/__init__.py | 1 + .../test_l1_mutable_channel_unit.py | 32 ++ .../test_units/test_mutable_channel_units.py | 147 +++++ .../test_one_shot_mutable_channel_unit.py | 16 + .../test_sequential_mutable_channel_unit.py | 41 ++ .../test_mutables/test_mutable_value.py | 7 +- .../test_mutators/test_channel_mutator.py | 404 +++++--------- .../test_mbv2_channel_mutator.py | 111 ---- .../test_mutators/test_one_shot_mutator.py | 121 ----- tests/test_models/test_mutators/utils.py | 15 - tests/test_utils/test_index_dict.py | 16 + tools/get_channel_units.py | 67 +++ 95 files changed, 4992 insertions(+), 2898 deletions(-) create mode 100644 configs/pruning/mmcls/l1-norm/README.md create mode 100644 configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k.py create mode 100644 mmrazor/models/algorithms/pruning/ite_prune_algorithm.py delete mode 100644 mmrazor/models/architectures/dynamic_ops/base.py delete mode 100644 mmrazor/models/architectures/dynamic_ops/default_dynamic_ops.py create mode 100644 mmrazor/models/architectures/dynamic_ops/mixins/__init__.py rename mmrazor/models/architectures/dynamic_ops/{bricks => mixins}/dynamic_conv_mixins.py (100%) rename mmrazor/models/architectures/dynamic_ops/{bricks => mixins}/dynamic_mixins.py (100%) delete mode 100644 mmrazor/models/architectures/dynamic_ops/slimmable_dynamic_ops.py create mode 100644 mmrazor/models/mutables/mutable_channel/base_mutable_channel.py delete mode 100644 mmrazor/models/mutables/mutable_channel/mutable_channel.py create mode 100644 mmrazor/models/mutables/mutable_channel/mutable_channel_container.py delete mode 100644 mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py create mode 100644 mmrazor/models/mutables/mutable_channel/oneshot_mutalbe_channel.py create mode 100644 mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py create mode 100644 mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py delete mode 100644 mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py create mode 100644 mmrazor/models/mutables/mutable_channel/units/__init__.py create mode 100644 mmrazor/models/mutables/mutable_channel/units/channel_unit.py create mode 100644 mmrazor/models/mutables/mutable_channel/units/l1_mutable_channel_unit.py create mode 100644 mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py create mode 100644 mmrazor/models/mutables/mutable_channel/units/one_shot_mutable_channel_unit.py create mode 100644 mmrazor/models/mutables/mutable_channel/units/sequential_mutable_channel_unit.py create mode 100644 mmrazor/models/mutables/mutable_channel/units/slimmable_channel_unit.py delete mode 100644 mmrazor/models/mutators/utils/__init__.py delete mode 100644 mmrazor/models/mutators/utils/default_module_converters.py delete mode 100644 mmrazor/models/mutators/utils/slimmable_bn_converter.py create mode 100644 mmrazor/structures/graph/channel_graph.py create mode 100644 mmrazor/structures/graph/channel_modules.py create mode 100644 mmrazor/structures/graph/channel_nodes.py create mode 100644 mmrazor/utils/index_dict.py create mode 100644 tests/data/MBV2_slimmable_config.json rename {mmrazor/models/architectures/dynamic_ops/head => tests/data}/__init__.py (100%) create mode 100644 tests/test_core/__init__.py create mode 100644 tests/test_core/test_graph/__init__.py create mode 100644 tests/test_core/test_graph/test_channel_graph.py create mode 100644 tests/test_models/__init__.py create mode 100644 tests/test_models/test_algorithms/__init__.py create mode 100644 tests/test_models/test_algorithms/test_prune_algorithm.py delete mode 100644 tests/test_models/test_architectures/test_dynamic_op/test_default_dynamic_op.py create mode 100644 tests/test_models/test_mutables/__init__.py delete mode 100644 tests/test_models/test_mutables/test_channel_mutable.py delete mode 100644 tests/test_models/test_mutables/test_dynamic_layer.py create mode 100644 tests/test_models/test_mutables/test_mutable_channel/__init__.py create mode 100644 tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py create mode 100644 tests/test_models/test_mutables/test_mutable_channel/test_sequential_mutable_channel.py create mode 100644 tests/test_models/test_mutables/test_mutable_channel/test_units/__init__.py create mode 100644 tests/test_models/test_mutables/test_mutable_channel/test_units/test_l1_mutable_channel_unit.py create mode 100644 tests/test_models/test_mutables/test_mutable_channel/test_units/test_mutable_channel_units.py create mode 100644 tests/test_models/test_mutables/test_mutable_channel/test_units/test_one_shot_mutable_channel_unit.py create mode 100644 tests/test_models/test_mutables/test_mutable_channel/test_units/test_sequential_mutable_channel_unit.py delete mode 100644 tests/test_models/test_mutators/test_classical_models/test_mbv2_channel_mutator.py delete mode 100644 tests/test_models/test_mutators/test_one_shot_mutator.py delete mode 100644 tests/test_models/test_mutators/utils.py create mode 100644 tests/test_utils/test_index_dict.py create mode 100644 tools/get_channel_units.py diff --git a/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py b/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py index bd2a415f0..4bcd0e0af 100644 --- a/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py +++ b/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py @@ -21,11 +21,6 @@ # !autoslim algorithm config # ========================================================================== -channel_cfg_paths = [ - 'https://download.openmmlab.com/mmrazor/v1/autoslim/autoslim_mbv2_subnet_8xb256_in1k_flops-220M_acc-71.4_20220715-9c288f3b_subnet_cfg.yaml', # noqa: E501 - 'https://download.openmmlab.com/mmrazor/v1/autoslim/autoslim_mbv2_subnet_8xb256_in1k_flops-320M_acc-72.73_20220715-9aa8f8ae_subnet_cfg.yaml', # noqa: E501 - 'https://download.openmmlab.com/mmrazor/v1/autoslim/autoslim_mbv2_subnet_8xb256_in1k_flops-530M_acc-74.23_20220715-aa8754fe_subnet_cfg.yaml' # noqa: E501 -] model = dict( _delete_=True, @@ -33,11 +28,12 @@ type='SlimmableNetwork', architecture=supernet, data_preprocessor=data_preprocessor, - channel_cfg_paths=channel_cfg_paths, mutator=dict( type='SlimmableChannelMutator', - mutable_cfg=dict(type='SlimmableMutableChannel'), - tracer_cfg=dict( + channel_unit_cfg=dict( + type='SlimmableChannelUnit', + units='tests/data/MBV2_slimmable_config.json'), + parse_cfg=dict( type='BackwardTracer', loss_calculator=dict(type='ImageClassifierPseudoLoss')))) @@ -46,6 +42,6 @@ broadcast_buffers=False, find_unused_parameters=True) -optim_wrapper = dict(accumulative_counts=len(channel_cfg_paths)) +optim_wrapper = dict(accumulative_counts=3) val_cfg = dict(type='mmrazor.SlimmableValLoop') diff --git a/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-220M.py b/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-220M.py index 221907c40..2ed71daf5 100644 --- a/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-220M.py +++ b/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-220M.py @@ -1,4 +1,3 @@ _base_ = 'autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py' -_channel_cfg_paths = 'https://download.openmmlab.com/mmrazor/v1/autoslim/autoslim_mbv2_subnet_8xb256_in1k_flops-220M_acc-71.4_20220715-9c288f3b_subnet_cfg.yaml' # noqa: E501 -model = dict(channel_cfg_paths=_channel_cfg_paths) +model = dict(deploy_index=0) diff --git a/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-320M.py b/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-320M.py index b9e215196..e53aae1bc 100644 --- a/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-320M.py +++ b/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-320M.py @@ -1,4 +1,3 @@ _base_ = 'autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py' -_channel_cfg_paths = 'https://download.openmmlab.com/mmrazor/v1/autoslim/autoslim_mbv2_subnet_8xb256_in1k_flops-320M_acc-72.73_20220715-9aa8f8ae_subnet_cfg.yaml' # noqa: E501 -model = dict(channel_cfg_paths=_channel_cfg_paths) +model = dict(deploy_index=1) diff --git a/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-530M.py b/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-530M.py index 964f69198..218a9b036 100644 --- a/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-530M.py +++ b/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-530M.py @@ -1,4 +1,3 @@ _base_ = 'autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py' -_channel_cfg_paths = 'https://download.openmmlab.com/mmrazor/v1/autoslim/autoslim_mbv2_subnet_8xb256_in1k_flops-530M_acc-74.23_20220715-aa8754fe_subnet_cfg.yaml' # noqa: E501 -model = dict(channel_cfg_paths=_channel_cfg_paths) +model = dict(deploy_index=2) diff --git a/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_supernet_8xb256_in1k.py b/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_supernet_8xb256_in1k.py index dfcefab56..6249b0160 100644 --- a/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_supernet_8xb256_in1k.py +++ b/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_supernet_8xb256_in1k.py @@ -43,11 +43,13 @@ preds_T=dict(recorder='fc', from_student=False)))), mutator=dict( type='OneShotChannelMutator', - mutable_cfg=dict( - type='OneShotMutableChannel', - candidate_choices=list(i / 12 for i in range(2, 13)), - candidate_mode='ratio'), - tracer_cfg=dict( + channel_unit_cfg=dict( + type='OneShotMutableChannelUnit', + default_args=dict( + candidate_choices=list(i / 12 for i in range(2, 13)), + choice_mode='ratio', + divisor=8)), + parse_cfg=dict( type='BackwardTracer', loss_calculator=dict(type='ImageClassifierPseudoLoss')))) diff --git a/configs/pruning/mmcls/l1-norm/README.md b/configs/pruning/mmcls/l1-norm/README.md new file mode 100644 index 000000000..e19726bf5 --- /dev/null +++ b/configs/pruning/mmcls/l1-norm/README.md @@ -0,0 +1,11 @@ +# L1-norm pruning + +> [Pruning Filters for Efficient ConvNets.](https://arxiv.org/pdf/1608.08710.pdf) + + + +## Implementation + +L1-norm pruning is a classical filter pruning algorithm. It prunes filers(channels) according to the l1-norm of the weight of a conv layer. + +We use ItePruneAlgorithm and L1MutableChannelUnit to implement l1-norm pruning. Please refer to xxxx for more configuration detail. diff --git a/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k.py b/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k.py new file mode 100644 index 000000000..89ef4138f --- /dev/null +++ b/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k.py @@ -0,0 +1,56 @@ +_base_ = ['mmcls::resnet/resnet34_8xb32_in1k.py'] + +stage_ratio_1 = 0.7 +stage_ratio_2 = 0.7 +stage_ratio_3 = 0.7 +stage_ratio_4 = 1.0 + +# the config template of target_pruning_ratio can be got by +# python ./tools/get_channel_units.py {config_file} --choice +target_pruning_ratio = { + 'backbone.layer1.2.conv2_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.1.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.2.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer2.0.conv1_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.3.conv2_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.1.conv1_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.2.conv1_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.3.conv1_(0, 128)_128': stage_ratio_2, + 'backbone.layer3.0.conv1_(0, 256)_256': stage_ratio_3, + 'backbone.layer3.5.conv2_(0, 256)_256': stage_ratio_3, + 'backbone.layer3.1.conv1_(0, 256)_256': stage_ratio_3, + 'backbone.layer3.2.conv1_(0, 256)_256': stage_ratio_3, + 'backbone.layer3.3.conv1_(0, 256)_256': stage_ratio_3, + 'backbone.layer3.4.conv1_(0, 256)_256': stage_ratio_3, + 'backbone.layer3.5.conv1_(0, 256)_256': stage_ratio_3, + 'backbone.layer4.0.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.2.conv2_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.1.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.2.conv1_(0, 512)_512': stage_ratio_4 +} +data_preprocessor = {'type': 'mmcls.ClsDataPreprocessor'} +architecture = _base_.model +architecture.update({ + 'init_cfg': { + 'type': + 'Pretrained', + 'checkpoint': + 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa + } +}) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='ItePruneAlgorithm', + architecture=architecture, + mutator_cfg=dict( + type='ChannelMutator', + channel_unit_cfg=dict( + type='L1MutableChannelUnit', + default_args=dict(choice_mode='ratio'))), + target_pruning_ratio=target_pruning_ratio, + step_epoch=1, + prune_times=1, +) diff --git a/mmrazor/engine/runner/slimmable_val_loop.py b/mmrazor/engine/runner/slimmable_val_loop.py index f830ffa84..d3f5e2a4e 100644 --- a/mmrazor/engine/runner/slimmable_val_loop.py +++ b/mmrazor/engine/runner/slimmable_val_loop.py @@ -43,10 +43,10 @@ def run(self): self.runner.call_hook('before_val') all_metrics = dict() - for subnet_idx in range(self._model.num_subnet): + for subnet_idx, subnet in enumerate(self._model.mutator.subnets): self.runner.call_hook('before_val_epoch') self.runner.model.eval() - self._model.mutator.switch_choices(subnet_idx) + self._model.mutator.set_choices(subnet) for idx, data_batch in enumerate(self.dataloader): self.run_iter(idx, data_batch) # compute student metrics diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index 2d96f3a96..e6258b012 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -5,11 +5,24 @@ SelfDistill, SingleTeacherDistill) from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP, Dsnas, DsnasDDP from .pruning import SlimmableNetwork, SlimmableNetworkDDP +from .pruning.ite_prune_algorithm import ItePruneAlgorithm __all__ = [ - 'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS', - 'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP', - 'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation', - 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'Dsnas', - 'DsnasDDP' + 'SingleTeacherDistill', + 'BaseAlgorithm', + 'FpnTeacherDistill', + 'SPOS', + 'SlimmableNetwork', + 'SlimmableNetworkDDP', + 'AutoSlim', + 'AutoSlimDDP', + 'Darts', + 'DartsDDP', + 'SelfDistill', + 'DataFreeDistillation', + 'DAFLDataFreeDistillation', + 'OverhaulFeatureDistillation', + 'ItePruneAlgorithm', + 'Dsnas', + 'DsnasDDP', ] diff --git a/mmrazor/models/algorithms/nas/autoslim.py b/mmrazor/models/algorithms/nas/autoslim.py index 398574a36..a9e5dcdd5 100644 --- a/mmrazor/models/algorithms/nas/autoslim.py +++ b/mmrazor/models/algorithms/nas/autoslim.py @@ -14,7 +14,6 @@ from mmrazor.models.utils import (add_prefix, reinitialize_optim_wrapper_count_status) from mmrazor.registry import MODEL_WRAPPERS, MODELS -from mmrazor.utils import SingleMutatorRandomSubnet from ..base import BaseAlgorithm VALID_MUTATOR_TYPE = Union[OneShotChannelMutator, Dict] @@ -33,10 +32,24 @@ def __init__(self, data_preprocessor: Optional[Union[Dict, nn.Module]] = None, init_cfg: Optional[Dict] = None, num_samples: int = 2) -> None: + """Implementation of Autoslim algorithm. Please refer to + https://arxiv.org/abs/1903.11728 for more details. + + Args: + mutator (VALID_MUTATOR_TYPE): config of mutator. + distiller (VALID_DISTILLER_TYPE): config of distiller. + architecture (Union[BaseModel, Dict]): the model to be searched. + data_preprocessor (Optional[Union[Dict, nn.Module]], optional): + data prepocessor. Defaults to None. + init_cfg (Optional[Dict], optional): config of initialization. + Defaults to None. + num_samples (int, optional): number of sample subnets. + Defaults to 2. + """ super().__init__(architecture, data_preprocessor, init_cfg) - self.mutator = self._build_mutator(mutator) - # `prepare_from_supernet` must be called before distiller initialized + self.mutator: OneShotChannelMutator = MODELS.build(mutator) + # prepare_from_supernet` must be called before distiller initialized self.mutator.prepare_from_supernet(self.architecture) self.distiller = self._build_distiller(distiller) @@ -49,7 +62,7 @@ def __init__(self, def _build_mutator(self, mutator: VALID_MUTATOR_TYPE) -> OneShotChannelMutator: - """build mutator.""" + """Build mutator.""" if isinstance(mutator, dict): mutator = MODELS.build(mutator) if not isinstance(mutator, OneShotChannelMutator): @@ -61,6 +74,7 @@ def _build_mutator(self, def _build_distiller( self, distiller: VALID_DISTILLER_TYPE) -> ConfigurableDistiller: + """Build distiller.""" if isinstance(distiller, dict): distiller = MODELS.build(distiller) if not isinstance(distiller, ConfigurableDistiller): @@ -70,20 +84,25 @@ def _build_distiller( return distiller - def sample_subnet(self) -> SingleMutatorRandomSubnet: + def sample_subnet(self) -> Dict: + """Sample a subnet.""" return self.mutator.sample_choices() - def set_subnet(self, subnet: SingleMutatorRandomSubnet) -> None: + def set_subnet(self, subnet) -> None: + """Set a subnet.""" self.mutator.set_choices(subnet) def set_max_subnet(self) -> None: - self.mutator.set_max_choices() + """Set max subnet.""" + self.mutator.set_choices(self.mutator.max_choices()) def set_min_subnet(self) -> None: - return self.mutator.set_min_choices() + """Set min subnet.""" + self.mutator.set_choices(self.mutator.min_choices()) def train_step(self, data: List[dict], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: + """Train step.""" def distill_step( batch_inputs: torch.Tensor, data_samples: List[BaseDataElement] @@ -109,7 +128,9 @@ def distill_step( accumulative_counts=self.num_samples + 2) self._optim_wrapper_count_status_reinitialized = True - batch_inputs, data_samples = self.data_preprocessor(data, True) + input_data = self.data_preprocessor(data, True) + batch_inputs = input_data['inputs'] + data_samples = input_data['data_samples'] total_losses = dict() self.set_max_subnet() @@ -136,6 +157,7 @@ def distill_step( @MODEL_WRAPPERS.register_module() class AutoSlimDDP(MMDistributedDataParallel): + """DDPwapper for autoslim.""" def __init__(self, *, @@ -175,7 +197,9 @@ def distill_step( accumulative_counts=self.module.num_samples + 2) self._optim_wrapper_count_status_reinitialized = True - batch_inputs, data_samples = self.module.data_preprocessor(data, True) + input_data = self.module.data_preprocessor(data, True) + batch_inputs = input_data['inputs'] + data_samples = input_data['data_samples'] total_losses = dict() self.module.set_max_subnet() diff --git a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py new file mode 100644 index 000000000..cca03a71f --- /dev/null +++ b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py @@ -0,0 +1,174 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmengine import MessageHub, MMLogger +from mmengine.model import BaseModel +from mmengine.structures import BaseDataElement + +from mmrazor.models.mutators import ChannelMutator +from mmrazor.registry import MODELS +from ..base import BaseAlgorithm + +LossResults = Dict[str, torch.Tensor] +TensorResults = Union[Tuple[torch.Tensor], torch.Tensor] +PredictResults = List[BaseDataElement] +ForwardResults = Union[LossResults, TensorResults, PredictResults] + + +class ItePruneConfigManager: + """ItePruneConfigManager manages the config of the structure of the model + during pruning. + + Args: + target (Dict[str, Union[int, float]]): The target structure to prune. + supernet (Dict[str, Union[int, float]]): The sturecture of the + supernet. + epoch_step (int, optional): The prune step to prune. Defaults to 1. + times (int, optional): The times to prune. Defaults to 1. + """ + + def __init__(self, + target: Dict[str, Union[int, float]], + supernet: Dict[str, Union[int, float]], + epoch_step=1, + times=1) -> None: + + self.supernet = supernet + self.target = target + self.epoch_step = epoch_step + self.prune_times = times + + self.delta: Dict = self._get_delta_each_epoch(self.target, + self.supernet, + self.prune_times) + + def is_prune_time(self, epoch, ite): + """Is the time to prune during training process.""" + return epoch % self.epoch_step == 0 \ + and epoch//self.epoch_step < self.prune_times \ + and ite == 0 + + def prune_at(self, epoch): + """Get the pruning structure in a time(epoch).""" + times = epoch // self.epoch_step + 1 + assert times <= self.prune_times + prune_current = {} + ratio = times / self.prune_times + + for key in self.target: + prune_current[key] = (self.target[key] - self.supernet[key] + ) * ratio + self.supernet[key] + if isinstance(self.supernet[key], int): + prune_current[key] = int(prune_current[key]) + return prune_current + + def _get_delta_each_epoch(self, target: Dict, supernet: Dict, times: int): + """Get the structure change for pruning once.""" + delta = {} + for key in target: + one_target = target[key] + if isinstance(one_target, float): + delta[key] = (1.0 - one_target) / times + elif isinstance(one_target, int): + delta[key] = int((supernet[key] - one_target) / times) + else: + raise NotImplementedError() + return delta + + +@MODELS.register_module() +class ItePruneAlgorithm(BaseAlgorithm): + """ItePruneAlgorithm prunes a model iteratively until reaching a prune- + target. + + Args: + architecture (Union[BaseModel, Dict]): The model to be pruned. + mutator_cfg (Union[Dict, ChannelMutator], optional): The config + of a mutator. Defaults to dict( type='ChannelMutator', + channel_unit_cfg=dict( type='SequentialMutableChannelUnit')). + data_preprocessor (Optional[Union[Dict, nn.Module]], optional): + Defaults to None. + target_pruning_ratio (dict, optional): The prune-target. The template + of the prune-target can be get by calling + mutator.choice_template(). Defaults to {}. + step_epoch (int, optional): The step between two pruning operations. + Defaults to 1. + prune_times (int, optional): The times to prune a model. Defaults to 1. + init_cfg (Optional[Dict], optional): init config for architecture. + Defaults to None. + """ + + def __init__(self, + architecture: Union[BaseModel, Dict], + mutator_cfg: Union[Dict, ChannelMutator] = dict( + type='ChannelMutator', + channel_unit_cfg=dict( + type='SequentialMutableChannelUnit')), + data_preprocessor: Optional[Union[Dict, nn.Module]] = None, + target_pruning_ratio={}, + step_epoch=1, + prune_times=1, + init_cfg: Optional[Dict] = None) -> None: + + super().__init__(architecture, data_preprocessor, init_cfg) + + # mutator + self.mutator: ChannelMutator = MODELS.build(mutator_cfg) + self.mutator.prepare_from_supernet(self.architecture) + + # config_manager + self.check_prune_targe(target_pruning_ratio) + self.prune_config_manager = ItePruneConfigManager( + target_pruning_ratio, + self.mutator.choice_template, + step_epoch, + times=prune_times) + + def check_prune_targe(self, config: Dict): + """Check if the prune-target is supported.""" + for value in config.values(): + assert isinstance(value, int) or isinstance(value, float) + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + mode: str = 'tensor') -> ForwardResults: + """Forward.""" + print(self._epoch, self._iteration) + if self.prune_config_manager.is_prune_time(self._epoch, + self._iteration): + + config = self.prune_config_manager.prune_at(self._epoch) + self.mutator.set_choices(config) + logger = MMLogger.get_current_instance() + logger.info(f'The model is pruned at {self._epoch}th epoch once.') + + return super().forward(inputs, data_samples, mode) + + def init_weights(self): + return self.architecture.init_weights() + + # private methods + + @property + def _epoch(self): + """Get current epoch number.""" + message_hub = MessageHub.get_current_instance() + if 'epoch' in message_hub.runtime_info: + return message_hub.runtime_info['epoch'] + else: + return 0 + + @property + def _iteration(self): + """Get current iteration number.""" + message_hub = MessageHub.get_current_instance() + if 'iter' in message_hub.runtime_info: + iter = message_hub.runtime_info['iter'] + max_iter = message_hub.runtime_info['max_iters'] + max_epoch = message_hub.runtime_info['max_epochs'] + return iter % (max_iter // max_epoch) + else: + return 0 diff --git a/mmrazor/models/algorithms/pruning/slimmable_network.py b/mmrazor/models/algorithms/pruning/slimmable_network.py index 67a08c920..429c2c856 100644 --- a/mmrazor/models/algorithms/pruning/slimmable_network.py +++ b/mmrazor/models/algorithms/pruning/slimmable_network.py @@ -1,20 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy import os from pathlib import Path from typing import Dict, List, Optional, Union import torch -from mmengine import fileio from mmengine.model import BaseModel, MMDistributedDataParallel from mmengine.optim import OptimWrapper from mmengine.structures import BaseDataElement from torch import nn +from mmrazor.models.mutables import BaseMutable from mmrazor.models.mutators import SlimmableChannelMutator from mmrazor.models.utils import (add_prefix, reinitialize_optim_wrapper_count_status) from mmrazor.registry import MODEL_WRAPPERS, MODELS +from mmrazor.structures.subnet.fix_subnet import _dynamic_to_static from ..base import BaseAlgorithm VALID_MUTATOR_TYPE = Union[SlimmableChannelMutator, Dict] @@ -32,11 +32,11 @@ class SlimmableNetwork(BaseAlgorithm): Args: mutator (dict | :obj:`SlimmableChannelMutator`): The config of :class:`SlimmableChannelMutator` or built mutator. + About the config of mutator, please refer to + SlimmableChannelMutator architecture (dict | :obj:`BaseModel`): The config of :class:`BaseModel` or built model. - channel_cfg_paths (str | :obj:`Path` | list): Config of list of configs - for channel of subnet(s) searched out. If there is only one - channel_cfg, the supernet will be fixed. + deploy_index (int): index of subnet to be deployed. data_preprocessor (dict | :obj:`torch.nn.Module` | None): The pre-process config of :class:`BaseDataPreprocessor`. Defaults to None. @@ -47,31 +47,21 @@ class SlimmableNetwork(BaseAlgorithm): def __init__(self, mutator: VALID_MUTATOR_TYPE, architecture: Union[BaseModel, Dict], - channel_cfg_paths: VALID_CHANNEL_CFG_PATH_TYPE, + deploy_index=-1, data_preprocessor: Optional[Union[Dict, nn.Module]] = None, init_cfg: Optional[Dict] = None) -> None: super().__init__(architecture, data_preprocessor, init_cfg) - if not isinstance(channel_cfg_paths, list): - channel_cfg_paths = [channel_cfg_paths] - self.num_subnet = len(channel_cfg_paths) - - channel_cfgs = self._load_and_merge_channel_cfgs(channel_cfg_paths) if isinstance(mutator, dict): - assert mutator.get('channel_cfgs') is None, \ - '`channel_cfgs` should not be in channel config' - mutator = copy.deepcopy(mutator) - mutator['channel_cfgs'] = channel_cfgs - - self.mutator: SlimmableChannelMutator = self._build_mutator(mutator) + self.mutator = MODELS.build(mutator) + else: + self.mutator = mutator self.mutator.prepare_from_supernet(self.architecture) + self.num_subnet = len(self.mutator.subnets) # must after `prepare_from_supernet` - if len(channel_cfg_paths) == 1: - # Avoid circular import - from mmrazor.structures import load_fix_subnet - load_fix_subnet(self.architecture, channel_cfg_paths[0]) - self.is_deployed = True + if deploy_index != -1: + self._deploy(deploy_index) else: self.is_deployed = False @@ -81,52 +71,12 @@ def __init__(self, # in our slimmable train step. self._optim_wrapper_count_status_reinitialized = False - def _load_and_merge_channel_cfgs( - self, channel_cfg_paths: List[VALID_PATH_TYPE]) -> Dict: - """Load and merge channel config.""" - channel_cfgs = list() - for channel_cfg_path in channel_cfg_paths: - channel_cfg = fileio.load(channel_cfg_path) - channel_cfgs.append(channel_cfg) - - return self.merge_channel_cfgs(channel_cfgs) - - @staticmethod - def merge_channel_cfgs(channel_cfgs: List[Dict]) -> Dict: - """Merge several channel configs.""" - merged_channel_cfg = dict() - num_subnet = len(channel_cfgs) - - for module_name in channel_cfgs[0].keys(): - channels_per_layer = [ - channel_cfgs[idx][module_name] for idx in range(num_subnet) - ] - merged_channels_per_layer = dict() - for key in channels_per_layer[0].keys(): - merged_channels = [ - channels_per_layer[idx][key] for idx in range(num_subnet) - ] - merged_channels_per_layer[key] = merged_channels - merged_channel_cfg[module_name] = merged_channels_per_layer - - return merged_channel_cfg - - def _build_mutator(self, - mutator: VALID_MUTATOR_TYPE) -> SlimmableChannelMutator: - """build mutator.""" - if isinstance(mutator, dict): - mutator = MODELS.build(mutator) - if not isinstance(mutator, SlimmableChannelMutator): - raise TypeError('mutator should be a `dict` or ' - '`SlimmableChannelMutator` instance, but got ' - f'{type(mutator)}') - - return mutator - def train_step(self, data: List[dict], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: """Train step.""" - batch_inputs, data_samples = self.data_preprocessor(data, True) + input_data = self.data_preprocessor(data, True) + batch_inputs = input_data['inputs'] + data_samples = input_data['data_samples'] train_kwargs = dict( batch_inputs=batch_inputs, data_samples=data_samples, @@ -151,8 +101,8 @@ def _slimmable_train_step( self._optim_wrapper_count_status_reinitialized = True total_losses = dict() - for subnet_idx in range(self.num_subnet): - self.mutator.switch_choices(subnet_idx) + for subnet_idx, subnet in enumerate(self.mutator.subnets): + self.mutator.set_choices(subnet) with optim_wrapper.optim_context(self): losses = self(batch_inputs, data_samples, mode='loss') parsed_losses, _ = self.parse_losses(losses) @@ -176,6 +126,19 @@ def _fixed_train_step( return losses + def _fix_archtecture(self): + for module in self.architecture.modules(): + if isinstance(module, BaseMutable): + if not module.is_fixed: + module.fix_chosen(None) + + def _deploy(self, index: int): + self.mutator.set_choices(self.mutator.subnets[index]) + self.mutator.fix_channel_mutables() + self._fix_archtecture() + _dynamic_to_static(self.architecture) + self.is_deployed = True + @MODEL_WRAPPERS.register_module() class SlimmableNetworkDDP(MMDistributedDataParallel): @@ -193,7 +156,9 @@ def __init__(self, def train_step(self, data: List[dict], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: """Train step.""" - batch_inputs, data_samples = self.module.data_preprocessor(data, True) + input_data = self.module.data_preprocessor(data, True) + batch_inputs = input_data['inputs'] + data_samples = input_data['data_samples'] train_kwargs = dict( batch_inputs=batch_inputs, data_samples=data_samples, @@ -218,8 +183,8 @@ def _slimmable_train_step( self._optim_wrapper_count_status_reinitialized = True total_losses = dict() - for subnet_idx in range(self.module.num_subnet): - self.module.mutator.switch_choices(subnet_idx) + for subnet_idx, subnet in enumerate(self.module.mutator.subnets): + self.module.mutator.set_choices(subnet) with optim_wrapper.optim_context(self): losses = self(batch_inputs, data_samples, mode='loss') parsed_losses, _ = self.module.parse_losses(losses) diff --git a/mmrazor/models/architectures/dynamic_ops/__init__.py b/mmrazor/models/architectures/dynamic_ops/__init__.py index 6b5796688..620c9e4c8 100644 --- a/mmrazor/models/architectures/dynamic_ops/__init__.py +++ b/mmrazor/models/architectures/dynamic_ops/__init__.py @@ -1,12 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base import DynamicOP -from .default_dynamic_ops import (DynamicBatchNorm, DynamicConv2d, - DynamicGroupNorm, DynamicInstanceNorm, - DynamicLinear) -from .slimmable_dynamic_ops import SwitchableBatchNorm2d +from .bricks.dynamic_conv import BigNasConv2d, DynamicConv2d, OFAConv2d +from .bricks.dynamic_linear import DynamicLinear +from .bricks.dynamic_norm import (DynamicBatchNorm1d, DynamicBatchNorm2d, + DynamicBatchNorm3d, SwitchableBatchNorm2d) +from .mixins.dynamic_conv_mixins import DynamicConvMixin +from .mixins.dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin, + DynamicLinearMixin, DynamicMixin) __all__ = [ - 'DynamicConv2d', 'DynamicLinear', 'DynamicBatchNorm', - 'DynamicInstanceNorm', 'DynamicGroupNorm', 'SwitchableBatchNorm2d', - 'DynamicOP' + 'BigNasConv2d', 'DynamicConv2d', 'OFAConv2d', 'DynamicLinear', + 'DynamicBatchNorm1d', 'DynamicBatchNorm2d', 'DynamicBatchNorm3d', + 'DynamicMixin', 'DynamicChannelMixin', 'DynamicBatchNormMixin', + 'DynamicLinearMixin', 'SwitchableBatchNorm2d', 'DynamicConvMixin' ] diff --git a/mmrazor/models/architectures/dynamic_ops/base.py b/mmrazor/models/architectures/dynamic_ops/base.py deleted file mode 100644 index 2a1720ea2..000000000 --- a/mmrazor/models/architectures/dynamic_ops/base.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from abc import ABC, abstractmethod -from typing import Any, Optional, Set - -from torch import nn - -from mmrazor.models.mutables.base_mutable import BaseMutable - - -class DynamicOP(ABC): - """Base class for dynamic OP. A dynamic OP usually consists of a normal - static OP and mutables, where mutables are used to control the searchable - (mutable) part of the dynamic OP. - - Note: - When the dynamic OP has just been initialized, its forward propagation - logic should be the same as the corresponding static OP. Only after - the searchable part accepts the specific mutable through the - corresponding interface does the part really become dynamic. - - Note: - All subclass should implement ``to_static_op`` API. - - Args: - accepted_mutables (set): The string set of all accepted mutables. - """ - accepted_mutables: Set[str] = set() - - @abstractmethod - def to_static_op(self) -> nn.Module: - """Convert dynamic OP to static OP. - - Note: - The forward result for the same input between dynamic OP and its - corresponding static OP must be same. - - Returns: - nn.Module: Corresponding static OP. - """ - - def check_if_mutables_fixed(self) -> None: - """Check if all mutables are fixed. - - Raises: - RuntimeError: Error if a existing mutable is not fixed. - """ - - def check_fixed(mutable: Optional[BaseMutable]) -> None: - if mutable is not None and not mutable.is_fixed: - raise RuntimeError(f'Mutable {type(mutable)} is not fixed.') - - for mutable in self.accepted_mutables: - check_fixed(getattr(self, f'{mutable}')) - - @staticmethod - def get_current_choice(mutable: BaseMutable) -> Any: - """Get current choice of given mutable. - - Args: - mutable (BaseMutable): Given mutable. - - Raises: - RuntimeError: Error if `current_choice` is None. - - Returns: - Any: Current choice of given mutable. - """ - current_choice = mutable.current_choice - if current_choice is None: - raise RuntimeError(f'current choice of mutable {type(mutable)} ' - 'can not be None at runtime') - - return current_choice - - -class ChannelDynamicOP(DynamicOP): - """Base class for dynamic OP with mutable channels. - - Note: - All subclass should implement ``mutable_in`` and ``mutable_out`` APIs. - """ - - @property - @abstractmethod - def mutable_in(self) -> Optional[BaseMutable]: - """Mutable related to input.""" - - @property - @abstractmethod - def mutable_out(self) -> Optional[BaseMutable]: - """Mutable related to output.""" - - @staticmethod - def check_mutable_channels(mutable_channels: BaseMutable) -> None: - """Check if mutable has `currnet_mask` attribute. - - Args: - mutable_channels (BaseMutable): Mutable to be checked. - - Raises: - ValueError: Error if mutable does not have `current_mask` - attribute. - """ - if not hasattr(mutable_channels, 'current_mask'): - raise ValueError( - 'channel mutable must have attribute `current_mask`') diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/__init__.py b/mmrazor/models/architectures/dynamic_ops/bricks/__init__.py index 7eaad52fa..ef101fec6 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/__init__.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/__init__.py @@ -1,14 +1 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .dynamic_conv import BigNasConv2d, DynamicConv2d, OFAConv2d -from .dynamic_linear import DynamicLinear -from .dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin, - DynamicLinearMixin, DynamicMixin) -from .dynamic_norm import (DynamicBatchNorm1d, DynamicBatchNorm2d, - DynamicBatchNorm3d) - -__all__ = [ - 'BigNasConv2d', 'DynamicConv2d', 'OFAConv2d', 'DynamicLinear', - 'DynamicBatchNorm1d', 'DynamicBatchNorm2d', 'DynamicBatchNorm3d', - 'DynamicMixin', 'DynamicChannelMixin', 'DynamicBatchNormMixin', - 'DynamicLinearMixin' -] diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py index 8b031e4b1..71fc7ab98 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py @@ -7,8 +7,10 @@ from mmrazor.models.mutables.base_mutable import BaseMutable from mmrazor.registry import MODELS -from .dynamic_conv_mixins import (BigNasConvMixin, DynamicConvMixin, - OFAConvMixin) +from ..mixins.dynamic_conv_mixins import (BigNasConvMixin, DynamicConvMixin, + OFAConvMixin) + +GroupWiseConvWarned = False @MODELS.register_module() @@ -37,16 +39,31 @@ def __init__(self, *args, **kwargs) -> None: def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d': """Convert an instance of nn.Conv2d to a new instance of DynamicConv2d.""" - return cls( - in_channels=module.in_channels, - out_channels=module.out_channels, - kernel_size=module.kernel_size, - stride=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - bias=True if module.bias is not None else False, - padding_mode=module.padding_mode) + # a group-wise conv will not be converted to dynamic conv + if module.groups > 1 and not (module.groups == module.out_channels == + module.in_channels): + global GroupWiseConvWarned + if GroupWiseConvWarned is False: + from mmengine import MMLogger + logger = MMLogger.get_current_instance() + logger.warning( + ('Group-wise convolutional layers are not supported to be' + 'pruned now, so they are not converted to new' + 'DynamicConvs.')) + GroupWiseConvWarned = True + + return module + else: + return cls( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + bias=True if module.bias is not None else False, + padding_mode=module.padding_mode) @property def conv_func(self) -> Callable: @@ -146,6 +163,7 @@ def __init__(self, *args, **kwargs) -> None: def convert_from(cls, module: nn.Conv2d) -> 'OFAConv2d': """Convert an instance of `nn.Conv2d` to a new instance of `OFAConv2d`.""" + return cls( in_channels=module.in_channels, out_channels=module.out_channels, diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_linear.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_linear.py index aa7bcbccc..4faa0c8b7 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_linear.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_linear.py @@ -6,7 +6,7 @@ from torch import Tensor from mmrazor.models.mutables.base_mutable import BaseMutable -from .dynamic_mixins import DynamicLinearMixin +from ..mixins import DynamicLinearMixin class DynamicLinear(nn.Linear, DynamicLinearMixin): diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py index 1455340ef..e3e795fa4 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, Optional +from typing import Dict, List, Optional import torch.nn as nn import torch.nn.functional as F @@ -8,7 +8,7 @@ from mmrazor.models.mutables.base_mutable import BaseMutable from mmrazor.registry import MODELS -from .dynamic_mixins import DynamicBatchNormMixin +from ..mixins.dynamic_mixins import DynamicBatchNormMixin class _DynamicBatchNorm(_BatchNorm, DynamicBatchNormMixin): @@ -128,3 +128,65 @@ def _check_input_dim(self, input: Tensor) -> None: if input.dim() != 5: raise ValueError('expected 5D input (got {}D input)'.format( input.dim())) + + +class SwitchableBatchNorm2d(DynamicBatchNorm2d): + """A switchable DynamicBatchNorm2d. It mmploys independent batch + normalization for different switches in a slimmable network. + + To train slimmable networks, ``SwitchableBatchNorm2d`` privatizes all batch + normalization layers for each switch in a slimmable network. Compared with + the naive training approach, it solves the problem of feature aggregation + inconsistency between different switches by independently normalizing the + feature mean and variance during testing. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.candidate_bn = nn.ModuleDict() + + def init_candidates(self, candidates: List): + """Initialize candicates.""" + assert len(self.candidate_bn) == 0 + self._check_candidates(candidates) + for num in candidates: + self.candidate_bn[str(num)] = nn.BatchNorm2d( + num, self.eps, self.momentum, self.affine, + self.track_running_stats) + + def forward(self, input: Tensor) -> Tensor: + """Forward.""" + choice_num = self.activated_channel_num() + if choice_num == self.num_features: + return super().forward(input) + else: + assert str(choice_num) in self.candidate_bn + return self.candidate_bn[str(choice_num)](input) + + def to_static_op(self: _BatchNorm) -> nn.Module: + """Convert to a normal BatchNorm.""" + choice_num = self.activated_channel_num() + if choice_num == self.num_features: + return super().to_static_op() + else: + assert str(choice_num) in self.candidate_bn + return self.candidate_bn[str(choice_num)] + + # private methods + + def activated_channel_num(self): + """The number of activated channels.""" + mask = self._get_num_features_mask() + choice_num = (mask == 1).sum().item() + return choice_num + + def _check_candidates(self, candidates: List): + """Check if candidates aviliable.""" + for value in candidates: + assert isinstance(value, int) + assert 0 < value <= self.num_features + + @property + def static_op_factory(self): + """Return initializer of static op.""" + return nn.BatchNorm2d diff --git a/mmrazor/models/architectures/dynamic_ops/default_dynamic_ops.py b/mmrazor/models/architectures/dynamic_ops/default_dynamic_ops.py deleted file mode 100644 index 2488a49eb..000000000 --- a/mmrazor/models/architectures/dynamic_ops/default_dynamic_ops.py +++ /dev/null @@ -1,333 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -from typing import Optional, Tuple - -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.nn.modules import GroupNorm -from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.modules.instancenorm import _InstanceNorm - -from mmrazor.models.mutables.mutable_channel import MutableChannel -from mmrazor.registry import MODELS -from .base import ChannelDynamicOP - - -class DynamicConv2d(nn.Conv2d, ChannelDynamicOP): - """Applies a 2D convolution over an input signal composed of several input - planes according to the `mutable_in_channels` and `mutable_out_channels` - dynamically. - - Args: - in_channels_cfg (Dict): Config related to `in_channels`. - out_channels_cfg (Dict): Config related to `out_channels`. - """ - accepted_mutables = {'mutable_in_channels', 'mutable_out_channels'} - - def __init__(self, in_channels_cfg, out_channels_cfg, *args, **kwargs): - super(DynamicConv2d, self).__init__(*args, **kwargs) - - in_channels_cfg_ = copy.deepcopy(in_channels_cfg) - in_channels_cfg_.update(dict(num_channels=self.in_channels)) - self.mutable_in_channels = MODELS.build(in_channels_cfg_) - - out_channels_cfg_ = copy.deepcopy(out_channels_cfg) - out_channels_cfg_.update(dict(num_channels=self.out_channels)) - self.mutable_out_channels = MODELS.build(out_channels_cfg_) - - assert isinstance(self.mutable_in_channels, MutableChannel) - assert isinstance(self.mutable_out_channels, MutableChannel) - # TODO - # https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv2d - assert self.padding_mode == 'zeros' - - @property - def mutable_in(self) -> MutableChannel: - """Mutable `in_channels`.""" - return self.mutable_in_channels - - @property - def mutable_out(self) -> MutableChannel: - """Mutable `out_channels`.""" - return self.mutable_out_channels - - def forward(self, input: Tensor) -> Tensor: - """Slice the parameters according to `mutable_in_channels` and - `mutable_out_channels`, and forward.""" - groups = self.groups - if self.groups == self.in_channels == self.out_channels: - groups = input.size(1) - weight, bias = self._get_dynamic_params() - - return F.conv2d(input, weight, bias, self.stride, self.padding, - self.dilation, groups) - - def _get_dynamic_params(self) -> Tuple[Tensor, Optional[Tensor]]: - in_mask = self.mutable_in_channels.current_mask.to(self.weight.device) - out_mask = self.mutable_out_channels.current_mask.to( - self.weight.device) - - if self.groups == 1: - weight = self.weight[out_mask][:, in_mask] - elif self.groups == self.in_channels == self.out_channels: - # depth-wise conv - weight = self.weight[out_mask] - else: - raise NotImplementedError( - 'Current `ChannelMutator` only support pruning the depth-wise ' - '`nn.Conv2d` or `nn.Conv2d` module whose group number equals ' - f'to one, but got {self.groups}.') - - bias = self.bias[out_mask] if self.bias is not None else None - - return weight, bias - - def to_static_op(self) -> nn.Conv2d: - assert self.mutable_in.is_fixed and self.mutable_out.is_fixed - - weight, bias, = self._get_dynamic_params() - groups = self.groups - if groups == self.in_channels == self.out_channels: - groups = self.mutable_in.current_mask.sum().item() - out_channels = weight.size(0) - in_channels = weight.size(1) * groups - - static_conv2d = nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=self.kernel_size, - stride=self.stride, - padding=self.padding, - padding_mode=self.padding_mode, - dilation=self.dilation, - groups=groups, - bias=True if bias is not None else False) - - static_conv2d.weight = nn.Parameter(weight) - if bias is not None: - static_conv2d.bias = nn.Parameter(bias) - - return static_conv2d - - -class DynamicLinear(nn.Linear, ChannelDynamicOP): - """Applies a linear transformation to the incoming data according to the - `mutable_in_features` and `mutable_out_features` dynamically. - - Args: - in_features_cfg (Dict): Config related to `in_features`. - out_features_cfg (Dict): Config related to `out_features`. - """ - accepted_mutables = {'mutable_in_features', 'mutable_out_features'} - - def __init__(self, in_features_cfg, out_features_cfg, *args, **kwargs): - super(DynamicLinear, self).__init__(*args, **kwargs) - - in_features_cfg_ = copy.deepcopy(in_features_cfg) - in_features_cfg_.update(dict(num_channels=self.in_features)) - self.mutable_in_features = MODELS.build(in_features_cfg_) - - out_features_cfg_ = copy.deepcopy(out_features_cfg) - out_features_cfg_.update(dict(num_channels=self.out_features)) - self.mutable_out_features = MODELS.build(out_features_cfg_) - - @property - def mutable_in(self): - """Mutable `in_features`.""" - return self.mutable_in_features - - @property - def mutable_out(self): - """Mutable `out_features`.""" - return self.mutable_out_features - - def forward(self, input: Tensor) -> Tensor: - """Slice the parameters according to `mutable_in_features` and - `mutable_out_features`, and forward.""" - in_mask = self.mutable_in_features.current_mask.to(self.weight.device) - out_mask = self.mutable_out_features.current_mask.to( - self.weight.device) - - weight = self.weight[out_mask][:, in_mask] - bias = self.bias[out_mask] if self.bias is not None else None - - return F.linear(input, weight, bias) - - # TODO - def to_static_op(self) -> nn.Module: - return self - - -class DynamicBatchNorm(_BatchNorm, ChannelDynamicOP): - """Applies Batch Normalization over an input according to the - `mutable_num_features` dynamically. - - Args: - num_features_cfg (Dict): Config related to `num_features`. - """ - accepted_mutables = {'mutable_num_features'} - - def __init__(self, num_features_cfg, *args, **kwargs): - super(DynamicBatchNorm, self).__init__(*args, **kwargs) - - num_features_cfg_ = copy.deepcopy(num_features_cfg) - num_features_cfg_.update(dict(num_channels=self.num_features)) - self.mutable_num_features = MODELS.build(num_features_cfg_) - - @property - def mutable_in(self): - """Mutable `num_features`.""" - return self.mutable_num_features - - @property - def mutable_out(self): - """Mutable `num_features`.""" - return self.mutable_num_features - - def forward(self, input: Tensor) -> Tensor: - """Slice the parameters according to `mutable_num_features`, and - forward.""" - if self.momentum is None: - exponential_average_factor = 0.0 - else: - exponential_average_factor = self.momentum - - if self.training and self.track_running_stats: - if self.num_batches_tracked is not None: # type: ignore - self.num_batches_tracked = \ - self.num_batches_tracked + 1 # type: ignore - if self.momentum is None: # use cumulative moving average - exponential_average_factor = 1.0 / float( - self.num_batches_tracked) - else: # use exponential moving average - exponential_average_factor = self.momentum - - if self.training: - bn_training = True - else: - bn_training = (self.running_mean is None) and (self.running_var is - None) - - if self.affine: - out_mask = self.mutable_num_features.current_mask.to( - self.weight.device) - weight = self.weight[out_mask] - bias = self.bias[out_mask] - else: - weight, bias = self.weight, self.bias - - if self.track_running_stats: - out_mask = self.mutable_num_features.current_mask.to( - self.running_mean.device) - running_mean = self.running_mean[out_mask] \ - if not self.training or self.track_running_stats else None - running_var = self.running_var[out_mask] \ - if not self.training or self.track_running_stats else None - else: - running_mean, running_var = self.running_mean, self.running_var - - return F.batch_norm(input, running_mean, running_var, weight, bias, - bn_training, exponential_average_factor, self.eps) - - # TODO - def to_static_op(self) -> nn.Module: - return self - - -class DynamicInstanceNorm(_InstanceNorm, ChannelDynamicOP): - """Applies Instance Normalization over an input according to the - `mutable_num_features` dynamically. - - Args: - num_features_cfg (Dict): Config related to `num_features`. - """ - accepted_mutables = {'mutable_num_features'} - - def __init__(self, num_features_cfg, *args, **kwargs): - super(DynamicInstanceNorm, self).__init__(*args, **kwargs) - - num_features_cfg_ = copy.deepcopy(num_features_cfg) - num_features_cfg_.update(dict(num_channels=self.num_features)) - self.mutable_num_features = MODELS.build(num_features_cfg_) - - @property - def mutable_in(self): - """Mutable `num_features`.""" - return self.mutable_num_features - - @property - def mutable_out(self): - """Mutable `num_features`.""" - return self.mutable_num_features - - def forward(self, input: Tensor) -> Tensor: - """Slice the parameters according to `mutable_num_features`, and - forward.""" - if self.affine: - out_mask = self.mutable_num_features.current_mask.to( - self.weight.device) - weight = self.weight[out_mask] - bias = self.bias[out_mask] - else: - weight, bias = self.weight, self.bias - - if self.track_running_stats: - out_mask = self.mutable_num_features.current_mask.to( - self.running_mean.device) - running_mean = self.running_mean[out_mask] - running_var = self.running_var[out_mask] - else: - running_mean, running_var = self.running_mean, self.running_var - - return F.instance_norm(input, running_mean, running_var, weight, bias, - self.training or not self.track_running_stats, - self.momentum, self.eps) - - # TODO - def to_static_op(self) -> nn.Module: - return self - - -class DynamicGroupNorm(GroupNorm, ChannelDynamicOP): - """Applies Group Normalization over a mini-batch of inputs according to the - `mutable_num_channels` dynamically. - - Args: - num_channels_cfg (Dict): Config related to `num_channels`. - """ - accepted_mutables = {'mutable_num_features'} - - def __init__(self, num_channels_cfg, *args, **kwargs): - super(DynamicGroupNorm, self).__init__(*args, **kwargs) - - num_channels_cfg_ = copy.deepcopy(num_channels_cfg) - num_channels_cfg_.update(dict(num_channels=self.num_channels)) - self.mutable_num_channels = MODELS.build(num_channels_cfg_) - - @property - def mutable_in(self): - """Mutable `num_channels`.""" - return self.mutable_num_channels - - @property - def mutable_out(self): - """Mutable `num_channels`.""" - return self.mutable_num_channels - - def forward(self, input: Tensor) -> Tensor: - """Slice the parameters according to `mutable_num_channels`, and - forward.""" - if self.affine: - out_mask = self.mutable_num_channels.current_mask.to( - self.weight.device) - weight = self.weight[out_mask] - bias = self.bias[out_mask] - else: - weight, bias = self.weight, self.bias - - return F.group_norm(input, self.num_groups, weight, bias, self.eps) - - # TODO - def to_static_op(self) -> nn.Module: - return self diff --git a/mmrazor/models/architectures/dynamic_ops/mixins/__init__.py b/mmrazor/models/architectures/dynamic_ops/mixins/__init__.py new file mode 100644 index 000000000..7a5097bc5 --- /dev/null +++ b/mmrazor/models/architectures/dynamic_ops/mixins/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dynamic_conv_mixins import DynamicConvMixin +from .dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin, + DynamicLinearMixin, DynamicMixin) + +__all__ = [ + 'DynamicChannelMixin', 'DynamicBatchNormMixin', 'DynamicLinearMixin', + 'DynamicMixin', 'DynamicConvMixin' +] diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv_mixins.py b/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_conv_mixins.py similarity index 100% rename from mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv_mixins.py rename to mmrazor/models/architectures/dynamic_ops/mixins/dynamic_conv_mixins.py diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_mixins.py b/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_mixins.py similarity index 100% rename from mmrazor/models/architectures/dynamic_ops/bricks/dynamic_mixins.py rename to mmrazor/models/architectures/dynamic_ops/mixins/dynamic_mixins.py diff --git a/mmrazor/models/architectures/dynamic_ops/slimmable_dynamic_ops.py b/mmrazor/models/architectures/dynamic_ops/slimmable_dynamic_ops.py deleted file mode 100644 index a85e39af3..000000000 --- a/mmrazor/models/architectures/dynamic_ops/slimmable_dynamic_ops.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -from typing import Dict - -import torch.nn as nn - -from mmrazor.models.mutables.mutable_channel import MutableChannel -from mmrazor.registry import MODELS -from .base import DynamicOP - - -class SwitchableBatchNorm2d(nn.Module, DynamicOP): - """Employs independent batch normalization for different switches in a - slimmable network. - - To train slimmable networks, ``SwitchableBatchNorm2d`` privatizes all - batch normalization layers for each switch in a slimmable network. - Compared with the naive training approach, it solves the problem of feature - aggregation inconsistency between different switches by independently - normalizing the feature mean and variance during testing. - - Args: - module_name (str): Name of this `SwitchableBatchNorm2d`. - num_features_cfg (Dict): Config related to `num_features`. - eps (float): A value added to the denominator for numerical stability. - Same as that in :obj:`torch.nn._BatchNorm`. Default: 1e-5 - momentum (float): The value used for the running_mean and running_var - computation. Can be set to None for cumulative moving average - (i.e. simple average). Same as that in :obj:`torch.nn._BatchNorm`. - Default: 0.1 - affine (bool): A boolean value that when set to True, this module has - learnable affine parameters. Same as that in - :obj:`torch.nn._BatchNorm`. Default: True - track_running_stats (bool): A boolean value that when set to True, this - module tracks the running mean and variance, and when set to False, - this module does not track such statistics, and initializes - statistics buffers running_mean and running_var as None. When these - buffers are None, this module always uses batch statistics. in both - training and eval modes. Same as that in - :obj:`torch.nn._BatchNorm`. Default: True - """ - - def __init__(self, - num_features_cfg: Dict, - eps: float = 1e-5, - momentum: float = 0.1, - affine: bool = True, - track_running_stats: bool = True): - super(SwitchableBatchNorm2d, self).__init__() - - num_features_cfg = copy.deepcopy(num_features_cfg) - candidate_choices = num_features_cfg.pop('candidate_choices') - num_features_cfg.update(dict(num_channels=max(candidate_choices))) - - bns = [ - nn.BatchNorm2d(num_features, eps, momentum, affine, - track_running_stats) - for num_features in candidate_choices - ] - self.bns = nn.ModuleList(bns) - - self.mutable_num_features = MODELS.build(num_features_cfg) - - @property - def mutable_in(self) -> MutableChannel: - """Mutable `num_features`.""" - return self.mutable_num_features - - @property - def mutable_out(self) -> MutableChannel: - """Mutable `num_features`.""" - return self.mutable_num_features - - def forward(self, input): - """Forward computation according to the current switch of the slimmable - networks.""" - idx = self.mutable_num_features.current_choice - return self.bns[idx](input) - - def to_static_op(self) -> nn.Module: - bn_idx = self.mutable_num_features.current_choice - - return self.bns[bn_idx] diff --git a/mmrazor/models/mutables/__init__.py b/mmrazor/models/mutables/__init__.py index 074eda445..abfceab7f 100644 --- a/mmrazor/models/mutables/__init__.py +++ b/mmrazor/models/mutables/__init__.py @@ -1,7 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .base_mutable import BaseMutable from .derived_mutable import DerivedMutable -from .mutable_channel import (MutableChannel, OneShotMutableChannel, - SlimmableMutableChannel) +from .mutable_channel import (BaseMutableChannel, MutableChannelContainer, + OneShotMutableChannel, SimpleMutableChannel, + SquentialMutableChannel) +from .mutable_channel.units import (ChannelUnitType, L1MutableChannelUnit, + MutableChannelUnit, + OneShotMutableChannelUnit, + SequentialMutableChannelUnit, + SlimmableChannelUnit) from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP, OneHotMutableOP, OneShotMutableModule, OneShotMutableOP) @@ -9,7 +16,11 @@ __all__ = [ 'OneShotMutableOP', 'OneShotMutableModule', 'DiffMutableOP', - 'DiffChoiceRoute', 'DiffMutableModule', 'OneShotMutableChannel', - 'SlimmableMutableChannel', 'MutableChannel', 'DerivedMutable', - 'MutableValue', 'OneShotMutableValue', 'OneHotMutableOP' + 'DiffChoiceRoute', 'DiffMutableModule', 'DerivedMutable', 'MutableValue', + 'OneShotMutableValue', 'SequentialMutableChannelUnit', + 'L1MutableChannelUnit', 'OneShotMutableChannelUnit', + 'SimpleMutableChannel', 'MutableChannelUnit', 'SlimmableChannelUnit', + 'BaseMutableChannel', 'MutableChannelContainer', 'ChannelUnitType', + 'SquentialMutableChannel', 'OneHotMutableOP', 'OneShotMutableChannel', + 'BaseMutable' ] diff --git a/mmrazor/models/mutables/derived_mutable.py b/mmrazor/models/mutables/derived_mutable.py index 5e991e9fe..98f680ee9 100644 --- a/mmrazor/models/mutables/derived_mutable.py +++ b/mmrazor/models/mutables/derived_mutable.py @@ -41,25 +41,27 @@ def current_mask(self) -> Tensor: """Current mask.""" -def _expand_choice_fn(mutable: MutableProtocol, expand_ratio: int) -> Callable: +def _expand_choice_fn(mutable: MutableProtocol, + expand_ratio: Union[int, float]) -> Callable: """Helper function to build `choice_fn` for expand derived mutable.""" def fn(): - return mutable.current_choice * expand_ratio + return int(mutable.current_choice * expand_ratio) return fn -def _expand_mask_fn(mutable: MutableProtocol, - expand_ratio: int) -> Callable: # pragma: no cover +def _expand_mask_fn( + mutable: MutableProtocol, + expand_ratio: Union[int, float]) -> Callable: # pragma: no cover """Helper function to build `mask_fn` for expand derived mutable.""" if not hasattr(mutable, 'current_mask'): raise ValueError('mutable must have attribute `currnet_mask`') def fn(): mask = mutable.current_mask - expand_num_channels = mask.size(0) * expand_ratio - expand_choice = mutable.current_choice * expand_ratio + expand_num_channels = int(mask.size(0) * expand_ratio) + expand_choice = int(mutable.current_choice * expand_ratio) expand_mask = torch.zeros(expand_num_channels).bool() expand_mask[:expand_choice] = True @@ -131,8 +133,9 @@ def derive_same_mutable(self: MutableProtocol) -> 'DerivedMutable': """Derive same mutable as the source.""" return self.derive_expand_mutable(expand_ratio=1) - def derive_expand_mutable(self: MutableProtocol, - expand_ratio: int) -> 'DerivedMutable': + def derive_expand_mutable( + self: MutableProtocol, + expand_ratio: Union[int, float]) -> 'DerivedMutable': """Derive expand mutable, usually used with `expand_ratio`.""" choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio) @@ -198,21 +201,18 @@ class DerivedMutable(BaseMutable[CHOICE_TYPE, CHOICE_TYPE], and `Pretrained`. Defaults to None. Examples: - >>> from mmrazor.models.mutables import OneShotMutableChannel - >>> mutable_channel = OneShotMutableChannel( - ... num_channels=3, - ... candidate_choices=[1, 2, 3], - ... candidate_mode='number') + >>> from mmrazor.models.mutables import SquentialMutableChannel + >>> mutable_channel = SquentialMutableChannel(num_channels=3) >>> # derive expand mutable >>> derived_mutable_channel = mutable_channel * 2 >>> # source mutables will be traced automatically >>> derived_mutable_channel.source_mutables - {OneShotMutableChannel(name=unbind, num_channels=3, current_choice=3, choices=[1, 2, 3], activated_channels=3, concat_mutable_name=[])} # noqa: E501 + {SquentialMutableChannel(name=unbind, num_channels=3, current_choice=3)} # noqa: E501 >>> # modify `current_choice` of `mutable_channel` >>> mutable_channel.current_choice = 2 >>> # `current_choice` and `current_mask` of derived mutable will be modified automatically # noqa: E501 >>> derived_mutable_channel - DerivedMutable(current_choice=4, activated_channels=4, source_mutables={OneShotMutableChannel(name=unbind, num_channels=3, current_choice=2, choices=[1, 2, 3], activated_channels=2, concat_mutable_name=[])}, is_fixed=False) # noqa: E501 + DerivedMutable(current_choice=4, activated_channels=4, source_mutables={SquentialMutableChannel(name=unbind, num_channels=3, current_choice=2)}, is_fixed=False) # noqa: E501 """ def __init__(self, diff --git a/mmrazor/models/mutables/mutable_channel/__init__.py b/mmrazor/models/mutables/mutable_channel/__init__.py index b3bbd3ab3..0ef09dc78 100644 --- a/mmrazor/models/mutables/mutable_channel/__init__.py +++ b/mmrazor/models/mutables/mutable_channel/__init__.py @@ -1,8 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .mutable_channel import MutableChannel -from .one_shot_mutable_channel import OneShotMutableChannel -from .slimmable_mutable_channel import SlimmableMutableChannel +from .base_mutable_channel import BaseMutableChannel +from .mutable_channel_container import MutableChannelContainer +from .oneshot_mutalbe_channel import OneShotMutableChannel +from .sequential_mutable_channel import SquentialMutableChannel +from .simple_mutable_channel import SimpleMutableChannel +from .units import (ChannelUnitType, L1MutableChannelUnit, MutableChannelUnit, + OneShotMutableChannelUnit, SequentialMutableChannelUnit, + SlimmableChannelUnit) __all__ = [ - 'OneShotMutableChannel', 'SlimmableMutableChannel', 'MutableChannel' + 'SimpleMutableChannel', 'L1MutableChannelUnit', + 'SequentialMutableChannelUnit', 'MutableChannelUnit', + 'OneShotMutableChannelUnit', 'SlimmableChannelUnit', 'BaseMutableChannel', + 'MutableChannelContainer', 'SquentialMutableChannel', 'ChannelUnitType', + 'OneShotMutableChannel' ] diff --git a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py new file mode 100644 index 000000000..28f1e4854 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""""" +from abc import abstractmethod + +import torch + +from ..base_mutable import BaseMutable +from ..derived_mutable import DerivedMethodMixin + + +class BaseMutableChannel(BaseMutable, DerivedMethodMixin): + """BaseMutableChannel works as a channel mask for DynamicOps to select + channels. + + |---------------------------------------| + |mutable_in_channel(BaseMutableChannel) | + |---------------------------------------| + | DynamicOp | + |---------------------------------------| + |mutable_out_channel(BaseMutableChannel)| + |---------------------------------------| + + All subclasses should implement the following APIs: + + - ``current_choice`` + - ``current_mask`` + + Args: + num_channels (int): number(dimension) of channels(mask). + """ + + def __init__(self, num_channels: int, **kwargs): + super().__init__(**kwargs) + self.name = '' + self.num_channels = num_channels + + # choice + + @property # type: ignore + @abstractmethod + def current_choice(self): + """get current choice.""" + raise NotImplementedError() + + @current_choice.setter # type: ignore + @abstractmethod + def current_choice(self): + """set current choice.""" + raise NotImplementedError() + + @property # type: ignore + @abstractmethod + def current_mask(self) -> torch.Tensor: + """Return a mask indicating the channel selection.""" + raise NotImplementedError() + + @property + def activated_channels(self) -> int: + """Number of activated channels.""" + return (self.current_mask == 1).sum().item() + + # implementation of abstract methods + + def fix_chosen(self, chosen=None): + """Fix the mutable with chosen.""" + if chosen is not None: + self.current_choice = chosen + + if self.is_fixed: + raise AttributeError( + 'The mode of current MUTABLE is `fixed`. ' + 'Please do not call `fix_chosen` function again.') + + self.is_fixed = True + + def dump_chosen(self): + """dump current choice to a dict.""" + raise NotImplementedError() + + def num_choices(self) -> int: + """Number of available choices.""" + raise NotImplementedError() + + # others + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += '(' + repr_str += f'num_channels={self.num_channels}, ' + repr_str += f'activated_channels={self.activated_channels}' + repr_str += ')' + return repr_str diff --git a/mmrazor/models/mutables/mutable_channel/mutable_channel.py b/mmrazor/models/mutables/mutable_channel/mutable_channel.py deleted file mode 100644 index af2bf2188..000000000 --- a/mmrazor/models/mutables/mutable_channel/mutable_channel.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from abc import abstractmethod -from typing import List - -import torch - -from ..base_mutable import CHOICE_TYPE, CHOSEN_TYPE, BaseMutable -from ..derived_mutable import DerivedMethodMixin - - -class MutableChannel(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE], - DerivedMethodMixin): - """A type of ``MUTABLES`` for single path supernet such as AutoSlim. In - single path supernet, each module only has one choice invoked at the same - time. A path is obtained by sampling all the available choices. It is the - base class for one shot channel mutables. - - Args: - num_channels (int): The raw number of channels. - init_cfg (dict, optional): initialization configuration dict for - ``BaseModule``. OpenMMLab has implement 5 initializer including - `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, - and `Pretrained`. - """ - - def __init__(self, num_channels: int, **kwargs): - super().__init__(**kwargs) - - self.num_channels = num_channels - self._same_mutables: List[MutableChannel] = list() - - # If the input of a module is a concatenation of several modules' - # outputs, we add the mutable out of these modules to the - # `concat_parent_mutables` of this module. - self.concat_parent_mutables: List[MutableChannel] = list() - self.name = 'unbind' - - @property - def same_mutables(self): - """Mutables in `same_mutables` and the current mutable should change - Synchronously.""" - return self._same_mutables - - def register_same_mutable(self, mutable): - """Register the input mutable in `same_mutables`.""" - if isinstance(mutable, list): - # Add a concatenation of mutables to `concat_parent_mutables`. - self.concat_parent_mutables = mutable - return - - if self == mutable: - return - if mutable in self._same_mutables: - return - - self._same_mutables.append(mutable) - for s_mutable in self._same_mutables: - s_mutable.register_same_mutable(mutable) - mutable.register_same_mutable(s_mutable) - - @abstractmethod - def convert_choice_to_mask(self, choice: CHOICE_TYPE) -> torch.Tensor: - """Get the mask according to the input choice.""" - pass - - @property - def current_mask(self): - """The current mask. - - We slice the registered parameters and buffers of a ``nn.Module`` - according to the mask of the corresponding channel mutable. - """ - if len(self.concat_parent_mutables) > 0: - # If the input of a module is a concatenation of several modules' - # outputs, the in_mask of this module is the concatenation of - # these modules' out_mask. - return torch.cat([ - mutable.current_mask for mutable in self.concat_parent_mutables - ]) - else: - return self.convert_choice_to_mask(self.current_choice) - - def bind_mutable_name(self, name: str) -> None: - """Bind a MutableChannel to its name. - - Args: - name (str): Name of this `MutableChannel`. - """ - self.name = name - - def fix_chosen(self, chosen: CHOSEN_TYPE) -> None: - """Fix mutable with subnet config. This operation would convert - `unfixed` mode to `fixed` mode. The :attr:`is_fixed` will be set to - True and only the selected operations can be retained. - - Args: - chosen (str): The chosen key in ``MUTABLE``. Defaults to None. - """ - if self.is_fixed: - raise AttributeError( - 'The mode of current MUTABLE is `fixed`. ' - 'Please do not call `fix_chosen` function again.') - - self.is_fixed = True - - def __repr__(self): - concat_mutable_name = [ - mutable.name for mutable in self.concat_parent_mutables - ] - repr_str = self.__class__.__name__ - repr_str += f'(name={self.name}, ' - repr_str += f'num_channels={self.num_channels}, ' - repr_str += f'concat_mutable_name={concat_mutable_name})' - return repr_str diff --git a/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py b/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py new file mode 100644 index 000000000..9292d64c8 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch + +from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin +from mmrazor.registry import MODELS +from mmrazor.utils import IndexDict +from .base_mutable_channel import BaseMutableChannel +from .simple_mutable_channel import SimpleMutableChannel + + +@MODELS.register_module() +class MutableChannelContainer(BaseMutableChannel): + """MutableChannelContainer inherits from BaseMutableChannel. However, + it's not a single BaseMutableChannel, but a container for + BaseMutableChannel. The mask of MutableChannelContainer consists of + all masks of stored MutableChannels. + + ----------------------------------------------------------- + | MutableChannelContainer | + ----------------------------------------------------------- + |MutableChannel1| MutableChannel2 |MutableChannel3| + ----------------------------------------------------------- + + Important interfaces: + register_mutable: register/store BaseMutableChannel in the + MutableChannelContainer + """ + + def __init__(self, num_channels: int, **kwargs): + super().__init__(num_channels, **kwargs) + self.mutable_channels = IndexDict() + + # choice + + @property + def current_choice(self) -> torch.Tensor: + """Get current choices.""" + if len(self.mutable_channels) == 0: + return torch.ones([self.num_channels]).bool() + else: + self._fill_unregistered_range() + self._assert_mutables_valid() + mutable_channels = list(self.mutable_channels.values()) + masks = [mutable.current_mask for mutable in mutable_channels] + mask = torch.cat(masks) + return mask.bool() + + @current_choice.setter + def current_choice(self, choice): + """Set current choices. + + However, MutableChannelContainer doesn't support directly set mask. You + can change the mask of MutableChannelContainer by changing its stored + BaseMutableChannel. + """ + raise NotImplementedError() + + @property + def current_mask(self) -> torch.Tensor: + """Return current mask.""" + return self.current_choice.bool() + + # basic extension + + def register_mutable(self, mutable_channel: BaseMutableChannel, start: int, + end: int): + """Register/Store BaseMutableChannel in the MutableChannelContainer in + the range [start,end)""" + + self.mutable_channels[(start, end)] = mutable_channel + + @classmethod + def register_mutable_channel_to_module(cls, + module: DynamicChannelMixin, + mutable: BaseMutableChannel, + is_to_output_channel=True, + start=0, + end=-1): + """Register a BaseMutableChannel to a module with + MutableChannelContainers.""" + if end == -1: + end = mutable.num_channels + start + if is_to_output_channel: + container: MutableChannelContainer = module.get_mutable_attr( + 'out_channels') + else: + container = module.get_mutable_attr('in_channels') + assert isinstance(container, MutableChannelContainer) + container.register_mutable(mutable, start, end) + + # private methods + + def _assert_mutables_valid(self): + """Assert the current stored BaseMutableChannels are valid to generate + mask.""" + assert len(self.mutable_channels) > 0 + last_end = 0 + for start, end in self.mutable_channels: + assert start == last_end + last_end = end + assert last_end == self.num_channels + + def _fill_unregistered_range(self): + """Fill with SimpleMutableChannels in the range without any stored + BaseMutableChannel. + + For example, if a MutableChannelContainer has 10 channels, and only the + [0,5) is registered with BaseMutableChannels, this method will + automatically register BaseMutableChannels in the range [5,10). + """ + last_end = 0 + for start, end in copy.copy(self.mutable_channels): + if last_end < start: + self.register_mutable( + SimpleMutableChannel(last_end - start), last_end, start) + last_end = end + if last_end < self.num_channels: + self.register_mutable( + SimpleMutableChannel(self.num_channels - last_end), last_end, + self.num_channels) diff --git a/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py deleted file mode 100644 index 7f6eea3ad..000000000 --- a/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Callable, Dict, List, Optional, Union - -import numpy as np -import torch - -from mmrazor.registry import MODELS -from ..derived_mutable import DerivedMutable -from .mutable_channel import MutableChannel - - -@MODELS.register_module() -class OneShotMutableChannel(MutableChannel[int, Dict]): - """A type of ``MUTABLES`` for single path supernet such as AutoSlim. In - single path supernet, each module only has one choice invoked at the same - time. A path is obtained by sampling all the available choices. It is the - base class for one shot mutable channel. - - Args: - num_channels (int): The raw number of channels. - candidate_choices (List): If `candidate_mode` is "ratio", - candidate_choices is a list of candidate width ratios. If - `candidate_mode` is "number", candidate_choices is a list of - candidate channel number. We note that the width ratio is the ratio - between the number of reserved channels and that of all channels in - a layer. - For example, if `ratios` is [0.25, 0.5], there are 2 cases - for us to choose from when we sample from a layer with 12 channels. - One is sampling the very first 3 channels in this layer, another is - sampling the very first 6 channels in this layer. - candidate_mode (str): One of "ratio" or "number". - init_cfg (dict, optional): initialization configuration dict for - ``BaseModule``. OpenMMLab has implement 5 initializer including - `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, - and `Pretrained`. - """ - - def __init__(self, - num_channels: int, - candidate_choices: List[Union[int, float]], - candidate_mode: str = 'ratio', - init_cfg: Optional[Dict] = None): - super(OneShotMutableChannel, self).__init__( - num_channels=num_channels, init_cfg=init_cfg) - - self._current_choice = num_channels - assert len(candidate_choices) > 0, \ - f'Number of candidate choices must be greater than 0, ' \ - f'but got: {len(candidate_choices)}' - self._candidate_choices = candidate_choices - assert candidate_mode in ['ratio', 'number'] - self._candidate_mode = candidate_mode - - self._check_candidate_choices() - - def _check_candidate_choices(self): - """Check if the input `candidate_choices` is valid.""" - if self._candidate_mode == 'number': - assert all([num > 0 and num <= self.num_channels - for num in self._candidate_choices]), \ - f'The candidate channel numbers should be in ' \ - f'range(0, {self.num_channels}].' - assert all([isinstance(num, int) - for num in self._candidate_choices]), \ - 'Type of `candidate_choices` should be int.' - else: - assert all([ - ratio > 0 and ratio <= 1 for ratio in self._candidate_choices - ]), 'The candidate ratio should be in range(0, 1].' - - def sample_choice(self) -> int: - """Sample an arbitrary selection from candidate choices. - - Returns: - int: The chosen number of channels. - """ - assert len(self.concat_parent_mutables) == 0 - num_channels = np.random.choice(self.choices) - assert num_channels > 0, \ - f'Sampled number of channels in `Mutable` {self.name}' \ - f' should be a positive integer.' - return num_channels - - @property - def min_choice(self) -> int: - """Minimum number of channels.""" - assert len(self.concat_parent_mutables) == 0 - min_channels = min(self.choices) - assert min_channels > 0, \ - f'Minimum number of channels in `Mutable` {self.name}' \ - f' should be a positive integer.' - return min_channels - - @property - def max_choice(self) -> int: - """Maximum number of channels.""" - return max(self.choices) - - @property - def current_choice(self): - """The current choice of the mutable.""" - assert len(self.concat_parent_mutables) == 0 - return self._current_choice - - @current_choice.setter - def current_choice(self, choice: int): - """Set the current choice of the mutable.""" - assert choice in self.choices - self._current_choice = choice - - @property - def choices(self) -> List: - """list: all choices. """ - if self._candidate_mode == 'number': - return self._candidate_choices - candidate_choices = [ - round(ratio * self.num_channels) - for ratio in self._candidate_choices - ] - return candidate_choices - - @property - def num_choices(self) -> int: - return len(self.choices) - - def convert_choice_to_mask(self, choice: int) -> torch.Tensor: - """Get the mask according to the input choice.""" - num_channels = choice - mask = torch.zeros(self.num_channels).bool() - mask[:num_channels] = True - return mask - - def dump_chosen(self) -> Dict: - assert self.current_choice is not None - - return dict( - current_choice=self.current_choice, - origin_channels=self.num_channels) - - def fix_chosen(self, dumped_chosen: Dict) -> None: - if self.is_fixed: - raise RuntimeError('OneShotMutableChannel can not be fixed twice') - - current_choice = dumped_chosen['current_choice'] - origin_channels = dumped_chosen['origin_channels'] - - assert current_choice <= origin_channels - assert origin_channels == self.num_channels - - self.current_choice = current_choice - self.is_fixed = True - - def __repr__(self): - concat_mutable_name = [ - mutable.name for mutable in self.concat_parent_mutables - ] - repr_str = self.__class__.__name__ - repr_str += f'(name={self.name}, ' - repr_str += f'num_channels={self.num_channels}, ' - repr_str += f'current_choice={self.current_choice}, ' - repr_str += f'choices={self.choices}, ' - repr_str += f'activated_channels={self.current_mask.sum().item()}, ' - repr_str += f'concat_mutable_name={concat_mutable_name})' - return repr_str - - def __rmul__(self, other) -> DerivedMutable: - return self * other - - def __mul__(self, other) -> DerivedMutable: - if isinstance(other, int): - return self.derive_expand_mutable(other) - - from ..mutable_value import OneShotMutableValue - - def expand_choice_fn(mutable1: 'OneShotMutableChannel', - mutable2: OneShotMutableValue) -> Callable: - - def fn(): - return mutable1.current_choice * mutable2.current_choice - - return fn - - def expand_mask_fn(mutable1: 'OneShotMutableChannel', - mutable2: OneShotMutableValue) -> Callable: - - def fn(): - mask = mutable1.current_mask - max_expand_ratio = mutable2.max_choice - current_expand_ratio = mutable2.current_choice - expand_num_channels = mask.size(0) * max_expand_ratio - - expand_choice = mutable1.current_choice * current_expand_ratio - expand_mask = torch.zeros(expand_num_channels).bool() - expand_mask[:expand_choice] = True - - return expand_mask - - return fn - - if isinstance(other, OneShotMutableValue): - return DerivedMutable( - choice_fn=expand_choice_fn(self, other), - mask_fn=expand_mask_fn(self, other)) - - raise TypeError(f'Unsupported type {type(other)} for mul!') - - def __floordiv__(self, other) -> DerivedMutable: - if isinstance(other, int): - return self.derive_divide_mutable(other) - if isinstance(other, tuple): - assert len(other) == 2 - return self.derive_divide_mutable(*other) - - raise TypeError(f'Unsupported type {type(other)} for div!') diff --git a/mmrazor/models/mutables/mutable_channel/oneshot_mutalbe_channel.py b/mmrazor/models/mutables/mutable_channel/oneshot_mutalbe_channel.py new file mode 100644 index 000000000..61d36fd18 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/oneshot_mutalbe_channel.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Union + +from .sequential_mutable_channel import SquentialMutableChannel + + +class OneShotMutableChannel(SquentialMutableChannel): + """OneShotMutableChannel is a subclass of SquentialMutableChannel. The + difference is that a OneShotMutableChannel limits the candidates of the + choice. + + Args: + num_channels (int): number of channels. + candidate_choices (List[Union[float, int]], optional): A list of + candidate width ratios. Each candidate indicates how many + channels to be reserved. Defaults to []. + choice_mode (str, optional): Mode of choices. Defaults to 'number'. + """ + + def __init__(self, + num_channels: int, + candidate_choices: List[Union[float, int]] = [], + choice_mode='number', + **kwargs): + super().__init__(num_channels, choice_mode, **kwargs) + self.candidate_choices = candidate_choices + if candidate_choices == []: + candidate_choices.append(num_channels if self.is_num_mode else 1.0) + + @property + def current_choice(self) -> Union[int, float]: + """Get current choice.""" + return super().current_choice + + @current_choice.setter + def current_choice(self, choice: Union[int, float]): + """Set current choice.""" + assert choice in self.candidate_choices + SquentialMutableChannel.current_choice.fset( # type: ignore + self, # type: ignore + choice) # type: ignore diff --git a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py new file mode 100644 index 000000000..eae559d41 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, Union + +import torch + +from mmrazor.registry import MODELS +from ..derived_mutable import DerivedMutable +from .simple_mutable_channel import SimpleMutableChannel + +# TODO discuss later + + +@MODELS.register_module() +class SquentialMutableChannel(SimpleMutableChannel): + """SquentialMutableChannel defines a BaseMutableChannel which switch off + channel mask from right to left sequentially, like '11111000'. + + A choice of SquentialMutableChannel is an integer, which indicates how many + channel are activated from left to right. + + Args: + num_channels (int): number of channels. + """ + + def __init__(self, num_channels: int, choice_mode='number', **kwargs): + + super().__init__(num_channels, **kwargs) + assert choice_mode in ['ratio', 'number'] + self.choice_mode = choice_mode + self.mask = torch.ones([self.num_channels]).bool() + + @property + def is_num_mode(self): + """Get if the choice is number mode.""" + return self.choice_mode == 'number' + + @property + def current_choice(self) -> Union[int, float]: + """Get current choice.""" + int_choice = (self.mask == 1).sum().item() + if self.is_num_mode: + return int_choice + else: + return self._num2ratio(int_choice) + + @current_choice.setter + def current_choice(self, choice: Union[int, float]): + """Set choice.""" + if isinstance(choice, float): + int_choice = self._ratio2num(choice) + else: + int_choice = choice + mask = torch.zeros([self.num_channels], device=self.mask.device) + mask[0:int_choice] = 1 + self.mask = mask.bool() + + @property + def current_mask(self) -> torch.Tensor: + """Return current mask.""" + return self.mask + + # methods for + + def fix_chosen(self, chosen=...): + """Fix chosen.""" + if chosen is ...: + chosen = self.current_choice + assert self.is_fixed is False + self.current_choice = chosen + self.is_fixed = True + + def dump_chosen(self): + """Dump chosen.""" + return self.current_choice + + def __rmul__(self, other) -> DerivedMutable: + return self * other + + def __mul__(self, other) -> DerivedMutable: + if isinstance(other, int) or isinstance(other, float): + return self.derive_expand_mutable(other) + + from ..mutable_value import OneShotMutableValue + + def expand_choice_fn(mutable1: 'SquentialMutableChannel', + mutable2: OneShotMutableValue) -> Callable: + + def fn(): + return mutable1.current_choice * mutable2.current_choice + + return fn + + def expand_mask_fn(mutable1: 'SquentialMutableChannel', + mutable2: OneShotMutableValue) -> Callable: + + def fn(): + mask = mutable1.current_mask + max_expand_ratio = mutable2.max_choice + current_expand_ratio = mutable2.current_choice + expand_num_channels = mask.size(0) * max_expand_ratio + + expand_choice = mutable1.current_choice * current_expand_ratio + expand_mask = torch.zeros(expand_num_channels).bool() + expand_mask[:expand_choice] = True + + return expand_mask + + return fn + + if isinstance(other, OneShotMutableValue): + return DerivedMutable( + choice_fn=expand_choice_fn(self, other), + mask_fn=expand_mask_fn(self, other)) + + raise TypeError(f'Unsupported type {type(other)} for mul!') + + def __floordiv__(self, other) -> DerivedMutable: + if isinstance(other, int): + return self.derive_divide_mutable(other) + if isinstance(other, tuple): + assert len(other) == 2 + return self.derive_divide_mutable(*other) + + raise TypeError(f'Unsupported type {type(other)} for div!') + + def _num2ratio(self, choice: Union[int, float]) -> float: + """Convert the a number choice to a ratio choice.""" + if isinstance(choice, float): + return choice + else: + return choice / self.num_channels + + def _ratio2num(self, choice: Union[int, float]) -> int: + """Convert the a ratio choice to a number choice.""" + if isinstance(choice, int): + return choice + else: + return max(1, int(self.num_channels * choice)) diff --git a/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py new file mode 100644 index 000000000..7f949890c --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch + +from mmrazor.registry import MODELS +from ..derived_mutable import DerivedMutable +from .base_mutable_channel import BaseMutableChannel + + +@MODELS.register_module() +class SimpleMutableChannel(BaseMutableChannel): + """SimpleMutableChannel is a simple BaseMutableChannel, it directly take a + mask as a choice. + + Args: + num_channels (int): number of channels. + """ + + def __init__(self, num_channels: int, **kwargs) -> None: + super().__init__(num_channels, **kwargs) + self.mask = torch.ones(num_channels).bool() + + # choice + + @property + def current_choice(self) -> torch.Tensor: + """Get current choice.""" + return self.mask.bool() + + @current_choice.setter + def current_choice(self, choice: torch.Tensor): + """Set current choice.""" + self.mask = choice.to(self.mask.device).bool() + + @property + def current_mask(self) -> torch.Tensor: + """Get current mask.""" + return self.current_choice.bool() + + # basic extension + + def expand_mutable_channel(self, expand_ratio: int) -> DerivedMutable: + """Get a derived SimpleMutableChannel with expanded mask.""" + + def _expand_mask(): + mask = self.current_mask + mask = torch.unsqueeze( + mask, -1).expand(list(mask.shape) + [expand_ratio]).flatten(-2) + return mask + + return DerivedMutable(_expand_mask, _expand_mask, [self]) diff --git a/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py deleted file mode 100644 index 5d4dec0e7..000000000 --- a/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional - -import torch - -from mmrazor.registry import MODELS -from .mutable_channel import MutableChannel - - -@MODELS.register_module() -class SlimmableMutableChannel(MutableChannel[int, Dict[str, int]]): - """A type of ``MUTABLES`` to train several subnet together, such as the - retraining stage in AutoSlim. - - Notes: - We need to set `candidate_choices` after the instantiation of a - `SlimmableMutableChannel` by ourselves. - - Args: - num_channels (int): The raw number of channels. - init_cfg (dict, optional): initialization configuration dict for - ``BaseModule``. OpenMMLab has implement 5 initializer including - `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, - and `Pretrained`. - """ - - def __init__(self, num_channels: int, init_cfg: Optional[Dict] = None): - super(SlimmableMutableChannel, self).__init__( - num_channels=num_channels, init_cfg=init_cfg) - - self.num_channels = num_channels - - @property - def candidate_choices(self) -> List: - """A list of candidate channel numbers.""" - return self._candidate_choices - - @candidate_choices.setter - def candidate_choices(self, choices): - """Set the candidate channel numbers.""" - assert getattr(self, '_candidate_choices', None) is None, \ - f'candidate_choices can be set only when candidate_choices is ' \ - f'None, got: candidate_choices = {self._candidate_choices}' - - assert all([num > 0 and num <= self.num_channels - for num in choices]), \ - f'The candidate channel numbers should be in ' \ - f'range(0, {self.num_channels}].' - assert all([isinstance(num, int) for num in choices]), \ - 'Type of `candidate_choices` should be int.' - - self._candidate_choices = list(choices) - - @property - def choices(self) -> List[int]: - """Return all subnet indexes.""" - assert self._candidate_choices is not None - return list(range(len(self.candidate_choices))) - - def dump_chosen(self) -> Dict: - assert self.current_choice is not None - - return dict( - current_choice=self._candidate_choices[self.current_choice], - origin_channels=self.num_channels) - - def fix_chosen(self, dumped_chosen: Dict) -> None: - chosen = dumped_chosen['current_choice'] - origin_channels = dumped_chosen['origin_channels'] - - assert chosen <= origin_channels - - # TODO - # remove after remove `current_choice` - self.current_choice = self.candidate_choices.index(chosen) - self._chosen = chosen - - super().fix_chosen(chosen) - - @property - def num_choices(self) -> int: - return len(self.choices) - - def convert_choice_to_mask(self, choice: int) -> torch.Tensor: - """Get the mask according to the input choice.""" - if self.is_fixed: - num_channels = self._chosen - elif not hasattr(self, '_candidate_choices'): - # todo: we trace the supernet before set_candidate_choices. - # It's hacky - num_channels = self.num_channels - else: - num_channels = self.candidate_choices[choice] - mask = torch.zeros(self.num_channels).bool() - mask[:num_channels] = True - return mask diff --git a/mmrazor/models/mutables/mutable_channel/units/__init__.py b/mmrazor/models/mutables/mutable_channel/units/__init__.py new file mode 100644 index 000000000..a61816718 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/units/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from .l1_mutable_channel_unit import L1MutableChannelUnit +from .mutable_channel_unit import ChannelUnitType, MutableChannelUnit +from .one_shot_mutable_channel_unit import OneShotMutableChannelUnit +from .sequential_mutable_channel_unit import SequentialMutableChannelUnit +from .slimmable_channel_unit import SlimmableChannelUnit + +__all__ = [ + 'L1MutableChannelUnit', 'MutableChannelUnit', + 'SequentialMutableChannelUnit', 'OneShotMutableChannelUnit', + 'SlimmableChannelUnit', 'ChannelUnitType' +] diff --git a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py new file mode 100644 index 000000000..576412ec0 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py @@ -0,0 +1,287 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, List + +import torch.nn as nn +from mmengine.model import BaseModule + +from mmrazor.structures.graph import ModuleGraph +from mmrazor.structures.graph.channel_graph import ChannelGraph +from mmrazor.structures.graph.channel_modules import (BaseChannel, + BaseChannelUnit) +from mmrazor.structures.graph.channel_nodes import \ + default_channel_node_converter + + +class Channel(BaseModule): + """Channel records information about channels for pruning. + + Args: + name (str): The name of the channel. When the channel is related with + a module, the name should be the name of the module in the model. + module (Any): Module of the channel. + index (Tuple[int,int]): Index(start,end) of the Channel in the Module + node (ChannelNode, optional): A ChannelNode corresponding to the + Channel. Defaults to None. + is_output_channel (bool, optional): Is the channel output channel. + Defaults to True. + expand_ratio (int, optional): Expand ratio of the mask. Defaults to 1. + """ + + # init + + def __init__(self, + name, + module, + index, + node=None, + is_output_channel=True, + expand_ratio=1) -> None: + super().__init__() + self.name = name + self.module = module + self.index = index + self.start = index[0] + self.end = index[1] + + self.node = node + + self.is_output_channel = is_output_channel + self.expand_ratio = expand_ratio + + @classmethod + def init_from_cfg(cls, model: nn.Module, config: Dict): + """init a Channel using a config which can be generated by + self.config_template()""" + name = config['name'] + start = config['start'] + end = config['end'] + expand_ratio = config['expand_ratio'] \ + if 'expand_ratio' in config else 1 + is_output_channel = config['is_output_channel'] + + name2module = dict(model.named_modules()) + name2module.pop('') + module = name2module[name] if name in name2module else None + return Channel( + name, + module, (start, end), + is_output_channel=is_output_channel, + expand_ratio=expand_ratio) + + @classmethod + def init_from_base_channel(cls, base_channel: BaseChannel): + """Init from a BaseChannel object.""" + return cls( + base_channel.name, + base_channel.module, + base_channel.index, + node=None, + is_output_channel=base_channel.is_output_channel, + expand_ratio=base_channel.expand_ratio) + + # config template + + def config_template(self): + """Generate a config template which can be used to initialize a Channel + by cls.init_from_cfg(**kwargs)""" + return { + 'name': self.name, + 'start': self.start, + 'end': self.end, + 'expand_ratio': self.expand_ratio, + 'is_output_channel': self.is_output_channel + } + + # basic properties + + @property + def num_channels(self) -> int: + """The number of channels in the Channel.""" + return self.index[1] - self.index[0] + + @property + def is_mutable(self) -> bool: + """If the channel is prunable.""" + if isinstance(self.module, nn.Conv2d): + # group-wise conv + if self.module.groups != 1 and not (self.module.groups == + self.module.in_channels == + self.module.out_channels): + return False + return True + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(' + f'{self.name}, index={self.index}, ' + f'is_output_channel=' + f'{"true" if self.is_output_channel else "false"}, ' + f'expand_ratio={self.expand_ratio}' + ')') + + def __eq__(self, obj: object) -> bool: + if isinstance(obj, BaseChannel): + return self.name == obj.name \ + and self.module == obj.module \ + and self.index == obj.index \ + and self.is_output_channel == obj.is_output_channel \ + and self.expand_ratio == obj.expand_ratio \ + and self.node == obj.node + else: + return False + + +# Channel && ChannelUnit + + +class ChannelUnit(BaseModule): + """A unit of Channels. + + A ChannelUnit has two list, input_related and output_related, to store + the Channels. These Channels are dependent on each other, and have to + have the same number of activated number of channels. + + Args: + num_channels (int): the number of channels of Channel object. + """ + + # init methods + + def __init__(self, num_channels: int, **kwargs): + super().__init__() + self.num_channels = num_channels + self.output_related: nn.ModuleList = nn.ModuleList() + self.input_related: nn.ModuleList = nn.ModuleList() + self.init_args: Dict = { + } # is used to generate new channel unit with same args + + @classmethod + def init_from_cfg(cls, model: nn.Module, config: Dict) -> 'ChannelUnit': + """init a ChannelUnit using a config which can be generated by + self.config_template()""" + + def auto_fill_channel_config(channel_config: Dict, + is_output_channel: bool, + unit_config: Dict = config): + """Fill channel config with default values.""" + if 'start' not in channel_config: + channel_config['start'] = 0 + if 'end' not in channel_config: + channel_config['end'] = unit_config['init_args'][ + 'num_channels'] + channel_config['is_output_channel'] = is_output_channel + + config = copy.deepcopy(config) + if 'channels' in config: + channels = config.pop('channels') + else: + channels = None + unit = cls(**(config['init_args'])) + if channels is not None: + for channel_config in channels['input_related']: + auto_fill_channel_config(channel_config, False) + unit.add_input_related( + Channel.init_from_cfg(model, channel_config)) + for channel_config in channels['output_related']: + auto_fill_channel_config(channel_config, True) + unit.add_ouptut_related( + Channel.init_from_cfg(model, channel_config)) + return unit + + @classmethod + def init_from_channel_unit(cls, + unit: 'ChannelUnit', + args: Dict = {}) -> 'ChannelUnit': + """Initial a object of current class from a ChannelUnit object.""" + args['num_channels'] = unit.num_channels + mutable_unit = cls(**args) + mutable_unit.input_related = unit.input_related + mutable_unit.output_related = unit.output_related + return mutable_unit + + @classmethod + def init_from_graph(cls, + graph: ModuleGraph, + unit_args={}, + num_input_channel=3) -> List['ChannelUnit']: + """Parse a module-graph and get ChannelUnits.""" + + def init_from_base_channel_unit(base_channel_unit: BaseChannelUnit): + unit = cls(len(base_channel_unit.channel_elems), **unit_args) + unit.input_related = nn.ModuleList([ + Channel.init_from_base_channel(channel) + for channel in base_channel_unit.input_related + ]) + unit.output_related = nn.ModuleList([ + Channel.init_from_base_channel(channel) + for channel in base_channel_unit.output_related + ]) + return unit + + unit_graph = ChannelGraph.copy_from(graph, + default_channel_node_converter) + unit_graph.forward(num_input_channel) + units = unit_graph.collect_units() + units = [init_from_base_channel_unit(unit) for unit in units] + return units + + # tools + + @property + def name(self) -> str: + """str: name of the unit""" + if len(self.output_related) + len(self.input_related) > 0: + first_module = (list(self.output_related) + + list(self.input_related))[0] + first_module_name = f'{first_module.name}_{first_module.index}' + else: + first_module_name = 'unitx' + name = f'{first_module_name}_{self.num_channels}' + return name + + def config_template(self, + with_init_args=False, + with_channels=False) -> Dict: + """Generate a config template which can be used to initialize a + ChannelUnit by cls.init_from_cfg(**kwargs)""" + config = {} + if with_init_args: + config['init_args'] = {'num_channels': self.num_channels} + if with_channels: + config['channels'] = self._channel_dict() + return config + + # node operations + + def add_ouptut_related(self, channel: Channel): + """Add a Channel which is output related.""" + assert channel.is_output_channel + assert self.num_channels == channel.num_channels + if channel not in self.output_related: + self.output_related.append(channel) + + def add_input_related(self, channel: Channel): + """Add a Channel which is input related.""" + assert channel.is_output_channel is False + assert self.num_channels == channel.num_channels + if channel not in self.input_related: + self.input_related.append(channel) + + # others + + def extra_repr(self) -> str: + s = super().extra_repr() + s += f'name={self.name}' + return s + + # private methods + + def _channel_dict(self) -> Dict: + """Return channel config.""" + info = { + 'input_related': + [channel.config_template() for channel in self.input_related], + 'output_related': + [channel.config_template() for channel in self.output_related], + } + return info diff --git a/mmrazor/models/mutables/mutable_channel/units/l1_mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/l1_mutable_channel_unit.py new file mode 100644 index 000000000..8b3c258ad --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/units/l1_mutable_channel_unit.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch +import torch.nn as nn + +from mmrazor.registry import MODELS +from ..simple_mutable_channel import SimpleMutableChannel +from .sequential_mutable_channel_unit import SequentialMutableChannelUnit + + +@MODELS.register_module() +class L1MutableChannelUnit(SequentialMutableChannelUnit): + """Implementation of L1-norm pruning algorithm. It compute the l1-norm of + modules and preferly prune the modules with less l1-norm. + + Please refer to papre `https://arxiv.org/pdf/1608.08710.pdf` for more + detail. + """ + + def __init__(self, + num_channels: int, + choice_mode='number', + divisor=1, + min_value=1, + min_ratio=0.9) -> None: + super().__init__(num_channels, choice_mode, divisor, min_value, + min_ratio) + self.mutable_channel = SimpleMutableChannel(num_channels) + + # choices + + @property + def current_choice(self) -> Union[int, float]: + num = self.mutable_channel.activated_channels + if self.is_num_mode: + return num + else: + return self._num2ratio(num) + + @current_choice.setter + def current_choice(self, choice: Union[int, float]): + int_choice = self._get_valid_int_choice(choice) + mask = self._generate_mask(int_choice).bool() + self.mutable_channel.current_choice = mask + + # private methods + + def _generate_mask(self, choice: int) -> torch.Tensor: + """Generate mask using choice.""" + norm = self._get_unit_norm() + idx = norm.topk(choice)[1] + mask = torch.zeros([self.num_channels]).to(idx.device) + mask.scatter_(0, idx, 1) + return mask + + def _get_l1_norm(self, module: Union[nn.modules.conv._ConvNd, nn.Linear], + start, end): + """Get l1-norm of a module.""" + if isinstance(module, nn.modules.conv._ConvNd): + weight = module.weight.flatten(1) # out_c * in_c * k * k + elif isinstance(module, nn.Linear): + weight = module.weight # out_c * in_c + weight = weight[start:end] + norm = weight.abs().mean(dim=[1]) + return norm + + def _get_unit_norm(self): + """Get l1-norm of the unit by averaging the l1-norm of the moduls in + the unit.""" + avg_norm = 0 + module_num = 0 + for channel in self.output_related: + if isinstance(channel.module, + nn.modules.conv._ConvNd) or isinstance( + channel.module, nn.Linear): + norm = self._get_l1_norm(channel.module, channel.start, + channel.end) + avg_norm += norm + module_num += 1 + avg_norm = avg_norm / module_num + return avg_norm diff --git a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py new file mode 100644 index 000000000..59039cd83 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py @@ -0,0 +1,300 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This module defines MutableChannelUnit.""" +import abc +from collections import Set +from typing import Dict, List, Type, TypeVar + +import torch.nn as nn + +from mmrazor.models.architectures import dynamic_ops +from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin +from mmrazor.models.mutables import DerivedMutable +from mmrazor.models.mutables.mutable_channel import (BaseMutableChannel, + MutableChannelContainer) +from .channel_unit import Channel, ChannelUnit + + +class MutableChannelUnit(ChannelUnit): + + # init methods + def __init__(self, num_channels: int, **kwargs) -> None: + """MutableChannelUnit inherits from ChannelUnit, which manages channels + with channel-dependency. + + Compared with ChannelUnit, MutableChannelUnit defines the core + interfaces for pruning. By inheriting MutableChannelUnit, + we can implement a variant pruning and nas algorithm. + + These apis includes + - basic property + - name + - is_mutable + - before pruning + - prepare_for_pruning + - pruning stage + - current_choice + - sample_choice + - after pruning + - fix_chosen + + Args: + num_channels (int): dimension of the channels of the Channel + objects in the unit. + """ + + super().__init__(num_channels) + + @classmethod + def init_from_mutable_channel(cls, mutable_channel: BaseMutableChannel): + unit = cls(mutable_channel.num_channels) + return unit + + @classmethod + def init_from_predefined_model(cls, model: nn.Module): + """Initialize units using the model with pre-defined dynamicops and + mutable-channels.""" + + def process_container(contanier: MutableChannelContainer, + module, + module_name, + mutable2units, + is_output=True): + for index, mutable in contanier.mutable_channels.items(): + if isinstance(mutable, DerivedMutable): + source_mutables: Set = \ + mutable._trace_source_mutables() + source_channel_mutables = [ + mutable for mutable in source_mutables + if isinstance(mutable, BaseMutableChannel) + ] + assert len(source_channel_mutables) == 1, ( + 'only support one mutable channel ' + 'used in DerivedMutable') + mutable = list(source_channel_mutables)[0] + + if mutable not in mutable2units: + mutable2units[mutable] = cls.init_from_mutable_channel( + mutable) + + unit: MutableChannelUnit = mutable2units[mutable] + if is_output: + unit.add_ouptut_related( + Channel( + module_name, + module, + index, + is_output_channel=is_output)) + else: + unit.add_input_related( + Channel( + module_name, + module, + index, + is_output_channel=is_output)) + + mutable2units: Dict = {} + for name, module in model.named_modules(): + if isinstance(module, DynamicChannelMixin): + in_container: MutableChannelContainer = \ + module.get_mutable_attr( + 'in_channels') + out_container: MutableChannelContainer = \ + module.get_mutable_attr( + 'out_channels') + process_container(in_container, module, name, mutable2units, + False) + process_container(out_container, module, name, mutable2units, + True) + units = list(mutable2units.values()) + return units + + # properties + + @property + def is_mutable(self) -> bool: + """If the channel-unit is prunable.""" + + def traverse(channels: List[Channel]): + has_dynamic_op = False + all_channel_prunable = True + for channel in channels: + if channel.is_mutable is False: + all_channel_prunable = False + break + if isinstance(channel.module, dynamic_ops.DynamicChannelMixin): + has_dynamic_op = True + return has_dynamic_op, all_channel_prunable + + input_has_dynamic_op, input_all_prunable = traverse(self.input_related) + output_has_dynamic_op, output_all_prunable = traverse( + self.output_related) + + return len(self.output_related) > 0 \ + and len(self.input_related) > 0 \ + and input_has_dynamic_op \ + and input_all_prunable \ + and output_has_dynamic_op \ + and output_all_prunable + + def config_template(self, + with_init_args=False, + with_channels=False) -> Dict: + """Return the config template of this unit. By default, the config + template only includes a key 'choice'. + + Args: + with_init_args (bool): if the config includes args for + initialization. + with_channels (bool): if the config includes info about + channels. the config with info about channels can used to + parse channel units without tracer. + """ + config = super().config_template(with_init_args, with_channels) + config['choice'] = self.current_choice + return config + + # before pruning: prepare a model + + @abc.abstractmethod + def prepare_for_pruning(self, model): + """Post process after parse units. + + For example, we need to register mutables to dynamic-ops. + """ + raise NotImplementedError + + # pruning: choice-related + + @property + def current_choice(self): + """Choice of this unit.""" + raise NotImplementedError() + + @current_choice.setter + def current_choice(self, choice) -> None: + """setter of current_choice.""" + raise NotImplementedError() + + @abc.abstractmethod + def sample_choice(self): + """Randomly sample a valid choice and return.""" + raise NotImplementedError() + + # after pruning + + def fix_chosen(self, choice=None): + """Make the channels in this unit fixed.""" + if choice is not None: + self.current_choice = choice + + # private methods + + def _replace_with_dynamic_ops( + self, model: nn.Module, + dynamicop_map: Dict[Type[nn.Module], Type[DynamicChannelMixin]]): + """Replace torch modules with dynamic-ops.""" + + def replace_op(model: nn.Module, name: str, module: nn.Module): + names = name.split('.') + for sub_name in names[:-1]: + model = getattr(model, sub_name) + + setattr(model, names[-1], module) + + def get_module(model, name): + names = name.split('.') + for sub_name in names: + model = getattr(model, sub_name) + return model + + for channel in list(self.input_related) + list(self.output_related): + if isinstance(channel.module, nn.Module): + module = get_module(model, channel.name) + if type(module) in dynamicop_map: + new_module = dynamicop_map[type(module)].convert_from( + module) + replace_op(model, channel.name, new_module) + channel.module = new_module + else: + channel.module = module + + @staticmethod + def _register_channel_container( + model: nn.Module, container_class: Type[MutableChannelContainer]): + """register channel container for dynamic ops.""" + for module in model.modules(): + if isinstance(module, dynamic_ops.DynamicChannelMixin): + if module.get_mutable_attr('in_channels') is None: + in_channels = 0 + if isinstance(module, nn.Conv2d): + in_channels = module.in_channels + elif isinstance(module, nn.modules.batchnorm._BatchNorm): + in_channels = module.num_features + elif isinstance(module, nn.Linear): + in_channels = module.in_features + else: + raise NotImplementedError() + module.register_mutable_attr('in_channels', + container_class(in_channels)) + if module.get_mutable_attr('out_channels') is None: + out_channels = 0 + if isinstance(module, nn.Conv2d): + out_channels = module.out_channels + elif isinstance(module, nn.modules.batchnorm._BatchNorm): + out_channels = module.num_features + elif isinstance(module, nn.Linear): + out_channels = module.out_features + else: + raise NotImplementedError() + module.register_mutable_attr('out_channels', + container_class(out_channels)) + + def _register_mutable_channel(self, mutable_channel: BaseMutableChannel): + # register mutable_channel + for channel in list(self.input_related) + list(self.output_related): + module = channel.module + if isinstance(module, dynamic_ops.DynamicChannelMixin): + container: MutableChannelContainer + if channel.is_output_channel and module.get_mutable_attr( + 'out_channels') is not None: + container = module.get_mutable_attr('out_channels') + elif channel.is_output_channel is False \ + and module.get_mutable_attr('in_channels') is not None: + container = module.get_mutable_attr('in_channels') + else: + raise NotImplementedError() + + if channel.num_channels == self.num_channels: + mutable_channel_ = mutable_channel + start = channel.start + end = channel.end + elif channel.num_channels > self.num_channels: + if channel.num_channels % self.num_channels == 0: + mutable_channel_ = \ + mutable_channel.expand_mutable_channel( + channel.num_channels // self.num_channels) + start = channel.start + end = channel.end + else: + raise NotImplementedError() + else: + raise NotImplementedError() + + if (start, end) in container.mutable_channels: + existed = container.mutable_channels[(start, end)] + if not isinstance(existed, DerivedMutable): + assert mutable_channel is existed + else: + source_mutables = list( + existed._trace_source_mutables()) + is_same = [ + mutable_channel is mutable + for mutable in source_mutables + ] + assert any(is_same) + + else: + container.register_mutable(mutable_channel_, start, end) + + +ChannelUnitType = TypeVar('ChannelUnitType', bound=MutableChannelUnit) diff --git a/mmrazor/models/mutables/mutable_channel/units/one_shot_mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/one_shot_mutable_channel_unit.py new file mode 100644 index 000000000..235978cfa --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/units/one_shot_mutable_channel_unit.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import random +from typing import Dict, List, Union + +import torch.nn as nn + +from mmrazor.registry import MODELS +from ..oneshot_mutalbe_channel import OneShotMutableChannel +from .sequential_mutable_channel_unit import SequentialMutableChannelUnit + + +@MODELS.register_module() +class OneShotMutableChannelUnit(SequentialMutableChannelUnit): + """OneShotMutableChannelUnit is for single path supernet such as AutoSlim. + In single path supernet, each module only has one choice invoked at the + same time. A path is obtained by sampling all the available choices. It is + the base class for one shot mutable channel. + + Args: + num_channels (_type_): The raw number of channels. + candidate_choices (List[Union[int, float]], optional): + A list of candidate width ratios. Each + candidate indicates how many channels to be reserved. + Defaults to [0.5, 1.0](choice_mode='ratio'). + choice_mode (str, optional): Mode of candidates. + One of "ratio" or "number". Defaults to 'ratio'. + divisor (int): Used to make choice divisible. + min_value (int): the minimal value used when make divisible. + min_ratio (float): the minimal ratio used when make divisible. + """ + + def __init__(self, + num_channels: int, + candidate_choices: List[Union[int, float]] = [0.5, 1.0], + choice_mode='ratio', + divisor=1, + min_value=1, + min_ratio=0.9) -> None: + super().__init__(num_channels, choice_mode, divisor, min_value, + min_ratio) + candidate_choices = copy.copy(candidate_choices) + if candidate_choices == []: + candidate_choices.append( + self.num_channels if self.is_num_mode else 1.0) + self.candidate_choices = self._prepare_candidate_choices( + candidate_choices, choice_mode) + + self.mutable_channel = OneShotMutableChannel(num_channels, + self.candidate_choices, + choice_mode) + + @classmethod + def init_from_mutable_channel(cls, mutable_channel: OneShotMutableChannel): + unit = cls(mutable_channel.num_channels, + mutable_channel.candidate_choices, + mutable_channel.choice_mode) + mutable_channel.candidate_choices = unit.candidate_choices + unit.mutable_channel = mutable_channel + return unit + + def prepare_for_pruning(self, model: nn.Module): + """Prepare for pruning.""" + super().prepare_for_pruning(model) + self.current_choice = self.max_choice + + # ~ + + def config_template(self, + with_init_args=False, + with_channels=False) -> Dict: + """Config template of the OneShotMutableChannelUnit.""" + config = super().config_template(with_init_args, with_channels) + if with_init_args: + init_cfg = config['init_args'] + init_cfg.pop('choice_mode') + init_cfg.update({ + 'candidate_choices': self.candidate_choices, + 'choice_mode': self.choice_mode + }) + return config + + # choice + + @property + def current_choice(self) -> Union[int, float]: + """Get current choice.""" + return super().current_choice + + @current_choice.setter + def current_choice(self, choice: Union[int, float]): + """Set current choice.""" + assert choice in self.candidate_choices + int_choice = self._get_valid_int_choice(choice) + choice_ = int_choice if self.is_num_mode else self._num2ratio( + int_choice) + self.mutable_channel.current_choice = choice_ + + def sample_choice(self) -> Union[int, float]: + """Sample a valid choice.""" + rand_idx = random.randint(0, len(self.candidate_choices) - 1) + return self.candidate_choices[rand_idx] + + @property + def min_choice(self) -> Union[int, float]: + """Get Minimal choice.""" + return self.candidate_choices[0] + + @property + def max_choice(self) -> Union[int, float]: + """Get Maximal choice.""" + return self.candidate_choices[-1] + + # private methods + + def _prepare_candidate_choices(self, candidate_choices: List, + choice_mode) -> List: + """Process candidate_choices.""" + choice_type = int if choice_mode == 'number' else float + for choice in candidate_choices: + assert isinstance(choice, choice_type) + if self.is_num_mode: + candidate_choices_ = [ + self._make_divisible(choice) for choice in candidate_choices + ] + else: + candidate_choices_ = [ + self._num2ratio(self._make_divisible(self._ratio2num(choice))) + for choice in candidate_choices + ] + if candidate_choices_ != candidate_choices: + self._make_divisible_info(candidate_choices, candidate_choices_) + + candidate_choices_ = sorted(candidate_choices_) + return candidate_choices_ diff --git a/mmrazor/models/mutables/mutable_channel/units/sequential_mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/sequential_mutable_channel_unit.py new file mode 100644 index 000000000..89a25d236 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/units/sequential_mutable_channel_unit.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from typing import Dict, Union + +import torch.nn as nn +from mmengine import MMLogger + +from mmrazor.models.architectures import dynamic_ops +from mmrazor.models.utils import make_divisible +from mmrazor.registry import MODELS +from ..mutable_channel_container import MutableChannelContainer +from ..sequential_mutable_channel import SquentialMutableChannel +from .mutable_channel_unit import MutableChannelUnit + + +# TODO change the name of SequentialMutableChannelUnit +@MODELS.register_module() +class SequentialMutableChannelUnit(MutableChannelUnit): + """SequentialMutableChannelUnit accepts a intger(number) or float(ratio) as + the choice, which indicates how many of the channels are remained from left + to right, like 11110000. + + Args: + num_channels (int): number of channels. + choice_mode (str): mode of choice, which is one of 'number' or 'ratio'. + divisor (int): Used to make choice divisible. + min_value (int): the minimal value used when make divisible. + min_ratio (float): the minimal ratio used when make divisible. + """ + + def __init__( + self, + num_channels: int, + choice_mode='number', + # args for make divisible + divisor=1, + min_value=1, + min_ratio=0.9) -> None: + super().__init__(num_channels) + assert choice_mode in ['ratio', 'number'] + self.choice_mode = choice_mode + + self.mutable_channel: SquentialMutableChannel = \ + SquentialMutableChannel(num_channels, choice_mode=choice_mode) + + # for make_divisible + self.divisor = divisor + self.min_value = min_value + self.min_ratio = min_ratio + + @classmethod + def init_from_mutable_channel(cls, + mutable_channel: SquentialMutableChannel): + unit = cls(mutable_channel.num_channels, mutable_channel.choice_mode) + unit.mutable_channel = mutable_channel + return unit + + def prepare_for_pruning(self, model: nn.Module): + """Prepare for pruning, including register mutable channels.""" + # register MutableMask + self._replace_with_dynamic_ops( + model, { + nn.Conv2d: dynamic_ops.DynamicConv2d, + nn.BatchNorm2d: dynamic_ops.DynamicBatchNorm2d, + nn.Linear: dynamic_ops.DynamicLinear + }) + self._register_channel_container(model, MutableChannelContainer) + self._register_mutable_channel(self.mutable_channel) + + # ~ + + @property + def is_num_mode(self): + return self.choice_mode == 'number' + + def fix_chosen(self, choice=None): + """fix chosen.""" + super().fix_chosen(choice) + self.mutable_channel.fix_chosen() + + def config_template(self, + with_init_args=False, + with_channels=False) -> Dict: + """Template of config.""" + config = super().config_template(with_init_args, with_channels) + if with_init_args: + init_args: Dict = config['init_args'] + init_args.update( + dict( + choice_mode=self.choice_mode, + divisor=self.divisor, + min_value=self.min_value, + min_ratio=self.min_ratio)) + return config + + # choice + + @property + def current_choice(self) -> Union[int, float]: + """return current choice.""" + return self.mutable_channel.current_choice + + @current_choice.setter + def current_choice(self, choice: Union[int, float]): + """set choice.""" + choice_num_ = self._get_valid_int_choice(choice) + self.mutable_channel.current_choice = choice_num_ + + def sample_choice(self) -> Union[int, float]: + """Sample a choice in (0,1]""" + num_choice = random.randint(1, self.num_channels) + num_choice = self._make_divisible(num_choice) + if self.is_num_mode: + return num_choice + else: + return self._num2ratio(num_choice) + + # private methods + def _get_valid_int_choice(self, choice: Union[float, int]) -> int: + choice_num = self._ratio2num(choice) + choice_num_ = self._make_divisible(choice_num) + if choice_num != choice_num_: + self._make_divisible_info(choice, self.current_choice) + return choice_num_ + + def _make_divisible(self, choice_int: int): + """Make the choice divisible.""" + return make_divisible(choice_int, self.divisor, self.min_value, + self.min_ratio) + + def _num2ratio(self, choice: Union[int, float]) -> float: + """Convert the a number choice to a ratio choice.""" + if isinstance(choice, float): + return choice + else: + return choice / self.num_channels + + def _ratio2num(self, choice: Union[int, float]) -> int: + """Convert the a ratio choice to a number choice.""" + if isinstance(choice, int): + return choice + else: + return max(1, int(self.num_channels * choice)) + + def _make_divisible_info(self, choice, new_choice): + logger = MMLogger.get_current_instance() + logger.info(f'The choice={choice}, which is set to {self.name}, ' + f'is changed to {new_choice} for a divisible choice.') diff --git a/mmrazor/models/mutables/mutable_channel/units/slimmable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/slimmable_channel_unit.py new file mode 100644 index 000000000..a51dce80b --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/units/slimmable_channel_unit.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import List, Union + +import torch.nn as nn + +from mmrazor.models.architectures import dynamic_ops +from mmrazor.registry import MODELS +from ..mutable_channel_container import MutableChannelContainer +from .one_shot_mutable_channel_unit import OneShotMutableChannelUnit + + +@MODELS.register_module() +class SlimmableChannelUnit(OneShotMutableChannelUnit): + """A type of ``MutableChannelUnit`` to train several subnets together. + + Args: + num_channels (int): The raw number of channels. + candidate_choices (List[Union[int, float]], optional): + A list of candidate width ratios. Each + candidate indicates how many channels to be reserved. + Defaults to [0.5, 1.0](choice_mode='ratio'). + choice_mode (str, optional): Mode of candidates. + One of 'ratio' or 'number'. Defaults to 'number'. + divisor (int, optional): Used to make choice divisible. + min_value (int, optional): The minimal value used when make divisible. + min_ratio (float, optional): The minimal ratio used when make + divisible. + """ + + def __init__(self, + num_channels: int, + candidate_choices: List[Union[int, float]] = [], + choice_mode='number', + divisor=1, + min_value=1, + min_ratio=0.9) -> None: + super().__init__(num_channels, candidate_choices, choice_mode, divisor, + min_value, min_ratio) + + def prepare_for_pruning(self, model: nn.Module): + """Prepare for pruning.""" + self._replace_with_dynamic_ops( + model, { + nn.Conv2d: dynamic_ops.DynamicConv2d, + nn.BatchNorm2d: dynamic_ops.SwitchableBatchNorm2d, + nn.Linear: dynamic_ops.DynamicLinear + }) + self.alter_candidates_of_switchbn(self.candidate_choices) + self._register_channel_container(model, MutableChannelContainer) + self._register_mutable_channel(self.mutable_channel) + + def alter_candidates_of_switchbn(self, candidates: List): + """Change candidates of SwitchableBatchNorm2d.""" + for channel in list(self.output_related) + list(self.input_related): + if isinstance(channel.module, dynamic_ops.SwitchableBatchNorm2d) \ + and len(channel.module.candidate_bn) == 0: + channel.module.init_candidates(candidates) + self.current_choice = self.max_choice diff --git a/mmrazor/models/mutables/mutable_value/mutable_value.py b/mmrazor/models/mutables/mutable_value/mutable_value.py index 748d83e78..49a0c870f 100644 --- a/mmrazor/models/mutables/mutable_value/mutable_value.py +++ b/mmrazor/models/mutables/mutable_value/mutable_value.py @@ -222,15 +222,15 @@ def __mul__(self, other) -> DerivedMutable: """Overload `*` operator. Args: - other (int, OneShotMutableChannel): Expand ratio or - OneShotMutableChannel. + other (int, SquentialMutableChannel): Expand ratio or + SquentialMutableChannel. Returns: DerivedMutable: Derived expand mutable. """ - from ..mutable_channel import OneShotMutableChannel + from ..mutable_channel import SquentialMutableChannel - if isinstance(other, OneShotMutableChannel): + if isinstance(other, SquentialMutableChannel): return other * self return super().__mul__(other) diff --git a/mmrazor/models/mutators/channel_mutator/__init__.py b/mmrazor/models/mutators/channel_mutator/__init__.py index 0a50e03d5..dc4b1c86d 100644 --- a/mmrazor/models/mutators/channel_mutator/__init__.py +++ b/mmrazor/models/mutators/channel_mutator/__init__.py @@ -4,5 +4,5 @@ from .slimmable_channel_mutator import SlimmableChannelMutator __all__ = [ - 'ChannelMutator', 'OneShotChannelMutator', 'SlimmableChannelMutator' + 'SlimmableChannelMutator', 'ChannelMutator', 'OneShotChannelMutator' ] diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.py b/mmrazor/models/mutators/channel_mutator/channel_mutator.py index b6e2d06f8..c4ce92e96 100644 --- a/mmrazor/models/mutators/channel_mutator/channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.py @@ -1,261 +1,315 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -from abc import abstractmethod -from typing import Dict, List, Optional +from typing import Dict, Generic, List, Optional, Tuple, Type, Union +from mmengine import fileio from torch.nn import Module -from mmrazor.registry import MODELS, TASK_UTILS -from ...mutables import MutableChannel -from ...task_modules import PathConcatNode, PathDepthWiseConvNode, PathList +from mmrazor.models.architectures.dynamic_ops import DynamicChannelMixin +from mmrazor.models.mutables import (ChannelUnitType, MutableChannelUnit, + SequentialMutableChannelUnit) +from mmrazor.models.mutables.mutable_channel.units.channel_unit import \ + ChannelUnit +from mmrazor.registry import MODELS +from mmrazor.structures.graph import ModuleGraph from ..base_mutator import BaseMutator -from ..utils import DEFAULT_MODULE_CONVERTERS + + +def is_dynamic_op_for_fx_tracer(module, name): + return isinstance(module, DynamicChannelMixin) @MODELS.register_module() -class ChannelMutator(BaseMutator): - """Base class for channel-based mutators. +class ChannelMutator(BaseMutator, Generic[ChannelUnitType]): + """ChannelMutator manages the pruning structure of a model. Args: - mutable_cfg (dict): The config for the channel mutable. - tracer_cfg (dict | Optional): The config for the model tracer. - We Trace the topology of a given model with the tracer. - skip_prefixes (List[str] | Optional): The module whose name start with - a string in skip_prefixes will not be pruned. - init_cfg (dict, optional): The config to control the initialization. - - Attributes: - search_groups (Dict[int, List]): Search group of supernet. Note that - the search group of a mutable based channel mutator is composed of - corresponding mutables. Mutables in the same search group should - be pruned together. - name2module (Dict[str, :obj:`torch.nn.Module`]): The mapping from - a module name to the module. - - Notes: - # To avoid ambiguity, we only allow the following two cases: - # 1. None of the parent nodes of a node is a `ConcatNode` - # 2. A node has only one parent node which is a `ConcatNode` + channel_unit_cfg (Union[ dict, Type[MutableChannelUnit]], optional): + The config of ChannelUnits. When the channel_unit_cfg + is a dict, it should follow the template below: + channel_unit_cfg = dict( + # type of used MutableChannelUnit + type ='XxxMutableChannelUnit', + # default args for MutableChananelUnit + default_args={}, + units = { + # config of a unit + "xxx_unit_name": {}, + ... + } + ), + The config template of 'units' can be got using + MutableChannelUnit.config_template() + Defaults to SequentialMutableChannelUnit. + + parse_cfg (Dict, optional): + The config to parse the model. + Defaults to + dict( type='BackwardTracer', + loss_calculator=dict(type='ImageClassifierPseudoLoss')). + + init_cfg (dict, optional): initialization configuration dict for + BaseModule. + + Note: + There are three ways used in ChannelMutator to parse a model and + get MutableChannelUnits. + 1. Using tracer. It needs parse_cfg to be the config of a tracer. + 2. Using config. When parse_cfg['type']='Config'. It needs that + channel_unit_cfg['unit']['xxx_unit_name] has a key 'channels'. + 3. Using the model with pre-defined dynamic-ops and mutablechannels: + When parse_cfg['type']='Predefined'. """ - def __init__( - self, - mutable_cfg: Dict, - tracer_cfg: Optional[Dict] = None, - skip_prefixes: Optional[List[str]] = None, - init_cfg: Optional[Dict] = None, - ) -> None: - super().__init__(init_cfg) + # init - self.mutable_cfg = mutable_cfg - if tracer_cfg: - self.tracer = TASK_UTILS.build(tracer_cfg) - else: - self.tracer = None - self.skip_prefixes = skip_prefixes - self._search_groups: Optional[Dict[int, List[Module]]] = None - - def add_link(self, path_list: PathList) -> None: - """Establish the relationship between the current nodes and their - parents.""" - for path in path_list: - pre_node = None - for node in path: - if isinstance(node, PathDepthWiseConvNode): - module = self.name2module[node.name] - # The in_channels and out_channels of a depth-wise conv - # should be the same - module.mutable_out.register_same_mutable(module.mutable_in) - module.mutable_in.register_same_mutable(module.mutable_out) - - if isinstance(node, PathConcatNode): - if pre_node is not None: - module_names = node.get_module_names() - concat_modules = [ - self.name2module[name] for name in module_names - ] - concat_mutables = [ - module.mutable_out for module in concat_modules - ] - pre_module = self.name2module[pre_node.name] - pre_module.mutable_in.register_same_mutable( - concat_mutables) - - for sub_path_list in node: - self.add_link(sub_path_list) - - # ConcatNode is the last node in a path - break - - if pre_node is None: - pre_node = node - continue - - pre_module = self.name2module[pre_node.name] - cur_module = self.name2module[node.name] - pre_module.mutable_in.register_same_mutable( - cur_module.mutable_out) - cur_module.mutable_out.register_same_mutable( - pre_module.mutable_in) - - pre_node = node + def __init__(self, + channel_unit_cfg: Union[ + dict, + Type[MutableChannelUnit]] = SequentialMutableChannelUnit, + parse_cfg: Dict = dict( + type='BackwardTracer', + loss_calculator=dict(type='ImageClassifierPseudoLoss')), + init_cfg: Optional[Dict] = None) -> None: - def prepare_from_supernet(self, supernet: Module) -> None: - """Do some necessary preparations with supernet. + super().__init__(init_cfg) - We support the following two cases: + # tracer + if isinstance(parse_cfg, dict): + assert parse_cfg['type'] in [ + 'RazorFxTracer', 'BackwardTracer', 'Config', 'Predefined' + ] + self.parse_cfg = parse_cfg - Case 1: The input is the original nn.Module. We first replace the - conv/linear/norm modules in the input supernet with dynamic ops. - And trace the topology of the supernet. Finally, `search_groups` can be - built based on the topology. + # units + self._name2unit: Dict[str, ChannelUnitType] = {} + self.units: List[ChannelUnitType] = [] - Case 2: The input supernet is made up of dynamic ops. In this case, - relationship between nodes and their parents must have been - established and topology of the supernet is available for us. Then - `search_groups` can be built based on the topology. + # unit config + self.channel_unit_cfg = channel_unit_cfg + self.unit_class, self.unit_default_args, self.units_cfg = \ + self._parse_channel_unit_cfg( + channel_unit_cfg) - Args: - supernet (:obj:`torch.nn.Module`): The supernet to be searched - in your algorithm. - """ + def prepare_from_supernet(self, supernet: Module) -> None: + """Prepare from a model for pruning. - if self.tracer is not None: - self.convert_dynamic_module(supernet, self.module_converters) - # The mapping from a module name to the module - self._name2module = dict(supernet.named_modules()) + It includes two steps: + 1. parse the model and get MutableChannelUnits. + 2. call unit.prepare_for_pruning for each unit. + """ - assert self.tracer is not None - module_path_list: PathList = self.tracer.trace(supernet) + self._name2module = dict(supernet.named_modules()) - self.add_link(module_path_list) + if 'Tracer' in self.parse_cfg['type']: + units = self._prepare_from_tracer(supernet, self.parse_cfg) + elif self.parse_cfg['type'] == 'Config': + units = self._prepare_from_cfg(supernet, self.units_cfg) + elif self.parse_cfg['type'] == 'Predefined': + units = self._prepare_from_predefined_model(supernet) else: - self._name2module = dict(supernet.named_modules()) - - self.bind_mutable_name(supernet) - self._search_groups = self.build_search_groups(supernet) - - @staticmethod - def find_same_mutables(supernet) -> Dict: - """The mutables in the same group should be pruned together.""" - visited = [] - groups = {} - group_idx = 0 - for name, module in supernet.named_modules(): - if isinstance(module, MutableChannel): - same_mutables = module.same_mutables - if module not in visited and len(same_mutables) > 0: - groups[group_idx] = [module] + same_mutables - visited.extend(groups[group_idx]) - group_idx += 1 - return groups - - def bind_mutable_name(self, supernet: Module): - """Bind a MutableChannel to its name. - - Args: - supernet (:obj:`torch.nn.Module`): The supernet to be searched - in your algorithm. - """ + raise NotImplementedError() - def traverse(module, prefix): - for name, child in module.named_children(): - module_name = f'{prefix}.{name}' if prefix else name + for unit in units: + unit.prepare_for_pruning(supernet) + self._name2unit[unit.name] = unit + self.units = units - if isinstance(child, MutableChannel): - child.bind_mutable_name(prefix) - else: - traverse(child, module_name) + # ~ - traverse(supernet, '') + @property + def mutable_units(self) -> List[ChannelUnitType]: + """Prunable units.""" + return [unit for unit in self.units if unit.is_mutable] - def convert_dynamic_module(self, supernet: Module, converters: Dict): - """Replace the conv/linear/norm modules in the input supernet with - dynamic ops. + def config_template(self, + only_mutable_units=False, + with_unit_init_args=False, + with_channels=False): + """Config template of the mutator. Args: - supernet (:obj:`torch.nn.Module`): The architecture to be converted - in your algorithm. - dynamic_layer (Dict): The mapping from the module type to the - corresponding dynamic layer. - """ - - def traverse(module, prefix): - for name, child in module.named_children(): - module_name = prefix + name - - if type(child) in converters: - mutable_cfg_ = copy.deepcopy(self.mutable_cfg) - converter = converters[type(child)] - layer = converter(child, mutable_cfg_, mutable_cfg_) - setattr(module, name, layer) - else: - traverse(child, module_name + '.') - - traverse(supernet, '') - - @abstractmethod - def build_search_groups(self, supernet: Module): - """Build `search_groups`. - - The mutables in the same group should be pruned together. + only_mutable_units (bool, optional): If only return config of + prunable units. Defaults to False. + with_unit_init_args (bool, optional): If return init_args of + units. Defaults to False. + with_channels (bool, optional): if return channel info. + Defaults to False. + + Example: + dict( + channel_unit_cfg = dict( + # type of used MutableChannelUnit + type ='XxxMutableChannelUnit', + # default args for MutableChananelUnit + default_args={}, + # config of units + units = { + # config of a unit + "xxx_unit_name": { + 'init_args':{}, # if with_unit_init_args + 'channels':{} # if with_channels + }, + ... + } + ), + # config of tracer + parse_cfg={} + ) + + + About the detail of the config of each unit, please refer to + MutableChannelUnit.config_template() """ + # template of units + units = self.mutable_units if only_mutable_units else self.units + units_template = {} + for unit in units: + units_template[unit.name] = unit.config_template( + with_init_args=with_unit_init_args, + with_channels=with_channels) + + # template of mutator + template = dict( + type=str(self.__class__.__name__), + channel_unit_cfg=dict( + type=str(self.unit_class.__name__), + default_args=self.unit_default_args, + units=units_template), + parse_cfg=self.parse_cfg) + + return template + + def fix_channel_mutables(self): + """Fix ChannelMutables.""" + for unit in self.units: + unit.fix_chosen() + + # choice manage @property - def search_groups(self) -> Dict[int, List]: - """Search group of supernet. - - Note: - For mutable based mutator, the search group is composed of - corresponding mutables. - - Raises: - RuntimeError: Called before search group has been built. - - Returns: - Dict[int, List[MUTABLE_TYPE]]: Search group. - """ - if self._search_groups is None: - raise RuntimeError( - 'Call `search_groups` before access `build_search_groups`!') - return self._search_groups + def current_choices(self) -> Dict: + """Get current choices.""" + config = self.choice_template + for unit in self.mutable_units: + config[unit.name] = unit.current_choice + return config + + def set_choices(self, config: Dict[str, Union[int, float]]): + """Set choices.""" + for name, choice in config.items(): + unit = self._name2unit[name] + unit.current_choice = choice + + def sample_choices(self) -> Dict[str, Union[int, float]]: + """Sample choices(pruning structure).""" + template = self.choice_template + for key in template: + template[key] = self._name2unit[key].sample_choice() + return template @property - def name2module(self): - """The mapping from a module name to the module. - - Returns: - dict: The name to module mapping. + def choice_template(self) -> Dict: + """Get the chocie template of the Mutator. + + Example: + { + 'xxx_unit_name': xx_choice_value, + ... + } """ - if hasattr(self, '_name2module'): - return self._name2module + template = {} + for unit in self.mutable_units: + template[unit.name] = unit.current_choice + return template + + # implementation of abstract functions + + def search_groups(self) -> Dict: + return self._name2unit + + def mutable_class_type(self) -> Type[ChannelUnitType]: + return self.unit_class + + # private methods + + def _convert_channel_unit_to_mutable(self, units: List[ChannelUnit]): + """Convert ChannelUnits to MutableChannelUnits.""" + mutable_units = [] + for unit in units: + args = copy.copy(self.unit_default_args) + if unit.name in self.units_cfg and \ + 'init_args' in self.units_cfg[unit.name]: + args = self.units_cfg[unit.name]['init_args'] + mutable_unit = self.unit_class.init_from_channel_unit(unit, args) + mutable_units.append(mutable_unit) + return mutable_units + + def _parse_channel_unit_cfg( + self, + channel_unit_cfg) -> Tuple[Type[ChannelUnitType], Dict, Dict]: + """Parse channel_unit_cfg.""" + if isinstance(channel_unit_cfg, dict): + unit_class = MODELS.module_dict[channel_unit_cfg['type']] + + default_unit_args = channel_unit_cfg[ + 'default_args'] if 'default_args' in channel_unit_cfg else {} + + unit_init_cfg = channel_unit_cfg[ + 'units'] if 'units' in channel_unit_cfg else {} + if isinstance(unit_init_cfg, str): + # load config file + unit_init_cfg = fileio.load(unit_init_cfg) + elif issubclass(channel_unit_cfg, MutableChannelUnit): + unit_class = channel_unit_cfg + default_unit_args = {} + unit_init_cfg = {} else: - raise RuntimeError('Called before access `prepare_from_supernet`!') + raise NotImplementedError() + return unit_class, default_unit_args, unit_init_cfg - @property - def module_converters(self) -> Dict: - """The mapping from a type to the corresponding dynamic layer. It is - called in `prepare_from_supernet`. - - Returns: - dict: The mapping dict. - """ - return DEFAULT_MODULE_CONVERTERS - - def is_skip_pruning(self, module_name: str, - skip_prefixes: Optional[List[str]]) -> bool: - """Judge if the module with the input `module_name` should not be - pruned. - - Args: - module_name (str): Module name. - skip_prefixes (list or None): The module whose name start with - a string in skip_prefixes will not be prune. - """ - skip_pruning = False - if skip_prefixes: - for prefix in skip_prefixes: - if module_name.startswith(prefix): - skip_pruning = True - break - return skip_pruning + def _prepare_from_tracer(self, model: Module, parse_cfg: Dict): + """Initialize units using a tracer.""" + if 'num_input_channel' in parse_cfg: + num_input_channel = parse_cfg.pop('num_input_channel') + else: + num_input_channel = 3 + if self.parse_cfg['type'] == 'BackwardTracer': + graph = ModuleGraph.init_from_backward_tracer(model, parse_cfg) + elif self.parse_cfg['type'] == 'RazorFxTracer': + graph = ModuleGraph.init_from_fx_tracer(model, fx_tracer=parse_cfg) + else: + raise NotImplementedError() + self._graph = graph + # get ChannelUnits + units = ChannelUnit.init_from_graph( + graph, num_input_channel=num_input_channel) + # convert to MutableChannelUnits + units = self._convert_channel_unit_to_mutable(units) + return units + + def _prepare_from_cfg(self, model, config: Dict): + """Initialize units using config dict.""" + assert isinstance(self.channel_unit_cfg, dict) + assert 'units' in self.channel_unit_cfg + config = self.channel_unit_cfg['units'] + if isinstance(config, str): + config = fileio.load(config) + assert isinstance(config, dict) + units = [] + for unit_key in config: + init_args = copy.deepcopy(self.unit_default_args) + if 'init_args' in config[unit_key]: + init_args.update(config[unit_key]['init_args']) + config[unit_key]['init_args'] = init_args + unit = self.unit_class.init_from_cfg(model, config[unit_key]) + units.append(unit) + return units + + def _prepare_from_predefined_model(self, model: Module): + """Initialize units using the model with pre-defined dynamicops and + mutable-channels.""" + + units = self.unit_class.init_from_predefined_model(model) + + return units diff --git a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py index 56c773667..a5350ab2b 100644 --- a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py @@ -1,139 +1,41 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings -from typing import Any, Dict, List, Optional - -from torch.nn import Module +from typing import Dict, Type, Union +from mmrazor.models.mutables import OneShotMutableChannelUnit from mmrazor.registry import MODELS -from ...mutables import OneShotMutableChannel -from .channel_mutator import ChannelMutator +from .channel_mutator import ChannelMutator, ChannelUnitType @MODELS.register_module() -class OneShotChannelMutator(ChannelMutator): - """One-shot channel mutable based channel mutator. +class OneShotChannelMutator(ChannelMutator[OneShotMutableChannelUnit]): + """OneShotChannelMutator based on ChannelMutator. It use + OneShotMutableChannelUnit by default. Args: - mutable_cfg (dict): The config for the channel mutable. - tracer_cfg (dict): The config for the model tracer. We Trace the - topology of a given model with the tracer. - skip_prefixes (List[str] | Optional): The module whose name start with - a string in skip_prefixes will not be pruned. - init_cfg (dict, optional): The config to control the initialization. + channel_unit_cfg (Union[dict, Type[ChannelUnitType]], optional): + Config of MutableChannelUnits. Defaults to + dict( type='OneShotMutableChannelUnit', + default_args=dict( num_blocks=8, min_blocks=2 ) ). """ def __init__(self, - mutable_cfg: Dict, - tracer_cfg: Optional[Dict] = None, - skip_prefixes: Optional[List[str]] = None, - init_cfg: Optional[Dict] = None) -> None: - super().__init__(mutable_cfg, tracer_cfg, skip_prefixes, init_cfg) - - def sample_choices(self): - """Sample a choice that records a selection from the search space. - - Returns: - dict: Record the information to build the subnet from the supernet. - Its keys are the properties ``group_idx`` in the channel - mutator's ``search_groups``, and its values are the sampled - choice. - """ - choice_dict = dict() - for group_idx, mutables in self.search_groups.items(): - choice_dict[group_idx] = mutables[0].sample_choice() - return choice_dict - - def set_choices(self, choice_dict: Dict[int, Any]) -> None: - """Set current subnet according to ``choice_dict``. - - Args: - choice_dict (Dict[int, Any]): Choice dict. - """ - for group_idx, choice in choice_dict.items(): - mutables = self.search_groups[group_idx] - for mutable in mutables: - mutable.current_choice = choice - - def set_max_choices(self) -> None: - """Set the channel numbers of each layer to maximum.""" - for mutables in self.search_groups.values(): - for mutable in mutables: - mutable.current_choice = mutable.max_choice - - def set_min_choices(self) -> None: - """Set the channel numbers of each layer to minimum.""" - for mutables in self.search_groups.values(): - for mutable in mutables: - mutable.current_choice = mutable.min_choice - - # todo: check search gorups - def build_search_groups(self, supernet: Module): - """Build `search_groups`. The mutables in the same group should be - pruned together. - - Examples: - >>> class ResBlock(nn.Module): - ... def __init__(self) -> None: - ... super().__init__() - ... - ... self.op1 = nn.Conv2d(3, 8, 1) - ... self.bn1 = nn.BatchNorm2d(8) - ... self.op2 = nn.Conv2d(8, 8, 1) - ... self.bn2 = nn.BatchNorm2d(8) - ... self.op3 = nn.Conv2d(8, 8, 1) - ... - ... def forward(self, x): - ... x1 = self.bn1(self.op1(x)) - ... x2 = self.bn2(self.op2(x1)) - ... x3 = self.op3(x2 + x1) - ... return x3 - - >>> class ToyPseudoLoss: - ... - ... def __call__(self, model): - ... pseudo_img = torch.rand(2, 3, 16, 16) - ... pseudo_output = model(pseudo_img) - ... return pseudo_output.sum() - - >>> mutator = OneShotChannelMutator( - ... tracer_cfg=dict(type='BackwardTracer', - ... loss_calculator=ToyPseudoLoss()), - ... mutable_cfg=dict(type='OneShotMutableChannel', - ... candidate_choices=[4 / 8, 1.0], candidate_mode='ratio') - - >>> model = ResBlock() - >>> mutator.prepare_from_supernet(model) - >>> mutator.search_groups - {0: [OneShotMutableChannel(name=op2, ...), # mutable out - OneShotMutableChannel(name=op1, ...), # mutable out - OneShotMutableChannel(name=op3, ...), # mutable in - OneShotMutableChannel(name=op2, ...), # mutable in - OneShotMutableChannel(name=bn2, ...), # mutable out - OneShotMutableChannel(name=bn1, ...)] # mutable out - } - """ - groups = self.find_same_mutables(supernet) - - search_groups = dict() - group_idx = 0 - for group in groups.values(): - is_skip = False - for mutable in group: - if self.is_skip_pruning(mutable.name, self.skip_prefixes): - warnings.warn(f'Group {group} is not searchable due to' - f' skip_prefixes: {self.skip_prefixes}') - is_skip = True - break - if not is_skip: - search_groups[group_idx] = group - group_idx += 1 - - return search_groups - - def mutable_class_type(self): - """One-shot channel mutable class type. - - Returns: - Type[OneShotMutableModule]: Class type of one-shot mutable. - """ - return OneShotMutableChannel + channel_unit_cfg: Union[dict, Type[ChannelUnitType]] = dict( + type='OneShotMutableChannelUnit', + default_args=dict(num_blocks=8, min_blocks=2)), + **kwargs) -> None: + + super().__init__(channel_unit_cfg, **kwargs) + + def min_choices(self) -> Dict: + """Return the minimal pruning subnet(structure).""" + template = self.choice_template + for key in template: + template[key] = self._name2unit[key].min_choice + return template + + def max_choices(self) -> Dict: + """Return the maximal pruning subnet(structure).""" + template = self.choice_template + for key in template: + template[key] = self._name2unit[key].max_choice + return template diff --git a/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py index 7002245ea..7c0d24fa6 100644 --- a/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py @@ -1,150 +1,72 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy -from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional -import torch.nn as nn -from torch.nn import Module -from torch.nn.modules.batchnorm import _BatchNorm - -from mmrazor.models.architectures.dynamic_ops import DynamicBatchNorm -from mmrazor.models.mutables import SlimmableMutableChannel +from mmrazor.models.mutables import SlimmableChannelUnit from mmrazor.registry import MODELS -from ...task_modules import PathList -from ..utils import switchable_bn_converter from .channel_mutator import ChannelMutator -NONPASS_MODULES = (nn.Conv2d, nn.Linear) -PASS_MODULES = (_BatchNorm, ) - -VALID_PATH_TYPE = Union[str, Path] - @MODELS.register_module() -class SlimmableChannelMutator(ChannelMutator): - """Slimmable channel mutable based channel mutator. +class SlimmableChannelMutator(ChannelMutator[SlimmableChannelUnit]): + """SlimmableChannelMutator is the default ChannelMutator for + SlimmableNetwork algorithm. Args: - channel_cfgs (list[Dict]): A list of candidate channel configs. - mutable_cfg (dict): The config for the channel mutable. - skip_prefixes (List[str] | Optional): The module whose name start with - a string in skip_prefixes will not be pruned. - init_cfg (dict, optional): The config to control the initialization. + channel_unit_cfg (Dict): The config of ChannelUnits. Defaults to + dict( type='SlimmableChannelUnit', units={}). + parse_cfg (Dict): The config of the tracer to parse the model. + Defaults to dict( type='BackwardTracer', + loss_calculator=dict(type='ImageClassifierPseudoLoss')). + init_cfg (dict, optional): initialization configuration dict for + BaseModule. """ def __init__(self, - channel_cfgs: Dict, - mutable_cfg: Dict, - tracer_cfg: Dict, - skip_prefixes: Optional[List[str]] = None, - init_cfg: Optional[Dict] = None): - super(SlimmableChannelMutator, self).__init__( - mutable_cfg=mutable_cfg, - tracer_cfg=tracer_cfg, - skip_prefixes=skip_prefixes, - init_cfg=init_cfg) - - self.channel_cfgs = channel_cfgs - - def prepare_from_supernet(self, supernet: Module) -> None: - """Do some necessary preparations with supernet. - - Note: - Different from `ChannelMutator`, we only support Case 1 in - `ChannelMutator`. The input supernet should be made up of original - nn.Module. And we replace the conv/linear/bn modules in the input - supernet with dynamic ops first. Then we trace the topology of - the supernet to get the `concat_parent_mutables` of a certain - mutable, if the input of a module is a concatenation of several - modules' outputs. Then we convert the ``DynamicBatchNorm`` in - supernet with ``SwitchableBatchNorm2d``, and set the candidate - channel numbers to the corresponding `SlimmableChannelMutable`. - Finally, we establish the relationship between the current nodes - and their parents. - - Args: - supernet (:obj:`torch.nn.Module`): The supernet to be searched - in your algorithm. - """ - self.convert_dynamic_module(supernet, self.module_converters) - - module_path_list: PathList = self.tracer.trace(supernet) - - self.convert_switchable_bn(supernet) - self.set_candidate_choices(supernet) - - # The mapping from a module name to the module - self._name2module = dict(supernet.named_modules()) - self.add_link(module_path_list) - self.bind_mutable_name(supernet) + channel_unit_cfg=dict(type='SlimmableChannelUnit', units={}), + parse_cfg=dict( + type='BackwardTracer', + loss_calculator=dict(type='ImageClassifierPseudoLoss')), + init_cfg: Optional[Dict] = None) -> None: - def set_candidate_choices(self, supernet): - """Set the ``candidate_choices`` of each ``SlimmableChannelMutable``. + super().__init__(channel_unit_cfg, parse_cfg, init_cfg) - Notes: - Different from other ``OneShotChannelMutable``, - ``candidate_choices`` is optional when instantiating a - ``SlimmableChannelMutable`` - """ - for name, module in supernet.named_modules(): - if isinstance(module, SlimmableMutableChannel): - candidate_choices = self.channel_cfgs[name]['current_choice'] - module.candidate_choices = candidate_choices - - def convert_switchable_bn(self, supernet): - """Replace ``DynamicBatchNorm`` in supernet with - ``SwitchableBatchNorm2d``. - - Args: - supernet (:obj:`torch.nn.Module`): The architecture to be converted - in your algorithm. - """ - - def traverse(module, prefix): - for name, child in module.named_children(): - module_name = prefix + name - if isinstance(child, DynamicBatchNorm): - mutable_cfg = copy.deepcopy(self.mutable_cfg) - key = module_name + '.mutable_num_features' - candidate_choices = self.channel_cfgs[key][ - 'current_choice'] - mutable_cfg.update( - dict(candidate_choices=candidate_choices)) - sbn = switchable_bn_converter(child, mutable_cfg, - mutable_cfg) - # TODO - # bind twice? - sbn.mutable_out.bind_mutable_name(module_name) - setattr(module, name, sbn) - else: - traverse(child, module_name + '.') + self.subnets = self._prepare_subnets(self.units_cfg) - traverse(supernet, '') + # private methods - def switch_choices(self, idx: int) -> None: - """Switch the channel config of the supernet according to input `idx`. - - If we train more than one subnet together, we need to switch the - channel_cfg from one to another during one training iteration. + def _prepare_subnets(self, unit_cfg: Dict) -> List[Dict[str, int]]: + """Prepare subnet config. Args: - idx (int): The index of the current subnet. - """ - for name, module in self.name2module.items(): - if isinstance(module, SlimmableMutableChannel): - module.current_choice = idx - - def build_search_groups(self, supernet: Module): - """Build `search_groups`. - - The mutables in the same group should be pruned together. - """ - pass - - def mutable_class_type(self): - """One-shot channel mutable class type. + unit_cfg (Dict[str, Dict[str]]): Config of the units. + unit_cfg follows the below template: + { + 'xx_unit_name':{ + 'init_args':{ + 'candidate_choices':[c1,c2,c3...],... + },... + },... + } + Every unit must have the same number of candidate_choices, and + the candidate in the list of candidate_choices with the same + position compose a subnet. Returns: - Type[OneShotMutableModule]: Class type of one-shot mutable. + List[Dict[str, int]]: config of the subnets. """ - return SlimmableMutableChannel + """Prepare subnet config.""" + subnets: List[Dict[str, int]] = [] + num_subnets = 0 + for key in unit_cfg: + num_subnets = len(unit_cfg[key]['init_args']['candidate_choices']) + break + for _ in range(num_subnets): + subnets.append({}) + for key in unit_cfg: + assert num_subnets == len( + unit_cfg[key]['init_args']['candidate_choices']) + for i, value in enumerate( + unit_cfg[key]['init_args']['candidate_choices']): + subnets[i][key] = value + + return subnets diff --git a/mmrazor/models/mutators/utils/__init__.py b/mmrazor/models/mutators/utils/__init__.py deleted file mode 100644 index 33f94c667..000000000 --- a/mmrazor/models/mutators/utils/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# yapf: disable -from .default_module_converters import (DEFAULT_MODULE_CONVERTERS, - dynamic_bn_converter, - dynamic_conv2d_converter, - dynamic_gn_converter, - dynamic_in_converter, - dynamic_linear_converter) -# yapf: enable -from .slimmable_bn_converter import switchable_bn_converter - -__all__ = [ - 'dynamic_conv2d_converter', 'dynamic_linear_converter', - 'dynamic_bn_converter', 'dynamic_in_converter', 'dynamic_gn_converter', - 'DEFAULT_MODULE_CONVERTERS', 'switchable_bn_converter' -] diff --git a/mmrazor/models/mutators/utils/default_module_converters.py b/mmrazor/models/mutators/utils/default_module_converters.py deleted file mode 100644 index fdfa5d266..000000000 --- a/mmrazor/models/mutators/utils/default_module_converters.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Callable, Dict, Optional - -from torch import nn -from torch.nn.modules import GroupNorm -from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.modules.instancenorm import _InstanceNorm - -from ...architectures import (DynamicBatchNorm, DynamicConv2d, - DynamicGroupNorm, DynamicInstanceNorm, - DynamicLinear) - - -def dynamic_conv2d_converter(module: nn.Conv2d, in_channels_cfg: Dict, - out_channels_cfg: Dict) -> DynamicConv2d: - """Convert a nn.Conv2d module to a DynamicConv2d. - - Args: - module (:obj:`torch.nn.Conv2d`): The original Conv2d module. - in_channels_cfg (Dict): Config related to `in_channels`. - out_channels_cfg (Dict): Config related to `out_channels`. - """ - dynamic_conv = DynamicConv2d( - in_channels_cfg=in_channels_cfg, - out_channels_cfg=out_channels_cfg, - in_channels=module.in_channels, - out_channels=module.out_channels, - kernel_size=module.kernel_size, - stride=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - bias=True if module.bias is not None else False, - padding_mode=module.padding_mode) - return dynamic_conv - - -def dynamic_linear_converter(module: nn.Linear, in_channels_cfg: Dict, - out_channels_cfg: Dict) -> DynamicLinear: - """Convert a nn.Linear module to a DynamicLinear. - - Args: - module (:obj:`torch.nn.Linear`): The original Linear module. - in_features_cfg (Dict): Config related to `in_features`. - out_features_cfg (Dict): Config related to `out_features`. - """ - dynamic_linear = DynamicLinear( - in_features_cfg=in_channels_cfg, - out_features_cfg=out_channels_cfg, - in_features=module.in_features, - out_features=module.out_features, - bias=True if module.bias is not None else False) - return dynamic_linear - - -def dynamic_bn_converter( - module: _BatchNorm, - in_channels_cfg: Dict, - out_channels_cfg: Optional[Dict] = None) -> DynamicBatchNorm: - """Convert a _BatchNorm module to a DynamicBatchNorm. - - Args: - module (:obj:`torch.nn._BatchNorm`): The original BatchNorm module. - num_features_cfg (Dict): Config related to `num_features`. - """ - dynamic_bn = DynamicBatchNorm( - num_features_cfg=in_channels_cfg, - num_features=module.num_features, - eps=module.eps, - momentum=module.momentum, - affine=module.affine, - track_running_stats=module.track_running_stats) - return dynamic_bn - - -def dynamic_in_converter( - module: _InstanceNorm, - in_channels_cfg: Dict, - out_channels_cfg: Optional[Dict] = None) -> DynamicInstanceNorm: - """Convert a _InstanceNorm module to a DynamicInstanceNorm. - - Args: - module (:obj:`torch.nn._InstanceNorm`): The original InstanceNorm - module. - num_features_cfg (Dict): Config related to `num_features`. - """ - dynamic_in = DynamicInstanceNorm( - num_features_cfg=in_channels_cfg, - num_features=module.num_features, - eps=module.eps, - momentum=module.momentum, - affine=module.affine, - track_running_stats=module.track_running_stats) - return dynamic_in - - -def dynamic_gn_converter( - module: GroupNorm, - in_channels_cfg: Dict, - out_channels_cfg: Optional[Dict] = None) -> DynamicGroupNorm: - """Convert a GroupNorm module to a DynamicGroupNorm. - - Args: - module (:obj:`torch.nn.GroupNorm`): The original GroupNorm module. - num_channels_cfg (Dict): Config related to `num_channels`. - """ - dynamic_gn = DynamicGroupNorm( - num_channels_cfg=in_channels_cfg, - num_channels=module.num_channels, - num_groups=module.num_groups, - eps=module.eps, - affine=module.affine) - return dynamic_gn - - -DEFAULT_MODULE_CONVERTERS: Dict[Callable, Callable] = { - nn.Conv2d: dynamic_conv2d_converter, - nn.Linear: dynamic_linear_converter, - nn.BatchNorm1d: dynamic_bn_converter, - nn.BatchNorm2d: dynamic_bn_converter, - nn.BatchNorm3d: dynamic_bn_converter, - nn.InstanceNorm1d: dynamic_in_converter, - nn.InstanceNorm2d: dynamic_in_converter, - nn.InstanceNorm3d: dynamic_in_converter, - nn.GroupNorm: dynamic_gn_converter -} diff --git a/mmrazor/models/mutators/utils/slimmable_bn_converter.py b/mmrazor/models/mutators/utils/slimmable_bn_converter.py deleted file mode 100644 index bef3077c1..000000000 --- a/mmrazor/models/mutators/utils/slimmable_bn_converter.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict - -from torch.nn.modules.batchnorm import _BatchNorm - -from mmrazor.models.architectures import SwitchableBatchNorm2d - - -def switchable_bn_converter(module: _BatchNorm, in_channels_cfg: Dict, - out_channels_cfg: Dict) -> SwitchableBatchNorm2d: - """Convert a _BatchNorm module to a SwitchableBatchNorm2d. - - Args: - module (:obj:`torch.nn.GroupNorm`): The original BatchNorm module. - num_channels_cfg (Dict): Config related to `num_features`. - """ - switchable_bn = SwitchableBatchNorm2d( - num_features_cfg=in_channels_cfg, - eps=module.eps, - momentum=module.momentum, - affine=module.affine, - track_running_stats=module.track_running_stats) - return switchable_bn diff --git a/mmrazor/structures/graph/channel_graph.py b/mmrazor/structures/graph/channel_graph.py new file mode 100644 index 000000000..a1629c587 --- /dev/null +++ b/mmrazor/structures/graph/channel_graph.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, Dict, List + +from torch.nn import Module + +from .base_graph import BaseGraph +from .channel_modules import BaseChannelUnit, ChannelTensor +from .channel_nodes import ChannelNode, default_channel_node_converter +from .module_graph import ModuleGraph + + +class ChannelGraph(ModuleGraph[ChannelNode]): + """ChannelGraph is used to trace the channel dependency of a model. + + A ChannelGraph generates a ChannelTensor as the input to the model. Then, + the tensor can forward through all nodes and collect channel dependency. + """ + + @classmethod + def copy_from(cls, + graph: 'BaseGraph', + node_converter: Callable = default_channel_node_converter): + """Copy from a ModuleGraph.""" + assert isinstance(graph, ModuleGraph) + return super().copy_from(graph, node_converter) + + def collect_units(self) -> List[BaseChannelUnit]: + """Collect channel units in the graph.""" + units = list() + for node in self.topo_traverse(): + node.register_channel_to_units() + for node in self.topo_traverse(): + for unit in node.in_channel_tensor.unit_list + \ + node.out_channel_tensor.unit_list: + if unit not in units: + units.append(unit) + return units + + def forward(self, num_input_channel=3): + """Generate a ChanneelTensor and let it forwards through the graph.""" + for node in self.topo_traverse(): + node.reset_channel_tensors() + self._merge_same_module() + for i, node in enumerate(self.topo_traverse()): + node: ChannelNode + if len(node.prev_nodes) == 0: + channel_list = ChannelTensor(num_input_channel) + node.forward([channel_list]) + else: + node.forward() + + def _merge_same_module(self): + """Union all nodes with the same module to the same unit.""" + module2node: Dict[Module, List[ChannelNode]] = dict() + for node in self: + if isinstance(node.val, Module): + if node.val not in module2node: + module2node[node.val] = [] + if node not in module2node[node.val]: + module2node[node.val].append(node) + + for module in module2node: + if len(module2node[module]) > 1: + nodes = module2node[module] + input_channel_tensor = ChannelTensor(nodes[0].in_channels) + out_channel_tensor = ChannelTensor(nodes[0].out_channels) + for node in nodes: + ChannelTensor.union(input_channel_tensor, + node.in_channel_tensor) + ChannelTensor.union(out_channel_tensor, + node.out_channel_tensor) diff --git a/mmrazor/structures/graph/channel_modules.py b/mmrazor/structures/graph/channel_modules.py new file mode 100644 index 000000000..1cfa2d5ff --- /dev/null +++ b/mmrazor/structures/graph/channel_modules.py @@ -0,0 +1,372 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, List, Tuple, Union + +# Channels + + +class BaseChannel: + """BaseChannel records information about channels for pruning. + + Args: + name (str): The name of the channel. When the channel is related with + a module, the name should be the name of the module in the model. + module (Any): Module of the channel. + index (Tuple[int,int]): Index(start,end) of the Channel in the Module + node (ChannelNode, optional): A ChannelNode corresponding to the + Channel. Defaults to None. + is_output_channel (bool, optional): Is the channel output channel. + Defaults to True. + expand_ratio (int, optional): Expand ratio of the mask. Defaults to 1. + """ + + # init + + def __init__(self, + name, + module, + index, + node=None, + is_output_channel=True, + expand_ratio=1) -> None: + self.name = name + self.module = module + self.index = index + self.start = index[0] + self.end = index[1] + + self.node = node + + self.is_output_channel = is_output_channel + self.expand_ratio = expand_ratio + + @property + def num_channels(self) -> int: + """The number of channels in the Channel.""" + return self.index[1] - self.index[0] + + # others + + def __repr__(self) -> str: + return f'{self.name}\t{self.index}\t \ + {"out" if self.is_output_channel else "in"}\t\ + expand:{self.expand_ratio}' + + def __eq__(self, obj: object) -> bool: + if isinstance(obj, BaseChannel): + return self.name == obj.name \ + and self.module == obj.module \ + and self.index == obj.index \ + and self.is_output_channel == obj.is_output_channel \ + and self.expand_ratio == obj.expand_ratio \ + and self.node == obj.node + else: + return False + + +class BaseChannelUnit: + """BaseChannelUnit is a collection of BaseChannel. + + All BaseChannels are saved in two lists: self.input_related and + self.output_related. + """ + + def __init__(self) -> None: + + self.channel_elems: Dict[int, List[ChannelElement]] = {} + self.input_related: List[BaseChannel] = [] + self.output_related: List[BaseChannel] = [] + + # ~ + + def add_channel_elem(self, channel_elem: 'ChannelElement', index): + """Add a ChannelElement to the BaseChannelUnit.""" + self._add_channel_info(channel_elem, index) + if channel_elem.unit is not None: + channel_elem.remove_from_unit() + channel_elem._register_unit(self, index) + + # unit operations + + @classmethod + def union_units(cls, units: List['BaseChannelUnit']): + """Union units.""" + assert len(units) > 1 + union_unit = units[0] + + for unit in units[1:]: + union_unit = BaseChannelUnit.union_two_units(union_unit, unit) + return union_unit + + @classmethod + def union_two_units(cls, unit1: 'BaseChannelUnit', + unit2: 'BaseChannelUnit'): + """Union two units.""" + if unit1 is unit2: + return unit1 + else: + assert len(unit1) == len(unit2) + for i in unit1: + for channel_elem in copy.copy(unit2[i]): + unit1.add_channel_elem(channel_elem, i) + return unit1 + + @classmethod + def split_unit(cls, unit: 'BaseChannelUnit', nums: List[int]): + """Split a unit to multiple units.""" + new_units = [] + if len(nums) == 1: + return [unit] + assert sum(nums) == len(unit) + for num in nums: + new_unit = unit._split_a_new_unit(list(range(0, num))) + new_units.append(new_unit) + return new_units + + # private methods + + def _clean_channel_info(self, channel_elem: 'ChannelElement', index: int): + """Clean the info of a ChannelElement.""" + self[index].remove(channel_elem) + + def _add_channel_info(self, channel_elem: 'ChannelElement', index): + """Add the info of a ChannelElemnt.""" + assert channel_elem.unit is not self + if index not in self.channel_elems: + self.channel_elems[index] = [] + self.channel_elems[index].append(channel_elem) + + def _split_a_new_unit(self, indexes: List[int]): + """Split a part of the unit to a new unit.""" + new_unit = BaseChannelUnit() + j = 0 + for i in indexes: + for channel_elem in copy.copy(self[i]): + new_unit.add_channel_elem(channel_elem, j) + self.channel_elems.pop(i) + j += 1 + self._reindex() + return new_unit + + def _reindex(self): + """Re-index the owning ChannelElements.""" + j = 0 + for i in copy.copy(self.channel_elems): + if len(self.channel_elems[i]) == 0: + self.channel_elems.pop(i) + else: + if j < i: + for channel_elem in copy.copy(self.channel_elems[i]): + if channel_elem.unit is not None: + channel_elem.remove_from_unit() + self.add_channel_elem(channel_elem, j) + self.channel_elems.pop(i) + j += 1 + elif j == i: + pass + else: + raise Exception() + + # others + + def __repr__(self) -> str: + + def add_prefix(string: str, prefix=' '): + str_list = string.split('\n') + str_list = [ + prefix + line if line != '' else line for line in str_list + ] + return '\n'.join(str_list) + + def list_repr(lit: List): + s = '[\n' + for item in lit: + s += add_prefix(item.__repr__(), ' ') + '\n' + s += ']\n' + return s + + s = ('xxxxx_' + f'\t{len(self.output_related)},{len(self.input_related)}\n') + s += ' output_related:\n' + s += add_prefix(list_repr(self.output_related), ' ' * 4) + s += ' input_related\n' + s += add_prefix(list_repr(self.input_related), ' ' * 4) + return s + + def __iter__(self): + for i in self.channel_elems: + yield i + + def __len__(self): + return len(self.channel_elems) + + def __getitem__(self, key): + return self.channel_elems[key] + + +class ChannelElement: + """Each ChannelElement is the basic element of a ChannelTensor. It records + its owing ChannelTensor and BaseChannelUnit. + + Args: + index (int): The index of the ChannelElement in the ChannelTensor. + """ + + def __init__(self, index_in_tensor: int) -> None: + + self.index_in_channel_tensor = index_in_tensor + + self.unit: Union[BaseChannelUnit, None] = None + self.index_in_unit = -1 + + def remove_from_unit(self): + """Remove the ChannelElement from its owning BaseChannelUnit.""" + self.unit._clean_channel_info(self, self.index_in_unit) + self._clean_unit_info() + + # private methods + + def _register_unit(self, unit, index): + """Register the ChannelElement to a BaseChannelUnit.""" + self.unit = unit + self.index_in_unit = index + + def _clean_unit_info(self): + """Clean the unit info in the ChannelElement.""" + self.unit = None + self.index_in_unit = -1 + + +class ChannelTensor: + """A ChannelTensor is a list of ChannelElemnts. It can forward through a + ChannelGraph. + + Args: + num_channel_elems (int): Number of ChannelElements. + """ + + def __init__(self, num_channel_elems: int) -> None: + + unit = BaseChannelUnit() + self.channel_elems: List[ChannelElement] = [ + ChannelElement(i) for i in range(num_channel_elems) + ] + for channel_elem in self.channel_elems: + unit.add_channel_elem(channel_elem, + channel_elem.index_in_channel_tensor) + + # unit operations + + def align_units_with_nums(self, nums: List[int]): + """Align owning units to certain lengths.""" + i = 0 + for start, end in self.unit_dict: + start_ = start + new_nums = [] + while start_ < end: + new_nums.append(nums[i]) + start_ += nums[i] + i += 1 + BaseChannelUnit.split_unit(self.unit_dict[(start, end)], new_nums) + + @property + def unit_dict(self) -> Dict[Tuple[int, int], BaseChannelUnit]: + """Get a dict of owning units.""" + units: Dict[Tuple[int, int], BaseChannelUnit] = {} + # current_unit = ... + current_unit_idx = -1 + start = 0 + for i in range(len(self)): + if i == 0: + current_unit = self[i].unit + current_unit_idx = self[i].index_in_unit + start = 0 + else: + if current_unit is not self[i].unit or \ + current_unit_idx > self[i].index_in_unit: + units[(start, i)] = current_unit + current_unit = self[i].unit + current_unit_idx = self[i].index_in_unit + start = i + current_unit_idx = self[i].index_in_unit + units[(start, len(self))] = current_unit + return units + + @property + def unit_list(self) -> List[BaseChannelUnit]: + """Get a list of owning units.""" + return list(self.unit_dict.values()) + + # tensor operations + + @classmethod + def align_tensors(cls, *tensors: 'ChannelTensor'): + """Align the lengths of the units of the tensors.""" + assert len(tensors) >= 2 + for tensor in tensors: + assert len(tensor) == len( + tensors[0]), f'{len(tensor)}!={len(tensors[0])}' + aligned_index = cls._index2points( + *[list(tenser.unit_dict.keys()) for tenser in tensors]) + nums = cls._points2num(aligned_index) + if len(nums) > 1: + for tensor in tensors: + tensor.align_units_with_nums(nums) + + def union(self, tensor1: 'ChannelTensor'): + """Union the units with the tensor1.""" + # align + ChannelTensor.align_tensors(self, tensor1) + # union + for ch1, ch2 in zip(self.channel_elems, tensor1.channel_elems): + assert ch1.unit is not None and ch2.unit is not None + for ch in copy.copy(ch2.unit.channel_elems[ch2.index_in_unit]): + ch1.unit.add_channel_elem(ch, ch1.index_in_unit) + + def expand(self, ratio) -> 'ChannelTensor': + """Get a new ChannelTensor which is expanded from this + ChannelTensor.""" + expanded_tensor = ChannelTensor(len(self) * ratio) + for i, ch in enumerate(self.channel_elems): + assert ch.unit is not None + unit = ch.unit + for j in range(0, ratio): + ex_ch = expanded_tensor[i * ratio + j] + unit.add_channel_elem(ex_ch, ch.index_in_unit) + return expanded_tensor + + # others + + def __getitem__(self, i: int): + """Get ith ChannelElement in the ChannelTensor.""" + return self.channel_elems[i] + + def __len__(self): + """Get length of the ChannelTensor.""" + return len(self.channel_elems) + + @classmethod + def _index2points(cls, *indexes: List[Tuple[int, int]]): + """Convert indexes to points.""" + new_index = [] + for index in indexes: + new_index.extend(index) + points = set() + for start, end in new_index: + points.add(start) + points.add(end) + points_list = list(points) + points_list.sort() + return points_list + + @classmethod + def _points2num(cls, indexes: List[int]): + """Convert a list of sorted points to the length of each block.""" + if len(indexes) == 0: + return [] + nums = [] + start = 0 + for end in indexes[1:]: + nums.append(end - start) + start = end + return nums diff --git a/mmrazor/structures/graph/channel_nodes.py b/mmrazor/structures/graph/channel_nodes.py new file mode 100644 index 000000000..1749b5875 --- /dev/null +++ b/mmrazor/structures/graph/channel_nodes.py @@ -0,0 +1,378 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import operator +from abc import abstractmethod +from typing import Union + +import torch +import torch.nn as nn +from mmengine import MMLogger + +from .channel_modules import BaseChannel, BaseChannelUnit, ChannelTensor +from .module_graph import ModuleNode + + +class ChannelNode(ModuleNode): + """A ChannelNode is like a torch module. It accepts a ChannelTensor and + output a ChannelTensor. The difference is that the torch module transforms + a tensor, while the ChannelNode records the information of channel + dependency in the ChannelTensor. + + Args: + name (str): The name of the node. + val (Union[nn.Module, str]): value of the node. + expand_ratio (int, optional): expand_ratio compare with channel + mask. Defaults to 1. + module_name (str, optional): the module name of the module of the + node. + """ + + # init + + def __init__(self, + name: str, + val: Union[nn.Module, str], + expand_ratio: int = 1, + module_name='') -> None: + + super().__init__(name, val, expand_ratio, module_name) + self.in_channel_tensor = ChannelTensor(self.in_channels) + self.out_channel_tensor = ChannelTensor(self.out_channels) + + @classmethod + def copy_from(cls, node): + """Copy from a ModuleNode.""" + assert isinstance(node, ModuleNode) + return cls(node.name, node.val, node.expand_ratio, node.module_name) + + def reset_channel_tensors(self): + """Reset the owning ChannelTensors.""" + self.in_channel_tensor = ChannelTensor(self.in_channels) + self.out_channel_tensor = ChannelTensor(self.out_channels) + + # forward + + def forward(self, in_channel_tensor=None): + """Forward with ChannelTensors.""" + assert self.in_channel_tensor is not None and \ + self.out_channel_tensor is not None + if in_channel_tensor is None: + out_channel_tensors = [ + node.out_channel_tensor for node in self.prev_nodes + ] + + in_channel_tensor = out_channel_tensors + self.channel_forward(*in_channel_tensor) + if self.expand_ratio > 1: + self.out_channel_tensor = self.out_channel_tensor.expand( + self.expand_ratio) + + @abstractmethod + def channel_forward(self, *channel_tensors: ChannelTensor): + """Forward with ChannelTensors.""" + assert len(channel_tensors) == 1, f'{len(channel_tensors)}' + BaseChannelUnit.union_two_units( + list(self.in_channel_tensor.unit_dict.values())[0], + list(channel_tensors[0].unit_dict.values())[0]) + + if self.in_channels == self.out_channels: + BaseChannelUnit.union_two_units( + self.in_channel_tensor.unit_list[0], + self.out_channel_tensor.unit_list[0]) + + # register unit + + def register_channel_to_units(self): + """Register the module of this node to corresponding units.""" + name = self.module_name if isinstance(self.val, + nn.Module) else self.name + for index, unit in self.in_channel_tensor.unit_dict.items(): + channel = BaseChannel(name, self.val, index, None, False, + self.expand_ratio) + if channel not in unit.input_related: + unit.input_related.append(channel) + for index, unit in self.out_channel_tensor.unit_dict.items(): + channel = BaseChannel(name, self.val, index, None, True, + self.expand_ratio) + if channel not in unit.output_related: + unit.output_related.append(channel) + + # channels + + # @abstractmethod + @property + def in_channels(self) -> int: + """Get the number of input channels of the node.""" + raise NotImplementedError() + + # @abstractmethod + @property + def out_channels(self) -> int: + """Get the number of output channels of the node.""" + raise NotImplementedError() + + +# basic nodes + + +class PassChannelNode(ChannelNode): + """A PassChannelNode has the same number of input channels and output + channels. + + Besides, the corresponding input channels and output channels belong to one + channel unit. Such as BatchNorm, Relu. + """ + + def channel_forward(self, *in_channel_tensor: ChannelTensor): + """Channel forward.""" + PassChannelNode._channel_forward(self, *in_channel_tensor) + + @property + def in_channels(self) -> int: + """Get the number of input channels of the node.""" + if len(self.prev_nodes) > 0: + return self.prev_nodes[0].out_channels + else: + return 0 + + @property + def out_channels(self) -> int: + """Get the number of output channels of the node.""" + return self.in_channels + + def __repr__(self) -> str: + return super().__repr__() + '_pass' + + @staticmethod + def _channel_forward(node: ChannelNode, *in_channel_tensor: ChannelTensor): + """Channel forward.""" + assert len(in_channel_tensor) == 1 and \ + node.in_channels == node.out_channels + in_channel_tensor[0].union(node.in_channel_tensor) + node.in_channel_tensor.union(node.out_channel_tensor) + + +class MixChannelNode(ChannelNode): + """A MixChannelNode has independent input channels and output channels.""" + + def channel_forward(self, *in_channel_tensor: ChannelTensor): + """Channel forward.""" + assert len(in_channel_tensor) <= 1 + if len(in_channel_tensor) == 1: + in_channel_tensor[0].union(self.in_channel_tensor) + + @property + def in_channels(self) -> int: + """Get the number of input channels of the node.""" + if len(self.prev_nodes) > 0: + return self.prev_nodes[0].in_channels + else: + return 0 + + @property + def out_channels(self) -> int: + """Get the number of output channels of the node.""" + if len(self.next_nodes) > 0: + return self.next_nodes[0].in_channels + else: + return 0 + + def __repr__(self) -> str: + return super().__repr__() + '_mix' + + +class BindChannelNode(PassChannelNode): + """A BindChannelNode has multiple inputs, and all input channels belong to + the same channel unit.""" + + def channel_forward(self, *in_channel_tensor: ChannelTensor): + """Channel forward.""" + assert len(in_channel_tensor) > 1 + # align channel_tensors + ChannelTensor.align_tensors(*in_channel_tensor) + + # union tensors + node_units = [ + channel_lis.unit_dict for channel_lis in in_channel_tensor + ] + for key in node_units[0]: + BaseChannelUnit.union_units([units[key] for units in node_units]) + super().channel_forward(in_channel_tensor[0]) + + def __repr__(self) -> str: + return super(ChannelNode, self).__repr__() + '_bind' + + +class CatChannelNode(ChannelNode): + """A CatChannelNode cat all input channels.""" + + def channel_forward(self, *in_channel_tensors: ChannelTensor): + BaseChannelUnit.union_two_units(self.in_channel_tensor.unit_list[0], + self.out_channel_tensor.unit_list[0]) + num_ch = [] + for in_ch_tensor in in_channel_tensors: + for start, end in in_ch_tensor.unit_dict: + num_ch.append(end - start) + + split_units = BaseChannelUnit.split_unit( + self.in_channel_tensor.unit_list[0], num_ch) + + i = 0 + for in_ch_tensor in in_channel_tensors: + for in_unit in in_ch_tensor.unit_dict.values(): + BaseChannelUnit.union_two_units(split_units[i], in_unit) + i += 1 + + @property + def in_channels(self) -> int: + """Get the number of input channels of the node.""" + return sum([node.out_channels for node in self.prev_nodes]) + + @property + def out_channels(self) -> int: + """Get the number of output channels of the node.""" + return self.in_channels + + def __repr__(self) -> str: + return super().__repr__() + '_cat' + + +# module nodes + + +class ConvNode(MixChannelNode): + """A ConvNode corresponds to a Conv2d module. + + It can deal with normal conv, dwconv and gwconv. + """ + + def __init__(self, + name: str, + val: Union[nn.Module, str], + expand_ratio: int = 1, + module_name='') -> None: + super().__init__(name, val, expand_ratio, module_name) + assert isinstance(self.val, nn.Conv2d) + if self.val.groups == 1: + self.conv_type = 'conv' + elif self.val.in_channels == self.out_channels == self.val.groups: + self.conv_type = 'dwconv' + else: + self.conv_type = 'gwconv' + + def channel_forward(self, *in_channel_tensor: ChannelTensor): + if self.conv_type == 'conv': + return super().channel_forward(*in_channel_tensor) + elif self.conv_type == 'dwconv': + return PassChannelNode._channel_forward(self, *in_channel_tensor) + elif self.conv_type == 'gwconv': + return super().channel_forward(*in_channel_tensor) + else: + pass + + @property + def in_channels(self) -> int: + return self.val.in_channels + + @property + def out_channels(self) -> int: + return self.val.out_channels + + def __repr__(self) -> str: + return super().__repr__() + '_conv' + + +class LinearNode(MixChannelNode): + """A LinearNode corresponds to a Linear module.""" + + def __init__(self, + name: str, + val: Union[nn.Module, str], + expand_ratio: int = 1, + module_name='') -> None: + super().__init__(name, val, expand_ratio, module_name) + assert isinstance(self.val, nn.Linear) + + @property + def in_channels(self) -> int: + return self.val.in_features + + @property + def out_channels(self) -> int: + return self.val.out_features + + def __repr__(self) -> str: + return super().__repr__() + 'linear' + + +class NormNode(PassChannelNode): + """A NormNode corresponds to a BatchNorm2d module.""" + + def __init__(self, + name: str, + val: Union[nn.Module, str], + expand_ratio: int = 1, + module_name='') -> None: + super().__init__(name, val, expand_ratio, module_name) + assert isinstance(self.val, nn.BatchNorm2d) + + @property + def in_channels(self) -> int: + return self.val.num_features + + @property + def out_channels(self) -> int: + return self.val.num_features + + def __repr__(self) -> str: + return super().__repr__() + '_bn' + + +# converter + + +def default_channel_node_converter(node: ModuleNode) -> ChannelNode: + """The default node converter for ChannelNode.""" + + def warn(default='PassChannelNode'): + logger = MMLogger('mmrazor', 'mmrazor') + logger.warn((f"{node.name}({node.val}) node can't find match type of" + 'channel_nodes,' + f'replaced with {default} by default.')) + + module_mapping = { + nn.Conv2d: ConvNode, + nn.BatchNorm2d: NormNode, + nn.Linear: LinearNode, + } + function_mapping = { + torch.add: BindChannelNode, + torch.cat: CatChannelNode, + operator.add: BindChannelNode + } + name_mapping = { + 'bind_placeholder': BindChannelNode, + 'pass_placeholder': PassChannelNode, + 'cat_placeholder': CatChannelNode, + } + if isinstance(node.val, nn.Module): + # module_mapping + for module_type in module_mapping: + if isinstance(node.val, module_type): + return module_mapping[module_type].copy_from(node) + + elif isinstance(node.val, str): + for module_type in name_mapping: + if node.val == module_type: + return name_mapping[module_type].copy_from(node) + + else: + for fun_type in function_mapping: + if node.val == fun_type: + return function_mapping[fun_type].copy_from(node) + if len(node.prev_nodes) > 1: + warn('BindChannelNode') + return BindChannelNode.copy_from(node) + else: + warn('PassChannelNode') + return PassChannelNode.copy_from(node) diff --git a/mmrazor/structures/graph/module_graph.py b/mmrazor/structures/graph/module_graph.py index 8b7f9c920..bc7e90dac 100644 --- a/mmrazor/structures/graph/module_graph.py +++ b/mmrazor/structures/graph/module_graph.py @@ -10,11 +10,16 @@ import torch.nn as nn from torch.nn import Module -from mmrazor.models.task_modules import (BackwardTracer, Path, PathConcatNode, - PathList, PathNode) -from mmrazor.models.task_modules.tracer import ImageClassifierPseudoLoss +from mmrazor.models.task_modules.tracer.backward_tracer import BackwardTracer +from mmrazor.models.task_modules.tracer.loss_calculator import \ + ImageClassifierPseudoLoss +from mmrazor.models.task_modules.tracer.path import (Path, PathConcatNode, + PathList, PathNode) +from mmrazor.registry import TASK_UTILS from .base_graph import BaseGraph, BaseNode +# ModuleNode && ModuleGraph + class ModuleNode(BaseNode): """A node in a computation graph. @@ -30,7 +35,8 @@ class ModuleNode(BaseNode): def __init__(self, name: str, val: Union[Module, str], - expand_ratio: int = 1) -> None: + expand_ratio: int = 1, + module_name='') -> None: """ Args: name (str): the name of the node @@ -56,6 +62,7 @@ def forward(x): 'expand != 1 is only valid when val=="pass"' super().__init__(name, val) self.expand_ratio = expand_ratio + self.module_name = module_name # channel @@ -204,25 +211,29 @@ def check_type(self): class ModuleGraph(BaseGraph[MODULENODE]): """Computatation Graph.""" - def __init__(self) -> None: + def __init__(self, model=None) -> None: super().__init__() - self._model = None + self._model: nn.Module = model # functions to generate module graph. @staticmethod - def init_using_backward_tracer( + def init_from_backward_tracer( model: Module, backward_tracer=BackwardTracer( loss_calculator=ImageClassifierPseudoLoss()), ): """init module graph using backward tracer.""" + if isinstance(backward_tracer, dict): + backward_tracer = TASK_UTILS.build(backward_tracer) path_lists = backward_tracer.trace(model) converter = PathToGraphConverter(path_lists, model) + converter.graph.refresh_module_name() return converter.graph @staticmethod - def init_using_fx_tracer(model: Module, is_extra_leaf_module=None): + def init_from_fx_tracer(model: Module, + fx_tracer={'type': 'RazorFxTracer'}): """init module graph using torch fx tracer.""" pass @@ -260,6 +271,16 @@ def connect_module(pre: Module, next: Module): next._pre = set() next._pre.add(pre) + # others + def refresh_module_name(self): + module2name = {} + for name, module in self._model.named_modules(): + module2name[module] = name + + for node in self: + if isinstance(node.val, nn.Module): + node.module_name = module2name[node.val] + # Converter @@ -267,8 +288,8 @@ def connect_module(pre: Module, next: Module): class GraphConverter: """Base class for converters for ModuleGraph.""" - def __init__(self) -> None: - self.graph = ModuleGraph[ModuleNode]() + def __init__(self, model) -> None: + self.graph = ModuleGraph[ModuleNode](model) self.cat_placeholder_num = 0 self.bind_placeholder_num = 0 self.pass_placeholder_num = 0 @@ -388,15 +409,15 @@ def __init__(self, path_list: PathList, model: Module) -> None: path_list (PathList): path_list generated by backward tracer. model (Module): the model corresponding to the path_list """ - super().__init__() + super().__init__(model) self.path_list = path_list self.cat_dict: Dict[str, str] = {} self.name2module = dict(model.named_modules()) - self._pass(self.path_list) + self._parse(self.path_list) self._post_process() - def _pass(self, path_list: PathList): + def _parse(self, path_list: PathList): """Parse path list.""" self._parse_helper(path_list, []) diff --git a/mmrazor/structures/subnet/fix_subnet.py b/mmrazor/structures/subnet/fix_subnet.py index 2b142f6ea..4eb515371 100644 --- a/mmrazor/structures/subnet/fix_subnet.py +++ b/mmrazor/structures/subnet/fix_subnet.py @@ -10,7 +10,7 @@ def _dynamic_to_static(model: nn.Module) -> None: # Avoid circular import - from mmrazor.models.architectures.dynamic_ops.bricks import DynamicMixin + from mmrazor.models.architectures.dynamic_ops import DynamicMixin def traverse_children(module: nn.Module) -> None: # TODO @@ -37,7 +37,7 @@ def load_fix_subnet(model: nn.Module, raise TypeError('fix_mutable should be a `str` or `dict`' f'but got {type(fix_mutable)}') - from mmrazor.models.architectures.dynamic_ops.bricks import DynamicMixin + from mmrazor.models.architectures.dynamic_ops import DynamicMixin if isinstance(model, DynamicMixin): raise RuntimeError('Root model can not be dynamic op.') diff --git a/mmrazor/utils/__init__.py b/mmrazor/utils/__init__.py index 5fb8dc209..8490e8eef 100644 --- a/mmrazor/utils/__init__.py +++ b/mmrazor/utils/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .index_dict import IndexDict from .misc import find_latest_checkpoint from .placeholder import get_placeholder from .setup_env import register_all_modules, setup_multi_processes @@ -9,5 +10,6 @@ __all__ = [ 'find_latest_checkpoint', 'setup_multi_processes', 'register_all_modules', 'FixMutable', 'ValidFixMutable', 'SingleMutatorRandomSubnet', - 'MultiMutatorsRandomSubnet', 'SupportRandomSubnet', 'get_placeholder' + 'MultiMutatorsRandomSubnet', 'SupportRandomSubnet', 'get_placeholder', + 'IndexDict' ] diff --git a/mmrazor/utils/index_dict.py b/mmrazor/utils/index_dict.py new file mode 100644 index 000000000..8ac3661c2 --- /dev/null +++ b/mmrazor/utils/index_dict.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from typing import Tuple, TypeVar + +VT = TypeVar('VT') # Value type + + +class IndexDict(OrderedDict): + """IndexDict inherents from OrderedDict[Tuple[int, int], VT]. Each + IndexDict object is a OrderDict object which using index(Tuple[int,int]) as + key and Any as value. + + The key type is Tuple[a: int,b: int]. It indicates a range in + the [a,b). + + IndexDict has three features: + 1. ensure a key always is a index(Tuple[int,int]). + 1. ensure the the indexes are sorted by ascending order. + 2. ensure there is no overlap among indexes. + """ + + def __setitem__(self, __k: Tuple[int, int], __v): + """set item.""" + start, end = __k + assert start < end + self._assert_no_over_lap(start, end) + super().__setitem__(__k, __v) + self._sort() + + def _sort(self): + """sort the dict accorrding to index.""" + items = sorted(self.items()) + self.clear() + for k, v in items: + super().__setitem__(k, v) + + def _assert_no_over_lap(self, start, end): + """Assert the index [start,end) has no over lav with existed + indexes.""" + assert (start, end) not in self, 'index overlap' + + def __contains__(self, __o) -> bool: + """Bool: if the index has any overlap with existed indexes""" + if super().__contains__(__o): + return True + else: + self._assert_is_index(__o) + start, end = __o + existed = False + for s, e in self.keys(): + existed = (s <= start < e or s < end < e or + (s < start and end < e)) or existed + + return existed + + def _assert_is_index(self, index): + """Assert the index is an instance of Tuple[int,int]""" + assert isinstance(index, Tuple) \ + and len(index) == 2 \ + and isinstance(index[0], int) \ + and isinstance(index[1], int) diff --git a/tests/__init__.py b/tests/__init__.py index ddce77790..ef101fec6 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .test_core.test_graph.test_graph import TestGraph - -__all__ = ['TestGraph'] diff --git a/tests/data/MBV2_slimmable_config.json b/tests/data/MBV2_slimmable_config.json new file mode 100644 index 000000000..f63029872 --- /dev/null +++ b/tests/data/MBV2_slimmable_config.json @@ -0,0 +1,392 @@ +{ + "backbone.conv1.conv_(0, 48)_48": { + "init_args": { + "num_channels": 48, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 8, + 8, + 32 + ], + "choice_mode": "number" + }, + "choice": 32 + }, + "backbone.layer1.0.conv.1.conv_(0, 24)_24": { + "init_args": { + "num_channels": 24, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 8, + 8, + 16 + ], + "choice_mode": "number" + }, + "choice": 16 + }, + "backbone.layer2.0.conv.0.conv_(0, 144)_144": { + "init_args": { + "num_channels": 144, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 96, + 96, + 144 + ], + "choice_mode": "number" + }, + "choice": 144 + }, + "backbone.layer2.0.conv.2.conv_(0, 40)_40": { + "init_args": { + "num_channels": 40, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 16, + 16, + 24 + ], + "choice_mode": "number" + }, + "choice": 24 + }, + "backbone.layer2.1.conv.0.conv_(0, 240)_240": { + "init_args": { + "num_channels": 240, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 96, + 96, + 176 + ], + "choice_mode": "number" + }, + "choice": 176 + }, + "backbone.layer3.0.conv.0.conv_(0, 240)_240": { + "init_args": { + "num_channels": 240, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 96, + 96, + 192 + ], + "choice_mode": "number" + }, + "choice": 192 + }, + "backbone.layer3.0.conv.2.conv_(0, 48)_48": { + "init_args": { + "num_channels": 48, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 24, + 24, + 48 + ], + "choice_mode": "number" + }, + "choice": 48 + }, + "backbone.layer3.1.conv.0.conv_(0, 288)_288": { + "init_args": { + "num_channels": 288, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 144, + 144, + 240 + ], + "choice_mode": "number" + }, + "choice": 240 + }, + "backbone.layer3.2.conv.0.conv_(0, 288)_288": { + "init_args": { + "num_channels": 288, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 144, + 144, + 144 + ], + "choice_mode": "number" + }, + "choice": 144 + }, + "backbone.layer4.0.conv.0.conv_(0, 288)_288": { + "init_args": { + "num_channels": 288, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 144, + 144, + 264 + ], + "choice_mode": "number" + }, + "choice": 264 + }, + "backbone.layer4.0.conv.2.conv_(0, 96)_96": { + "init_args": { + "num_channels": 96, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 48, + 56, + 88 + ], + "choice_mode": "number" + }, + "choice": 88 + }, + "backbone.layer4.1.conv.0.conv_(0, 576)_576": { + "init_args": { + "num_channels": 576, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 288, + 288, + 288 + ], + "choice_mode": "number" + }, + "choice": 288 + }, + "backbone.layer4.2.conv.0.conv_(0, 576)_576": { + "init_args": { + "num_channels": 576, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 288, + 288, + 336 + ], + "choice_mode": "number" + }, + "choice": 336 + }, + "backbone.layer4.3.conv.0.conv_(0, 576)_576": { + "init_args": { + "num_channels": 576, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 288, + 288, + 432 + ], + "choice_mode": "number" + }, + "choice": 432 + }, + "backbone.layer5.0.conv.0.conv_(0, 576)_576": { + "init_args": { + "num_channels": 576, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 288, + 288, + 576 + ], + "choice_mode": "number" + }, + "choice": 576 + }, + "backbone.layer5.0.conv.2.conv_(0, 144)_144": { + "init_args": { + "num_channels": 144, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 64, + 96, + 144 + ], + "choice_mode": "number" + }, + "choice": 144 + }, + "backbone.layer5.1.conv.0.conv_(0, 864)_864": { + "init_args": { + "num_channels": 864, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 432, + 432, + 576 + ], + "choice_mode": "number" + }, + "choice": 576 + }, + "backbone.layer5.2.conv.0.conv_(0, 864)_864": { + "init_args": { + "num_channels": 864, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 432, + 432, + 648 + ], + "choice_mode": "number" + }, + "choice": 648 + }, + "backbone.layer6.0.conv.0.conv_(0, 864)_864": { + "init_args": { + "num_channels": 864, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 648, + 864, + 864 + ], + "choice_mode": "number" + }, + "choice": 864 + }, + "backbone.layer6.0.conv.2.conv_(0, 240)_240": { + "init_args": { + "num_channels": 240, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 176, + 240, + 240 + ], + "choice_mode": "number" + }, + "choice": 240 + }, + "backbone.layer6.1.conv.0.conv_(0, 1440)_1440": { + "init_args": { + "num_channels": 1440, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 720, + 1440, + 1440 + ], + "choice_mode": "number" + }, + "choice": 1440 + }, + "backbone.layer6.2.conv.0.conv_(0, 1440)_1440": { + "init_args": { + "num_channels": 1440, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 720, + 960, + 1440 + ], + "choice_mode": "number" + }, + "choice": 1440 + }, + "backbone.layer7.0.conv.0.conv_(0, 1440)_1440": { + "init_args": { + "num_channels": 1440, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 1440, + 1440, + 1440 + ], + "choice_mode": "number" + }, + "choice": 1440 + }, + "backbone.layer7.0.conv.2.conv_(0, 480)_480": { + "init_args": { + "num_channels": 480, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 280, + 480, + 480 + ], + "choice_mode": "number" + }, + "choice": 480 + }, + "backbone.conv2.conv_(0, 1920)_1920": { + "init_args": { + "num_channels": 1920, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 1920, + 1920, + 1920 + ], + "choice_mode": "number" + }, + "choice": 1920 + }, + "head.fc_(0, 1000)_1000": { + "init_args": { + "num_channels": 1000, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 1000, + 1000, + 1000 + ], + "choice_mode": "number" + }, + "choice": 1000 + } +} \ No newline at end of file diff --git a/mmrazor/models/architectures/dynamic_ops/head/__init__.py b/tests/data/__init__.py similarity index 100% rename from mmrazor/models/architectures/dynamic_ops/head/__init__.py rename to tests/data/__init__.py diff --git a/tests/data/models.py b/tests/data/models.py index dd328b516..60c8a7058 100644 --- a/tests/data/models.py +++ b/tests/data/models.py @@ -3,10 +3,29 @@ from torch import Tensor import torch.nn as nn import torch - +from mmrazor.models.architectures.dynamic_ops import DynamicBatchNorm2d, DynamicConv2d, DynamicLinear, DynamicChannelMixin +from mmrazor.models.mutables.mutable_channel import MutableChannelContainer +from mmrazor.models.mutables import MutableChannelUnit +from mmrazor.models.mutables import DerivedMutable +from mmrazor.models.mutables import BaseMutable +from mmrazor.models.mutables import OneShotMutableChannelUnit, SquentialMutableChannel, OneShotMutableChannel +from mmrazor.registry import MODELS +from mmengine.model import BaseModel # this file includes models for tesing. +class LinearHead(Module): + + def __init__(self, in_channel, num_class=1000) -> None: + super().__init__() + self.pool = nn.AdaptiveAvgPool2d(1) + self.linear = nn.Linear(in_channel, num_class) + + def forward(self, x): + pool = self.pool(x).flatten(1) + return self.linear(pool) + + class MultiConcatModel(Module): """ x---------------- @@ -127,7 +146,7 @@ def forward(self, x: Tensor) -> Tensor: output = self.fc(x_pool) return output - + class ResBlock(Module): """ @@ -166,7 +185,7 @@ def forward(self, x: Tensor) -> Tensor: return output -class LineModel(Module): +class LineModel(BaseModel): """ x |net0,net1 @@ -233,6 +252,20 @@ def forward(self, x): class GroupWiseConvModel(nn.Module): + """ + x + |op1,bn1 + x1 + |op2,bn2 + x2 + |op3 + x3 + |avg_pool + x_pool + |fc + y + """ + def __init__(self) -> None: super().__init__() self.op1 = nn.Conv2d(3, 8, 3, 1, 1) @@ -240,6 +273,8 @@ def __init__(self) -> None: self.op2 = nn.Conv2d(8, 16, 3, 1, 1, groups=2) self.bn2 = nn.BatchNorm2d(16) self.op3 = nn.Conv2d(16, 32, 3, 1, 1) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(32, 1000) def forward(self, x): x1 = self.op1(x) @@ -291,6 +326,23 @@ def forward(self, x): class MultipleUseModel(nn.Module): + """ + x------------------------ + |conv0 |conv1 |conv2 |conv3 + xs.0 xs.1 xs.2 xs.3 + |convm |convm |convm |convm + xs_.0 xs_.1 xs_.2 xs_.3 + | | | | + +------------------------ + | + x_sum + |conv_last + feature + |avg_pool + pool + |linear + output + """ def __init__(self) -> None: super().__init__() @@ -299,7 +351,7 @@ def __init__(self) -> None: self.conv2 = nn.Conv2d(3, 8, 3, 1, 1) self.conv3 = nn.Conv2d(3, 8, 3, 1, 1) self.conv_multiple_use = nn.Conv2d(8, 16, 3, 1, 1) - self.conv_last = nn.Conv2d(16, 32, 3, 1, 1) + self.conv_last = nn.Conv2d(16 * 4, 32, 3, 1, 1) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.linear = nn.Linear(32, 1000) @@ -309,17 +361,226 @@ def forward(self, x): for conv in [self.conv0, self.conv1, self.conv2, self.conv3] ] xs_ = [self.conv_multiple_use(x_) for x_ in xs] - x_sum = 0 - for x_ in xs_: - x_sum = x_sum + x_ - feature = self.conv_last(x_sum) + x_cat = torch.cat(xs_, dim=1) + feature = self.conv_last(x_cat) pool = self.avg_pool(feature).flatten(1) return self.linear(pool) +class IcepBlock(nn.Module): + """ + x------------------------ + |op1 |op2 |op3 |op4 + x1 x2 x3 x4 + | | | | + cat---------------------- + | + y_ + """ + + def __init__(self, in_c=3, out_c=32) -> None: + super().__init__() + self.op1 = nn.Conv2d(in_c, out_c, 3, 1, 1) + self.op2 = nn.Conv2d(in_c, out_c, 3, 1, 1) + self.op3 = nn.Conv2d(in_c, out_c, 3, 1, 1) + self.op4 = nn.Conv2d(in_c, out_c, 3, 1, 1) + # self.op5 = nn.Conv2d(out_c*4, out_c, 3) + + def forward(self, x): + x1 = self.op1(x) + x2 = self.op2(x) + x3 = self.op3(x) + x4 = self.op4(x) + y_ = [x1, x2, x3, x4] + y_ = torch.cat(y_, 1) + return y_ + + +class Icep(nn.Module): + + def __init__(self, num_icep_blocks=2) -> None: + super().__init__() + self.icps = nn.Sequential(*[ + IcepBlock(32 * 4 if i != 0 else 3, 32) + for i in range(num_icep_blocks) + ]) + self.op = nn.Conv2d(32 * 4, 32, 1) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(32, 1000) + + def forward(self, x): + y_ = self.icps(x) + y = self.op(y_) + pool = self.avg_pool(y).flatten(1) + return self.fc(pool) + + +class ExpandLineModel(Module): + """ + x + |net0,net1,net2 + |net3,net4 + x1 + |fc + output + """ + + def __init__(self) -> None: + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(), + nn.Conv2d(8, 16, 3, 1, 1), nn.BatchNorm2d(16), + nn.AdaptiveAvgPool2d(2)) + self.linear = nn.Linear(64, 1000) + + def forward(self, x): + x1 = self.net(x) + x1 = x1.reshape([x1.shape[0], -1]) + return self.linear(x1) + + +class MultiBindModel(Module): + + def __init__(self) -> None: + super().__init__() + self.conv1 = nn.Conv2d(3, 8, 3, 1, 1) + self.conv2 = nn.Conv2d(3, 8, 3, 1, 1) + self.conv3 = nn.Conv2d(8, 8, 3, 1, 1) + self.head = LinearHead(8, 1000) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x) + x12 = x1 + x2 + x3 = self.conv3(x12) + x123 = x12 + x3 + return self.head(x123) + + +class DwConvModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(3, 48, 3, 1, 1), nn.BatchNorm2d(48), nn.ReLU(), + nn.Conv2d(48, 48, 3, 1, 1, groups=48), nn.BatchNorm2d(48), + nn.ReLU()) + self.head = LinearHead(48, 1000) + + def forward(self, x): + return self.head(self.net(x)) + + +# models with dynamicop + + +def register_mutable(module: DynamicChannelMixin, + mutable: OneShotMutableChannelUnit, + is_out=True, + start=0, + end=-1): + if end == -1: + end = mutable.num_channels + start + if is_out: + container: MutableChannelContainer = module.get_mutable_attr( + 'out_channels') + else: + container: MutableChannelContainer = module.get_mutable_attr( + 'in_channels') + container.register_mutable(mutable, start, end) + + +class SampleExpandDerivedMutable(BaseMutable): + + def __init__(self, expand_ratio=1) -> None: + super().__init__() + self.ratio = expand_ratio + + def __mul__(self, other): + if isinstance(other, OneShotMutableChannel): + + def _expand_mask(): + mask = other.current_mask + mask = torch.unsqueeze( + mask, + -1).expand(list(mask.shape) + [self.ratio]).flatten(-2) + return mask + + return DerivedMutable(_expand_mask, _expand_mask, [self, other]) + else: + raise NotImplementedError() + + def dump_chosen(self): + return super().dump_chosen() + + def fix_chosen(self, chosen): + return super().fix_chosen(chosen) + + def num_choices(self) -> int: + return super().num_choices + + +class DynamicLinearModel(nn.Module): + """ + x + |net0,net1 + |net2 + |net3 + x1 + |fc + output + """ + + def __init__(self) -> None: + super().__init__() + self.net = nn.Sequential( + DynamicConv2d(3, 8, 3, 1, 1), DynamicBatchNorm2d(8), nn.ReLU(), + DynamicConv2d(8, 16, 3, 1, 1), DynamicBatchNorm2d(16), + nn.AdaptiveAvgPool2d(1)) + self.linear = DynamicLinear(16, 1000) + + MutableChannelUnit._register_channel_container( + self, MutableChannelContainer) + self._register_mutable() + + def forward(self, x): + x1 = self.net(x) + x1 = x1.reshape([x1.shape[0], -1]) + return self.linear(x1) + + def _register_mutable(self): + mutable1 = OneShotMutableChannel(8, candidate_choices=[1, 4, 8]) + mutable2 = OneShotMutableChannel(16, candidate_choices=[2, 8, 16]) + mutable_value = SampleExpandDerivedMutable(1) + + MutableChannelContainer.register_mutable_channel_to_module( + self.net[0], mutable1, True) + MutableChannelContainer.register_mutable_channel_to_module( + self.net[1], mutable1.expand_mutable_channel(1), True, 0, 8) + MutableChannelContainer.register_mutable_channel_to_module( + self.net[3], mutable_value * mutable1, False, 0, 8) + + MutableChannelContainer.register_mutable_channel_to_module( + self.net[3], mutable2, True) + MutableChannelContainer.register_mutable_channel_to_module( + self.net[4], mutable2, True) + MutableChannelContainer.register_mutable_channel_to_module( + self.linear, mutable2, False) + + default_models = [ - LineModel, ResBlock, AddCatModel, ConcatModel, MultiConcatModel, - MultiConcatModel2, GroupWiseConvModel, Xmodel, MultipleUseModel + LineModel, + ResBlock, + AddCatModel, + ConcatModel, + MultiConcatModel, + MultiConcatModel2, + GroupWiseConvModel, + Xmodel, + MultipleUseModel, + Icep, + ExpandLineModel, + DwConvModel, ] diff --git a/tests/test_core/__init__.py b/tests/test_core/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_core/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_core/test_graph/__init__.py b/tests/test_core/test_graph/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_core/test_graph/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_core/test_graph/test_channel_graph.py b/tests/test_core/test_graph/test_channel_graph.py new file mode 100644 index 000000000..6eb3e1454 --- /dev/null +++ b/tests/test_core/test_graph/test_channel_graph.py @@ -0,0 +1,178 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch +from torch import nn + +from mmrazor.models.task_modules import BackwardTracer +from mmrazor.registry import TASK_UTILS +from mmrazor.structures.graph import ModuleGraph +from mmrazor.structures.graph.channel_graph import ChannelGraph +from mmrazor.structures.graph.channel_modules import (BaseChannelUnit, + ChannelTensor) +from mmrazor.structures.graph.channel_nodes import \ + default_channel_node_converter +from ...data.models import LineModel +from .test_graph import TestGraph + +NodeMap = {} + + +@TASK_UTILS.register_module() +class ImageClassifierPseudoLossWithSixChannel: + """Calculate the pseudo loss to trace the topology of a `ImageClassifier` + in MMClassification with `BackwardTracer`.""" + + def __call__(self, model) -> torch.Tensor: + pseudo_img = torch.rand(1, 6, 224, 224) + pseudo_output = model(pseudo_img) + return sum(pseudo_output) + + +class TestChannelGraph(unittest.TestCase): + + def test_init(self): + model = LineModel() + module_graph = ModuleGraph.init_from_backward_tracer(model) + + _ = ChannelGraph.copy_from(module_graph, + default_channel_node_converter) + + def test_forward(self): + for model_data in TestGraph.backward_tracer_passed_models(): + with self.subTest(model=model_data): + model = model_data() + module_graph = ModuleGraph.init_from_backward_tracer(model) + + channel_graph = ChannelGraph.copy_from( + module_graph, default_channel_node_converter) + channel_graph.forward() + + _ = channel_graph.collect_units + + def test_forward_with_config_num_in_channel(self): + + class MyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.conv1 = nn.Conv2d(6, 3, 3, 1, 1) + self.net = LineModel() + + def forward(self, x): + return self.net(self.conv1(x)) + + model = MyModel() + module_graph = ModuleGraph.init_from_backward_tracer( + model, + backward_tracer=BackwardTracer( + loss_calculator=ImageClassifierPseudoLossWithSixChannel())) + + channel_graph = ChannelGraph.copy_from(module_graph, + default_channel_node_converter) + channel_graph.forward(num_input_channel=6) + + _ = channel_graph.collect_units + + +class TestChannelUnit(unittest.TestCase): + + def test_union(self): + channel_tensor1 = ChannelTensor(8) + channel_tensor2 = ChannelTensor(8) + channel_tensor3 = ChannelTensor(8) + channel_tensor4 = ChannelTensor(8) + unit1 = channel_tensor1.unit_dict[(0, 8)] + unit2 = channel_tensor2.unit_dict[(0, 8)] + unit3 = channel_tensor3.unit_dict[(0, 8)] + unit4 = channel_tensor4.unit_dict[(0, 8)] + + unit12 = BaseChannelUnit.union_two_units(unit1, unit2) + self.assertDictEqual(channel_tensor1.unit_dict, + channel_tensor2.unit_dict) + + unit34 = BaseChannelUnit.union_two_units(unit3, unit4) + BaseChannelUnit.union_two_units(unit12, unit34) + self.assertDictEqual(channel_tensor1.unit_dict, + channel_tensor4.unit_dict) + + def test_split(self): + channel_tensor1 = ChannelTensor(8) + channel_tensor2 = ChannelTensor(8) + BaseChannelUnit.union_two_units(channel_tensor1.unit_dict[(0, 8)], + channel_tensor2.unit_dict[(0, 8)]) + unit1 = channel_tensor1.unit_dict[(0, 8)] + BaseChannelUnit.split_unit(unit1, [2, 6]) + + self.assertDictEqual(channel_tensor1.unit_dict, + channel_tensor2.unit_dict) + + +class TestChannelTensor(unittest.TestCase): + + def test_init(self): + channel_tensor = ChannelTensor(8) + self.assertIn((0, 8), channel_tensor.unit_dict) + + def test_align_with_nums(self): + channel_tensor = ChannelTensor(8) + channel_tensor.align_units_with_nums([2, 6]) + self.assertSequenceEqual( + list(channel_tensor.unit_dict.keys()), [(0, 2), (2, 8)]) + channel_tensor.align_units_with_nums([2, 2, 4]) + self.assertSequenceEqual( + list(channel_tensor.unit_dict.keys()), [(0, 2), (2, 4), (4, 8)]) + + def test_align_units(self): + channel_tensor1 = ChannelTensor(8) + channel_tensor2 = ChannelTensor(8) + channel_tensor3 = ChannelTensor(8) + + BaseChannelUnit.split_unit(channel_tensor1.unit_list[0], [2, 6]) + BaseChannelUnit.split_unit(channel_tensor2.unit_list[0], [4, 4]) + BaseChannelUnit.split_unit(channel_tensor3.unit_list[0], [6, 2]) + """ + xxoooooo + xxxxoooo + xxxxxxoo + """ + + ChannelTensor.align_tensors(channel_tensor1, channel_tensor2, + channel_tensor3) + for lst in [channel_tensor1, channel_tensor2, channel_tensor3]: + self.assertSequenceEqual( + list(lst.unit_dict.keys()), [ + (0, 2), + (2, 4), + (4, 6), + (6, 8), + ]) + + def test_expand(self): + channel_tensor = ChannelTensor(8) + expanded = channel_tensor.expand(4) + self.assertIn((0, 32), expanded.unit_dict) + + def test_union(self): + channel_tensor1 = ChannelTensor(8) + channel_tensor2 = ChannelTensor(8) + channel_tensor3 = ChannelTensor(8) + channel_tensor4 = ChannelTensor(8) + channel_tensor3.union(channel_tensor4) + + self.assertEqual( + id(channel_tensor3.unit_dict[(0, 8)]), + id(channel_tensor4.unit_dict[(0, 8)])) + + channel_tensor2.union(channel_tensor3) + channel_tensor1.union(channel_tensor2) + + self.assertEqual( + id(channel_tensor1.unit_dict[(0, 8)]), + id(channel_tensor2.unit_dict[(0, 8)])) + self.assertEqual( + id(channel_tensor2.unit_dict[(0, 8)]), + id(channel_tensor3.unit_dict[(0, 8)])) + self.assertEqual( + id(channel_tensor3.unit_dict[(0, 8)]), + id(channel_tensor4.unit_dict[(0, 8)])) diff --git a/tests/test_core/test_graph/test_graph.py b/tests/test_core/test_graph/test_graph.py index 2b7995a09..1383dccd8 100644 --- a/tests/test_core/test_graph/test_graph.py +++ b/tests/test_core/test_graph/test_graph.py @@ -1,16 +1,29 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os import sys from unittest import TestCase import torch +from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin from mmrazor.structures.graph import ModuleGraph -from tests.data.models import (AddCatModel, ConcatModel, LineModel, - MultiConcatModel, MultiConcatModel2, ResBlock) +from ...data.models import Icep # noqa +from ...data.models import MultipleUseModel # noqa +from ...data.models import Xmodel # noqa +from ...data.models import (AddCatModel, ConcatModel, DwConvModel, + ExpandLineModel, GroupWiseConvModel, LineModel, + ModelLibrary, MultiBindModel, MultiConcatModel, + MultiConcatModel2, ResBlock) + +FULL_TEST = os.getenv('FULL_TEST') == 'true' sys.setrecursionlimit(int(1e8)) +def is_dynamic_op_fx(module, name): + return isinstance(module, DynamicChannelMixin) + + class ToyCNNPseudoLoss: def __call__(self, model): @@ -19,59 +32,68 @@ def __call__(self, model): return pseudo_output.sum() -DATA = [ - { - 'model': LineModel, - 'num_nodes': 5, - }, - { - 'model': ResBlock, - 'num_nodes': 7, - }, - { - 'model': ConcatModel, - 'num_nodes': 7, - }, - { - 'model': MultiConcatModel2, - 'num_nodes': 7, - }, - { - 'model': MultiConcatModel, - 'num_nodes': 7, - }, - { - 'model': AddCatModel - }, -] - - class TestGraph(TestCase): - def test_graph_init(self) -> None: - - for data in DATA: + @classmethod + def backward_tracer_passed_models(cls): + '''MultipleUseModel: backward tracer can't distinguish multiple use and + first bind then use.''' + default_models = [ + LineModel, + ResBlock, + AddCatModel, + ConcatModel, + MultiConcatModel, + MultiConcatModel2, + GroupWiseConvModel, + Xmodel, + # MultipleUseModel, # bug + # Icep, bug + ExpandLineModel, + MultiBindModel, + DwConvModel + ] + """ + googlenet return a tuple when training, so it + should trace in eval mode + """ + + torch_models_includes = [ + 'alexnet', + 'densenet', + 'efficientnet', + 'googlenet', + # 'inception', bug + 'mnasnet', + 'mobilenet', + 'regnet', + 'resnet', + 'resnext', + # 'shufflenet', # bug + 'squeezenet', + 'vgg', + 'wide_resnet', + ] + model_library = ModelLibrary(torch_models_includes) + + models = default_models + model_library.export_models( + ) if FULL_TEST else default_models + return models + + def test_init_from_backward_tracer(self) -> None: + TestData = self.backward_tracer_passed_models() + + for data in TestData: with self.subTest(data=data): - model = data['model']() - # print(model) - graphs = [ - ModuleGraph.init_using_backward_tracer(model), - ] - - unit_num = len(graphs[0].nodes) - - for graph in graphs: - - # check channels - try: - graph.check() - except Exception as e: - self.fail(str(e) + '\n' + str(graph)) - - # check number of nodes - self.assertEqual(unit_num, len(graph.nodes)) - if 'num_nodes' in data: - self.assertEqual( - len(graph), - data['num_nodes'], - msg=f'{graph.nodes}') + model = data() + model.eval() + graph = ModuleGraph.init_from_backward_tracer(model) + + # check channels + self._valid_graph(graph) + + def _valid_graph(self, graph: ModuleGraph): + try: + graph.check() + except Exception as e: + self.fail(str(e) + '\n' + str(graph)) diff --git a/tests/test_models/__init__.py b/tests/test_models/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_models/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_algorithms/__init__.py b/tests/test_models/test_algorithms/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_models/test_algorithms/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_algorithms/test_autoslim.py b/tests/test_models/test_algorithms/test_autoslim.py index f73222630..79169b3cf 100644 --- a/tests/test_models/test_algorithms/test_autoslim.py +++ b/tests/test_models/test_algorithms/test_autoslim.py @@ -23,19 +23,20 @@ backbone=dict(type='MobileNetV2', widen_factor=1.5), neck=dict(type='GlobalAveragePooling'), head=dict( - type='LinearClsHead', + type='mmcls.LinearClsHead', num_classes=1000, in_channels=1920, - loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0), topk=(1, 5))) MUTATOR_CFG = dict( type='OneShotChannelMutator', - mutable_cfg=dict( - type='OneShotMutableChannel', - candidate_choices=list(i / 12 for i in range(2, 13)), - candidate_mode='ratio'), - tracer_cfg=dict( + channel_unit_cfg=dict( + type='OneShotMutableChannelUnit', + default_args=dict( + candidate_choices=list(i / 12 for i in range(2, 13)), + choice_mode='ratio')), + parse_cfg=dict( type='BackwardTracer', loss_calculator=dict(type='ImageClassifierPseudoLoss'))) @@ -50,9 +51,17 @@ preds_S=dict(recorder='fc', from_student=True), preds_T=dict(recorder='fc', from_student=False)))) -OPTIMIZER_CFG = dict( - type='SGD', lr=0.5, momentum=0.9, nesterov=True, weight_decay=0.0001) -OPTIM_WRAPPER_CFG = dict(optimizer=OPTIMIZER_CFG, accumulative_counts=4) +OPTIM_WRAPPER_CFG = dict( + optimizer=dict( + type='mmcls.SGD', + lr=0.5, + momentum=0.9, + weight_decay=4e-05, + _scope_='mmrazor'), + paramwise_cfg=dict( + bias_decay_mult=0.0, norm_decay_mult=0.0, dwconv_decay_mult=0.0), + clip_grad=None, + accumulative_counts=4) class FakeMutator: @@ -65,7 +74,7 @@ def forward( self, data: Dict, training: bool = True) -> Tuple[torch.Tensor, List[ClsDataSample]]: - return data['inputs'], data['data_samples'] + return data @unittest.skipIf( @@ -76,9 +85,15 @@ class TestAutoSlim(TestCase): def test_init(self) -> None: mutator_wrong_type = FakeMutator() - with pytest.raises(TypeError): + with pytest.raises(Exception): _ = self.prepare_model(mutator_wrong_type) + algo = self.prepare_model() + self.assertSequenceEqual( + algo.mutator.mutable_units[0].candidate_choices, + list(i / 12 for i in range(2, 13)), + ) + def test_autoslim_train_step(self) -> None: algo = self.prepare_model() data = self._prepare_fake_data() @@ -92,11 +107,11 @@ def test_autoslim_train_step(self) -> None: assert len(losses) == 7 assert losses['max_subnet.loss'] > 0 assert losses['min_subnet.loss'] > 0 - assert losses['min_subnet.loss_kl'] > 0 + assert losses['min_subnet.loss_kl'] + 1e-5 > 0 assert losses['random_subnet_0.loss'] > 0 - assert losses['random_subnet_0.loss_kl'] > 0 + assert losses['random_subnet_0.loss_kl'] + 1e-5 > 0 assert losses['random_subnet_1.loss'] > 0 - assert losses['random_subnet_1.loss_kl'] > 0 + assert losses['random_subnet_1.loss_kl'] + 1e-5 > 0 assert algo._optim_wrapper_count_status_reinitialized assert optim_wrapper._inner_count == 4 diff --git a/tests/test_models/test_algorithms/test_ofd_algo.py b/tests/test_models/test_algorithms/test_ofd_algo.py index ff2851584..4b7442d68 100644 --- a/tests/test_models/test_algorithms/test_ofd_algo.py +++ b/tests/test_models/test_algorithms/test_ofd_algo.py @@ -3,9 +3,9 @@ from unittest import TestCase from mmengine import ConfigDict -from toy_models import ToyOFDStudent from mmrazor.models import OverhaulFeatureDistillation +from .toy_models import ToyOFDStudent class TestSingleTeacherDistill(TestCase): diff --git a/tests/test_models/test_algorithms/test_prune_algorithm.py b/tests/test_models/test_algorithms/test_prune_algorithm.py new file mode 100644 index 000000000..519407772 --- /dev/null +++ b/tests/test_models/test_algorithms/test_prune_algorithm.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +import unittest + +import torch +from mmcls.structures import ClsDataSample +from mmengine import MessageHub +from mmengine.model import BaseModel + +from mmrazor.models.algorithms.pruning.ite_prune_algorithm import ( + ItePruneAlgorithm, ItePruneConfigManager) +from mmrazor.registry import MODELS + + +# @TASK_UTILS.register_module() +class ImageClassifierPseudoLoss: + """Calculate the pseudo loss to trace the topology of a `ImageClassifier` + in MMClassification with `BackwardTracer`.""" + + def __call__(self, model) -> torch.Tensor: + pseudo_img = torch.rand(2, 3, 32, 32) + pseudo_output = model(pseudo_img) + return pseudo_output.sum() + + +MODEL_CFG = dict( + _scope_='mmcls', + type='ImageClassifier', + backbone=dict( + type='ResNet', + depth=18, + num_stages=4, + out_indices=(3, ), + style='pytorch'), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=512, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) + +MUTATOR_CONFIG_NUM = dict( + type='ChannelMutator', + channel_unit_cfg={ + 'type': 'SequentialMutableChannelUnit', + 'default_args': { + 'choice_mode': 'number' + } + }) +MUTATOR_CONFIG_FLOAT = dict( + type='ChannelMutator', + channel_unit_cfg={ + 'type': 'SequentialMutableChannelUnit', + 'default_args': { + 'choice_mode': 'ratio' + } + }) + +if torch.cuda.is_available(): + DEVICE = torch.device('cuda:0') +else: + DEVICE = torch.device('cpu') + + +class TestItePruneAlgorithm(unittest.TestCase): + + def _set_epoch_ite(self, epoch, ite, max_epoch): + iter_per_epoch = 10 + message_hub = MessageHub.get_current_instance() + message_hub.update_info('epoch', epoch) + message_hub.update_info('max_epochs', max_epoch) + message_hub.update_info('max_iters', max_epoch * 10) + message_hub.update_info('iter', ite + iter_per_epoch * epoch) + + def fake_cifar_data(self): + imgs = torch.randn(16, 3, 32, 32).to(DEVICE) + data_samples = [ + ClsDataSample().set_gt_label(torch.randint(0, 10, + (16, ))).to(DEVICE) + ] + + return {'inputs': imgs, 'data_samples': data_samples} + + def test_ite_prune_config_manager(self): + float_origin, float_target = 1.0, 0.5 + int_origin, int_target = 10, 5 + for origin, target, manager in [ + (float_origin, float_target, + ItePruneConfigManager({'a': float_target}, {'a': float_origin}, 2, + 5)), + (int_origin, int_target, + ItePruneConfigManager({'a': int_target}, {'a': int_origin}, 2, 5)) + ]: + times = 1 + for e in range(1, 20): + for ite in range(1, 5): + self._set_epoch_ite(e, ite, 5) + if (e, ite) in [(0, 0), (2, 0), (4, 0), (6, 0), (8, 0)]: + self.assertTrue(manager.is_prune_time(e, ite)) + self.assertEqual( + manager.prune_at(e)['a'], + origin - (origin - target) * times / 5) + times += 1 + else: + self.assertFalse(manager.is_prune_time(e, ite)) + + def test_iterative_prune_int(self): + + data = self.fake_cifar_data() + + model = MODELS.build(MODEL_CFG) + mutator = MODELS.build(MUTATOR_CONFIG_FLOAT) + mutator.prepare_from_supernet(model) + prune_target = mutator.sample_choices() + + epoch = 10 + epoch_step = 2 + times = 3 + + algorithm = ItePruneAlgorithm( + MODEL_CFG, + target_pruning_ratio=prune_target, + mutator_cfg=MUTATOR_CONFIG_FLOAT, + step_epoch=epoch_step, + prune_times=times).to(DEVICE) + + for e in range(epoch): + for ite in range(5): + self._set_epoch_ite(e, ite, 5) + + algorithm.forward( + data['inputs'], data['data_samples'], mode='loss') + + current_choices = algorithm.mutator.current_choices + for key in current_choices: + self.assertAlmostEqual( + current_choices[key], prune_target[key], delta=0.1) + + def test_load_pretrained(self): + epoch_step = 2 + times = 3 + data = self.fake_cifar_data() + + # prepare checkpoint + model_cfg = copy.deepcopy(MODEL_CFG) + model: BaseModel = MODELS.build(model_cfg) + checkpoint_path = os.path.dirname(__file__) + '/checkpoint' + torch.save(model.state_dict(), checkpoint_path) + + # build algorithm + model_cfg['init_cfg'] = { + 'type': 'Pretrained', + 'checkpoint': checkpoint_path + } + algorithm = ItePruneAlgorithm( + model_cfg, + mutator_cfg=MUTATOR_CONFIG_NUM, + target_pruning_ratio={}, + step_epoch=epoch_step, + prune_times=times, + ).to(DEVICE) + algorithm.init_weights() + algorithm.forward(data['inputs'], data['data_samples'], mode='loss') + + # delete checkpoint + os.remove(checkpoint_path) diff --git a/tests/test_models/test_algorithms/test_single_teacher_distill.py b/tests/test_models/test_algorithms/test_single_teacher_distill.py index 99392e6c0..249e4878c 100644 --- a/tests/test_models/test_algorithms/test_single_teacher_distill.py +++ b/tests/test_models/test_algorithms/test_single_teacher_distill.py @@ -4,9 +4,9 @@ import torch from mmengine import ConfigDict -from toy_models import ToyStudent from mmrazor.models import SingleTeacherDistill +from .toy_models import ToyStudent class TestSingleTeacherDistill(TestCase): diff --git a/tests/test_models/test_algorithms/test_slimmable_network.py b/tests/test_models/test_algorithms/test_slimmable_network.py index a9a7bc16e..13efbcf84 100644 --- a/tests/test_models/test_algorithms/test_slimmable_network.py +++ b/tests/test_models/test_algorithms/test_slimmable_network.py @@ -9,7 +9,6 @@ import torch import torch.distributed as dist from mmcls.structures import ClsDataSample -from mmengine import fileio from mmengine.optim import build_optim_wrapper from mmrazor.models.algorithms import SlimmableNetwork, SlimmableNetworkDDP @@ -25,11 +24,12 @@ in_channels=1920, loss=dict(type='CrossEntropyLoss', loss_weight=1.0), topk=(1, 5))) +CHANNEL_CFG_PATH = 'tests/data/MBV2_slimmable_config.json' MUTATOR_CFG = dict( type='SlimmableChannelMutator', - mutable_cfg=dict(type='SlimmableMutableChannel'), - tracer_cfg=dict( + channel_unit_cfg=dict(type='SlimmableChannelUnit', units=CHANNEL_CFG_PATH), + parse_cfg=dict( type='BackwardTracer', loss_calculator=dict(type='ImageClassifierPseudoLoss'))) @@ -41,8 +41,7 @@ OPTIMIZER_CFG = dict( type='SGD', lr=0.5, momentum=0.9, nesterov=True, weight_decay=0.0001) -OPTIM_WRAPPER_CFG = dict( - optimizer=OPTIMIZER_CFG, accumulative_counts=len(CHANNEL_CFG_PATHS)) +OPTIM_WRAPPER_CFG = dict(optimizer=OPTIMIZER_CFG, accumulative_counts=3) class FakeMutator: @@ -55,79 +54,40 @@ def forward( self, data: Dict, training: bool = True) -> Tuple[torch.Tensor, List[ClsDataSample]]: - return data['inputs'], data['data_samples'] + return data class TestSlimmable(TestCase): device: str = 'cpu' - def test_merge_channel_cfgs(self) -> None: - channel_cfg1 = { - 'layer1': { - 'current_choice': 2, - 'origin_channel': 4 - }, - 'layer2': { - 'current_choice': 5, - 'origin_channel': 8 - } - } - channel_cfg2 = { - 'layer1': { - 'current_choice': 1, - 'origin_channel': 4 - }, - 'layer2': { - 'current_choice': 4, - 'origin_channel': 8 - } - } - - self.assertEqual( - SlimmableNetwork.merge_channel_cfgs([channel_cfg1, channel_cfg2]), - { - 'layer1': { - 'current_choice': [2, 1], - 'origin_channel': [4, 4] - }, - 'layer2': { - 'current_choice': [5, 4], - 'origin_channel': [8, 8] - } - }) - def test_init(self) -> None: - channel_cfgs = self._load_and_merge_channel_cfgs(CHANNEL_CFG_PATHS) - mutator_with_channel_cfgs = copy.deepcopy(MUTATOR_CFG) - mutator_with_channel_cfgs['channel_cfgs'] = channel_cfgs - - with pytest.raises(AssertionError): - _ = self.prepare_model(mutator_with_channel_cfgs, MODEL_CFG, - CHANNEL_CFG_PATHS) mutator_wrong_type = FakeMutator() - with pytest.raises(TypeError): - _ = self.prepare_model(mutator_wrong_type, MODEL_CFG, - CHANNEL_CFG_PATHS) + with pytest.raises(AttributeError): + _ = self.prepare_model(mutator_wrong_type, MODEL_CFG) + + # assert has prunable units + algo = SlimmableNetwork(MUTATOR_CFG, MODEL_CFG) + self.assertGreater(len(algo.mutator.mutable_units), 0) + + # assert can generate config template + mutator_cfg = copy.deepcopy(MUTATOR_CFG) + mutator_cfg['channel_unit_cfg']['units'] = {} + algo = SlimmableNetwork(mutator_cfg, MODEL_CFG) + try: + algo.mutator.config_template() + except Exception: + self.fail() def test_is_deployed(self) -> None: slimmable_should_not_deployed = \ - SlimmableNetwork(MUTATOR_CFG, MODEL_CFG, CHANNEL_CFG_PATHS) + SlimmableNetwork(MUTATOR_CFG, MODEL_CFG) assert not slimmable_should_not_deployed.is_deployed slimmable_should_deployed = \ - SlimmableNetwork(MUTATOR_CFG, MODEL_CFG, CHANNEL_CFG_PATHS[0]) + SlimmableNetwork(MUTATOR_CFG, MODEL_CFG, deploy_index=0) assert slimmable_should_deployed.is_deployed - def _load_and_merge_channel_cfgs(self, - channel_cfg_paths: List[str]) -> Dict: - channel_cfgs = list() - for channel_cfg_path in channel_cfg_paths: - channel_cfg = fileio.load(channel_cfg_path) - channel_cfgs.append(channel_cfg) - - return SlimmableNetwork.merge_channel_cfgs(channel_cfgs) - def test_slimmable_train_step(self) -> None: algo = self.prepare_slimmable_model() data = self._prepare_fake_data() @@ -171,16 +131,17 @@ def _prepare_fake_data(self) -> Dict: return {'inputs': imgs, 'data_samples': data_samples} def prepare_slimmable_model(self) -> SlimmableNetwork: - return self.prepare_model(MUTATOR_CFG, MODEL_CFG, CHANNEL_CFG_PATHS) + return self.prepare_model(MUTATOR_CFG, MODEL_CFG) def prepare_fixed_model(self) -> SlimmableNetwork: - channel_cfg_paths = CHANNEL_CFG_PATHS[0] - return self.prepare_model(MUTATOR_CFG, MODEL_CFG, channel_cfg_paths) + return self.prepare_model(MUTATOR_CFG, MODEL_CFG, deploy=0) - def prepare_model(self, mutator_cfg: Dict, model_cfg: Dict, - channel_cfg_paths: Dict) -> SlimmableNetwork: - model = SlimmableNetwork(mutator_cfg, model_cfg, channel_cfg_paths, + def prepare_model(self, + mutator_cfg: Dict, + model_cfg: Dict, + deploy=-1) -> SlimmableNetwork: + model = SlimmableNetwork(mutator_cfg, model_cfg, deploy, ToyDataPreprocessor()) model.to(self.device) @@ -202,10 +163,11 @@ def setUpClass(cls) -> None: backend = 'gloo' dist.init_process_group(backend, rank=0, world_size=1) - def prepare_model(self, mutator_cfg: Dict, model_cfg: Dict, - channel_cfg_paths: Dict) -> SlimmableNetworkDDP: - model = super().prepare_model(mutator_cfg, model_cfg, - channel_cfg_paths) + def prepare_model(self, + mutator_cfg: Dict, + model_cfg: Dict, + deploy=-1) -> SlimmableNetwork: + model = super().prepare_model(mutator_cfg, model_cfg, deploy) return SlimmableNetworkDDP(module=model, find_unused_parameters=True) def test_is_deployed(self) -> None: diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_conv.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_conv.py index 6082817ea..8eab78af8 100644 --- a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_conv.py +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_conv.py @@ -8,10 +8,10 @@ import torch from torch import nn -from mmrazor.models.architectures.dynamic_ops.bricks import (BigNasConv2d, - DynamicConv2d, - OFAConv2d) -from mmrazor.models.mutables import OneShotMutableChannel, OneShotMutableValue +from mmrazor.models.architectures.dynamic_ops import (BigNasConv2d, + DynamicConv2d, OFAConv2d) +from mmrazor.models.mutables import (OneShotMutableValue, + SquentialMutableChannel) from mmrazor.structures.subnet import export_fix_subnet, load_fix_subnet from ..utils import fix_dynamic_op @@ -39,10 +39,8 @@ def test_dynamic_conv2d_depthwise(self) -> None: with pytest.raises(ValueError): d_conv2d.register_mutable_attr('out_channels', mock_mutable) - mutable_in_channels = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') - mutable_out_channels = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') + mutable_in_channels = SquentialMutableChannel(10) + mutable_out_channels = SquentialMutableChannel(10) d_conv2d.register_mutable_attr('in_channels', mutable_in_channels) d_conv2d.register_mutable_attr('out_channels', mutable_out_channels) @@ -82,10 +80,8 @@ def test_dynamic_conv2d(bias: bool) -> None: x_max = torch.rand(10, 4, 224, 224) out_before_mutate = d_conv2d(x_max) - mutable_in_channels = OneShotMutableChannel( - 4, candidate_choices=[2, 3, 4], candidate_mode='number') - mutable_out_channels = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') + mutable_in_channels = SquentialMutableChannel(4) + mutable_out_channels = SquentialMutableChannel(10) d_conv2d.register_mutable_attr('in_channels', mutable_in_channels) d_conv2d.register_mutable_attr('out_channels', mutable_out_channels) @@ -128,8 +124,7 @@ def test_dynamic_conv2d_mutable_single_channels(is_mutate_in_channels: bool, out_channels: int) -> None: d_conv2d = DynamicConv2d( in_channels=10, out_channels=10, kernel_size=3, stride=1, bias=True) - mutable_channels = OneShotMutableChannel( - 10, candidate_choices=[4, 6, 10], candidate_mode='number') + mutable_channels = SquentialMutableChannel(10) if is_mutate_in_channels: d_conv2d.register_mutable_attr('in_channels', mutable_channels) @@ -172,10 +167,8 @@ def test_dynamic_conv2d_mutable_single_channels(is_mutate_in_channels: bool, def test_kernel_dynamic_conv2d(dynamic_class: Type[nn.Conv2d], kernel_size_list: bool) -> None: - mutable_in_channels = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') - mutable_out_channels = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') + mutable_in_channels = SquentialMutableChannel(10) + mutable_out_channels = SquentialMutableChannel(10) mutable_kernel_size = OneShotMutableValue(value_list=kernel_size_list) @@ -233,7 +226,7 @@ def test_kernel_dynamic_conv2d(dynamic_class: Type[nn.Conv2d], @pytest.mark.parametrize('dynamic_class', [OFAConv2d, BigNasConv2d]) def test_mutable_kernel_dynamic_conv2d_grad( dynamic_class: Type[nn.Conv2d]) -> None: - from mmrazor.models.architectures.dynamic_ops.bricks import \ + from mmrazor.models.architectures.dynamic_ops.mixins import \ dynamic_conv_mixins kernel_size_list = [3, 5, 7] diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_linear.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_linear.py index 0cdaa20b6..ece69ddc0 100644 --- a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_linear.py +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_linear.py @@ -6,19 +6,18 @@ import torch from torch import nn -from mmrazor.models.architectures.dynamic_ops.bricks import ( # noqa - DynamicLinear, DynamicLinearMixin) -from mmrazor.models.mutables import OneShotMutableChannel +from mmrazor.models.mutables import SquentialMutableChannel from mmrazor.structures.subnet import export_fix_subnet, load_fix_subnet from ..utils import fix_dynamic_op +from mmrazor.models.architectures.dynamic_ops import ( # isort:skip + DynamicLinear, DynamicLinearMixin) + @pytest.mark.parametrize('bias', [True, False]) def test_dynamic_linear(bias) -> None: - mutable_in_features = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') - mutable_out_features = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') + mutable_in_features = SquentialMutableChannel(10) + mutable_out_features = SquentialMutableChannel(10) d_linear = DynamicLinear(in_features=10, out_features=10, bias=bias) @@ -77,8 +76,7 @@ def test_dynamic_linear_mutable_single_features( is_mutate_in_features: Optional[bool], in_features: int, out_features: int) -> None: d_linear = DynamicLinear(in_features=10, out_features=10, bias=True) - mutable_channels = OneShotMutableChannel( - 10, candidate_choices=[4, 6, 10], candidate_mode='number') + mutable_channels = SquentialMutableChannel(10) if is_mutate_in_features is not None: if is_mutate_in_features: diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_norm.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_norm.py index e6cfe103f..ce6ae7b36 100644 --- a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_norm.py +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_norm.py @@ -6,9 +6,11 @@ import torch from torch import nn -from mmrazor.models.architectures.dynamic_ops.bricks import ( - DynamicBatchNorm1d, DynamicBatchNorm2d, DynamicBatchNorm3d, DynamicMixin) -from mmrazor.models.mutables import OneShotMutableChannel +from mmrazor.models.architectures.dynamic_ops import (DynamicBatchNorm1d, + DynamicBatchNorm2d, + DynamicBatchNorm3d, + DynamicMixin) +from mmrazor.models.mutables import SquentialMutableChannel from mmrazor.structures.subnet import export_fix_subnet, load_fix_subnet from ..utils import fix_dynamic_op @@ -22,8 +24,7 @@ def test_dynamic_bn(dynamic_class: Type[nn.modules.batchnorm._BatchNorm], input_shape: Tuple[int], affine: bool, track_running_stats: bool) -> None: - mutable_num_features = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') + mutable_num_features = SquentialMutableChannel(10) d_bn = dynamic_class( num_features=10, @@ -87,8 +88,7 @@ def test_bn_track_running_stats( dynamic_class: Type[nn.modules.batchnorm._BatchNorm], input_shape: Tuple[int], ) -> None: - mutable_num_features = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') + mutable_num_features = SquentialMutableChannel(10) mutable_num_features.current_choice = 8 d_bn = dynamic_class( num_features=10, track_running_stats=True, affine=False) diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_default_dynamic_op.py b/tests/test_models/test_architectures/test_dynamic_op/test_default_dynamic_op.py deleted file mode 100644 index 97277d03a..000000000 --- a/tests/test_models/test_architectures/test_dynamic_op/test_default_dynamic_op.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -"""from unittest import TestCase. - -import pytest -import torch - -from mmrazor.models.architectures import DynamicConv2d -from mmrazor.structures import export_fix_subnet, load_fix_subnet -from .utils import fix_dynamic_op - -class TestDefaultDynamicOP(TestCase): - - def test_dynamic_conv2d(self) -> None: - in_channels_cfg = dict(type='SlimmableMutableChannel', num_channels=4) - out_channels_cfg = dict( - type='SlimmableMutableChannel', num_channels=10) - - d_conv2d = DynamicConv2d( - in_channels_cfg, - out_channels_cfg, - in_channels=4, - out_channels=10, - kernel_size=3, - stride=1, - bias=True) - - d_conv2d.mutable_in.candidate_choices = [2, 3, 4] - d_conv2d.mutable_out.candidate_choices = [4, 8, 10] - - with pytest.raises(AssertionError): - d_conv2d.to_static_op() - - d_conv2d.mutable_in.current_choice = 1 - d_conv2d.mutable_out.current_choice = 0 - - x = torch.rand(10, 3, 224, 224) - out1 = d_conv2d(x) - self.assertEqual(out1.size(1), 4) - - fix_mutables = export_fix_subnet(d_conv2d) - with pytest.raises(RuntimeError): - load_fix_subnet(d_conv2d, fix_mutables) - fix_dynamic_op(d_conv2d, fix_mutables) - - out2 = d_conv2d(x) - self.assertTrue(torch.equal(out1, out2)) - - s_conv2d = d_conv2d.to_static_op() - out3 = s_conv2d(x) - - self.assertTrue(torch.equal(out1, out3)) - - def test_dynamic_conv2d_depthwise(self) -> None: - in_channels_cfg = dict(type='SlimmableMutableChannel', num_channels=10) - out_channels_cfg = dict( - type='SlimmableMutableChannel', num_channels=10) - - d_conv2d = DynamicConv2d( - in_channels_cfg, - out_channels_cfg, - in_channels=10, - out_channels=10, - groups=10, - kernel_size=3, - stride=1, - bias=True) - - d_conv2d.mutable_in.candidate_choices = [4, 8, 10] - d_conv2d.mutable_out.candidate_choices = [4, 8, 10] - - with pytest.raises(AssertionError): - d_conv2d.to_static_op() - - d_conv2d.mutable_in.current_choice = 1 - d_conv2d.mutable_out.current_choice = 1 - - x = torch.rand(10, 8, 224, 224) - out1 = d_conv2d(x) - self.assertEqual(out1.size(1), 8) - - fix_mutables = export_fix_subnet(d_conv2d) - with pytest.raises(RuntimeError): - load_fix_subnet(d_conv2d, fix_mutables) - fix_dynamic_op(d_conv2d, fix_mutables) - - out2 = d_conv2d(x) - self.assertTrue(torch.equal(out1, out2)) - - s_conv2d = d_conv2d.to_static_op() - out3 = s_conv2d(x) - - self.assertTrue(torch.equal(out1, out3)) -""" diff --git a/tests/test_models/test_architectures/test_dynamic_op/utils.py b/tests/test_models/test_architectures/test_dynamic_op/utils.py index 506fe1b5a..ceb2a5d4f 100644 --- a/tests/test_models/test_architectures/test_dynamic_op/utils.py +++ b/tests/test_models/test_architectures/test_dynamic_op/utils.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, Optional -from mmrazor.models.architectures.dynamic_ops import DynamicOP +from mmrazor.models.architectures.dynamic_ops import DynamicMixin -def fix_dynamic_op(op: DynamicOP, fix_mutables: Optional[Dict] = None) -> None: +def fix_dynamic_op(op: DynamicMixin, + fix_mutables: Optional[Dict] = None) -> None: for name, mutable in op.mutable_attrs.items(): if fix_mutables is not None: diff --git a/tests/test_models/test_mutables/__init__.py b/tests/test_models/test_mutables/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_models/test_mutables/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_mutables/test_channel_mutable.py b/tests/test_models/test_mutables/test_channel_mutable.py deleted file mode 100644 index fd808351c..000000000 --- a/tests/test_models/test_mutables/test_channel_mutable.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -from unittest import TestCase - -import pytest -import torch - -from mmrazor.models import OneShotMutableChannel - - -class TestChannelMutables(TestCase): - - def test_mutable_channel_ratio(self): - with pytest.raises(AssertionError): - # Test invalid `candidate_mode` - OneShotMutableChannel( - num_channels=8, - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='xxx') - - with pytest.raises(AssertionError): - # Number of candidate choices must be greater than 0 - OneShotMutableChannel( - num_channels=8, - candidate_choices=list(), - candidate_mode='ratio') - - with pytest.raises(AssertionError): - # The candidate ratio should be in range(0, 1]. - OneShotMutableChannel( - num_channels=8, - candidate_choices=[0., 1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - with pytest.raises(AssertionError): - # Minimum number of channels should be a positive integer. - out_mutable = OneShotMutableChannel( - num_channels=8, - candidate_choices=[0.01, 1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - out_mutable.bind_mutable_name('op') - _ = out_mutable.min_choice - - # Test mutable out - out_mutable = OneShotMutableChannel( - num_channels=8, - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - random_choice = out_mutable.sample_choice() - assert random_choice in [2, 4, 6, 8] - - max_choice = out_mutable.max_choice - assert max_choice == 8 - out_mutable.current_choice = max_choice - assert torch.equal(out_mutable.current_mask, - torch.ones_like(out_mutable.current_mask).bool()) - - min_choice = out_mutable.min_choice - assert min_choice == 2 - out_mutable.current_choice = min_choice - min_mask = torch.zeros_like(out_mutable.current_mask).bool() - min_mask[:2] = True - assert torch.equal(out_mutable.current_mask, min_mask) - - # Test mutable in with concat_mutable - in_mutable = OneShotMutableChannel( - num_channels=16, - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - out_mutable1 = copy.deepcopy(out_mutable) - out_mutable2 = copy.deepcopy(out_mutable) - in_mutable.register_same_mutable([out_mutable1, out_mutable2]) - choice1 = out_mutable1.sample_choice() - out_mutable1.current_choice = choice1 - choice2 = out_mutable2.sample_choice() - out_mutable2.current_choice = choice2 - assert torch.equal( - in_mutable.current_mask, - torch.cat([out_mutable1.current_mask, out_mutable2.current_mask])) - - with pytest.raises(AssertionError): - # The mask of this in_mutable depends on the out mask of its - # `concat_mutables`, so the `sample_choice` method should not - # be called - in_mutable.sample_choice() - - with pytest.raises(AssertionError): - # The mask of this in_mutable depends on the out mask of its - # `concat_mutables`, so the `min_choice` property should not - # be called - _ = in_mutable.min_choice - - def test_mutable_channel_number(self): - with pytest.raises(AssertionError): - # The candidate ratio should be in range(0, `num_channels`]. - OneShotMutableChannel( - num_channels=8, - candidate_choices=[0, 2, 4, 6, 8], - candidate_mode='number') - - with pytest.raises(AssertionError): - # Type of `candidate_choices` should be int. - OneShotMutableChannel( - num_channels=8, - candidate_choices=[0., 2, 4, 6, 8], - candidate_mode='number') - - # Test mutable out - out_mutable = OneShotMutableChannel( - num_channels=8, - candidate_choices=[2, 4, 6, 8], - candidate_mode='number') - - random_choice = out_mutable.sample_choice() - assert random_choice in [2, 4, 6, 8] - - max_choice = out_mutable.max_choice - assert max_choice == 8 - out_mutable.current_choice = max_choice - assert torch.equal(out_mutable.current_mask, - torch.ones_like(out_mutable.current_mask).bool()) - - min_choice = out_mutable.min_choice - assert min_choice == 2 - out_mutable.current_choice = min_choice - min_mask = torch.zeros_like(out_mutable.current_mask).bool() - min_mask[:2] = True - assert torch.equal(out_mutable.current_mask, min_mask) diff --git a/tests/test_models/test_mutables/test_derived_mutable.py b/tests/test_models/test_mutables/test_derived_mutable.py index 99da8dc71..3e87b0654 100644 --- a/tests/test_models/test_mutables/test_derived_mutable.py +++ b/tests/test_models/test_mutables/test_derived_mutable.py @@ -4,18 +4,15 @@ import pytest import torch -from mmrazor.models.mutables import (DerivedMutable, OneShotMutableChannel, - OneShotMutableValue) +from mmrazor.models.mutables import (DerivedMutable, OneShotMutableValue, + SquentialMutableChannel) from mmrazor.models.mutables.base_mutable import BaseMutable class TestDerivedMutable(TestCase): def test_is_fixed(self) -> None: - mc = OneShotMutableChannel( - num_channels=10, - candidate_choices=[2, 8, 10], - candidate_mode='number') + mc = SquentialMutableChannel(num_channels=10) mc.current_choice = 2 mv = OneShotMutableValue(value_list=[2, 3, 4]) @@ -46,10 +43,7 @@ def test_fix_dump_chosen(self) -> None: derived_mutable.fix_chosen(derived_mutable.dump_chosen()) def test_derived_same_mutable(self) -> None: - mc = OneShotMutableChannel( - num_channels=3, - candidate_choices=[1, 2, 3], - candidate_mode='number') + mc = SquentialMutableChannel(num_channels=3) mc_derived = mc.derive_same_mutable() assert mc_derived.source_mutables == {mc} @@ -59,10 +53,8 @@ def test_derived_same_mutable(self) -> None: torch.tensor([1, 1, 0], dtype=torch.bool)) def test_mutable_concat_derived(self) -> None: - mc1 = OneShotMutableChannel( - num_channels=3, candidate_choices=[1, 3], candidate_mode='number') - mc2 = OneShotMutableChannel( - num_channels=4, candidate_choices=[1, 4], candidate_mode='number') + mc1 = SquentialMutableChannel(num_channels=3) + mc2 = SquentialMutableChannel(num_channels=4) ms = [mc1, mc2] mc_derived = DerivedMutable.derive_concat_mutable(ms) @@ -88,10 +80,7 @@ def test_mutable_concat_derived(self) -> None: _ = DerivedMutable.derive_concat_mutable(ms) def test_mutable_channel_derived(self) -> None: - mc = OneShotMutableChannel( - num_channels=3, - candidate_choices=[1, 2, 3], - candidate_mode='number') + mc = SquentialMutableChannel(num_channels=3) mc_derived = mc * 3 assert mc_derived.source_mutables == {mc} @@ -112,10 +101,7 @@ def test_mutable_channel_derived(self) -> None: mc_derived.current_mask.size()) def test_mutable_divide(self) -> None: - mc = OneShotMutableChannel( - num_channels=128, - candidate_choices=[112, 120, 128], - candidate_mode='number') + mc = SquentialMutableChannel(num_channels=128) mc_derived = mc // 8 assert mc_derived.source_mutables == {mc} @@ -138,14 +124,15 @@ def test_mutable_divide(self) -> None: assert mv_derived.current_choice == 16 def test_source_mutables(self) -> None: - useless_fn = lambda x: x # noqa: E731 + + def useless_fn(x): + return x # noqa: E731 + with pytest.raises(RuntimeError): _ = DerivedMutable(choice_fn=useless_fn) - mc1 = OneShotMutableChannel( - num_channels=3, candidate_choices=[1, 3], candidate_mode='number') - mc2 = OneShotMutableChannel( - num_channels=4, candidate_choices=[1, 4], candidate_mode='number') + mc1 = SquentialMutableChannel(num_channels=3) + mc2 = SquentialMutableChannel(num_channels=4) ms = [mc1, mc2] mc_derived1 = DerivedMutable.derive_concat_mutable(ms) @@ -180,8 +167,7 @@ def fn(): mask_fn=dict_closure_fn({2: [mc1, mc2]}, {3: dd_mutable})) assert ddd_mutable.source_mutables == mc_derived1.source_mutables - mc3 = OneShotMutableChannel( - num_channels=4, candidate_choices=[2, 4], candidate_mode='number') + mc3 = SquentialMutableChannel(num_channels=4) dddd_mutable = DerivedMutable( choice_fn=dict_closure_fn({ mc1: [2, 3], @@ -191,10 +177,8 @@ def fn(): assert dddd_mutable.source_mutables == {mc1, mc2, mc3} def test_nested_mutables(self) -> None: - source_a = OneShotMutableChannel( - num_channels=2, candidate_choices=[1, 2], candidate_mode='number') - source_b = OneShotMutableChannel( - num_channels=3, candidate_choices=[2, 3], candidate_mode='number') + source_a = SquentialMutableChannel(num_channels=2) + source_b = SquentialMutableChannel(num_channels=3) # derive from derived_c = source_a * 1 diff --git a/tests/test_models/test_mutables/test_dynamic_layer.py b/tests/test_models/test_mutables/test_dynamic_layer.py deleted file mode 100644 index 6864e5edd..000000000 --- a/tests/test_models/test_mutables/test_dynamic_layer.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest import TestCase - -import torch -from torch import nn - -from mmrazor.models.mutators.utils import (dynamic_bn_converter, - dynamic_conv2d_converter, - dynamic_gn_converter, - dynamic_in_converter, - dynamic_linear_converter) - - -class TestDynamicLayer(TestCase): - - def test_dynamic_conv(self): - imgs = torch.rand(2, 8, 16, 16) - - in_channels_cfg = dict( - type='OneShotMutableChannel', - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - out_channels_cfg = dict( - type='OneShotMutableChannel', - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - conv = nn.Conv2d(8, 8, 1) - dynamic_conv = dynamic_conv2d_converter(conv, in_channels_cfg, - out_channels_cfg) - # test forward - dynamic_conv(imgs) - - conv = nn.Conv2d(8, 8, 1, groups=8) - dynamic_conv = dynamic_conv2d_converter(conv, in_channels_cfg, - out_channels_cfg) - # test forward - dynamic_conv(imgs) - - conv = nn.Conv2d(8, 8, 1, groups=4) - dynamic_conv = dynamic_conv2d_converter(conv, in_channels_cfg, - out_channels_cfg) - # test forward - with self.assertRaisesRegex(NotImplementedError, - 'only support pruning the depth-wise'): - dynamic_conv(imgs) - - def test_dynamic_linear(self): - imgs = torch.rand(2, 8) - - in_features_cfg = dict( - type='OneShotMutableChannel', - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - out_features_cfg = dict( - type='OneShotMutableChannel', - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - linear = nn.Linear(8, 8) - dynamic_linear = dynamic_linear_converter(linear, in_features_cfg, - out_features_cfg) - # test forward - dynamic_linear(imgs) - - def test_dynamic_batchnorm(self): - imgs = torch.rand(2, 8, 16, 16) - - num_features_cfg = dict( - type='OneShotMutableChannel', - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - bn = nn.BatchNorm2d(8) - dynamic_bn = dynamic_bn_converter(bn, num_features_cfg) - # test forward - dynamic_bn(imgs) - - bn = nn.BatchNorm2d(8, momentum=0) - dynamic_bn = dynamic_bn_converter(bn, num_features_cfg) - # test forward - dynamic_bn(imgs) - - bn = nn.BatchNorm2d(8) - bn.train() - dynamic_bn = dynamic_bn_converter(bn, num_features_cfg) - # test forward - dynamic_bn(imgs) - # test num_batches_tracked is not None - dynamic_bn(imgs) - - bn = nn.BatchNorm2d(8, affine=False) - dynamic_bn = dynamic_bn_converter(bn, num_features_cfg) - # test forward - dynamic_bn(imgs) - - bn = nn.BatchNorm2d(8, track_running_stats=False) - dynamic_bn = dynamic_bn_converter(bn, num_features_cfg) - # test forward - dynamic_bn(imgs) - - def test_dynamic_instancenorm(self): - imgs = torch.rand(2, 8, 16, 16) - - num_features_cfg = dict( - type='OneShotMutableChannel', - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - instance_norm = nn.InstanceNorm2d(8) - dynamic_in = dynamic_in_converter(instance_norm, num_features_cfg) - # test forward - dynamic_in(imgs) - - instance_norm = nn.InstanceNorm2d(8, affine=False) - dynamic_in = dynamic_in_converter(instance_norm, num_features_cfg) - # test forward - dynamic_in(imgs) - - instance_norm = nn.InstanceNorm2d(8, track_running_stats=False) - dynamic_in = dynamic_in_converter(instance_norm, num_features_cfg) - # test forward - dynamic_in(imgs) - - def test_dynamic_groupnorm(self): - imgs = torch.rand(2, 8, 16, 16) - - num_channels_cfg = dict( - type='OneShotMutableChannel', - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - gn = nn.GroupNorm(num_groups=4, num_channels=8) - dynamic_gn = dynamic_gn_converter(gn, num_channels_cfg) - # test forward - dynamic_gn(imgs) - - gn = nn.GroupNorm(num_groups=4, num_channels=8, affine=False) - dynamic_gn = dynamic_gn_converter(gn, num_channels_cfg) - # test forward - dynamic_gn(imgs) diff --git a/tests/test_models/test_mutables/test_mutable_channel/__init__.py b/tests/test_models/test_mutables/test_mutable_channel/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_models/test_mutables/test_mutable_channel/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py b/tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py new file mode 100644 index 000000000..c93a43842 --- /dev/null +++ b/tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import pytest +import torch + +from mmrazor.models.mutables import (SimpleMutableChannel, + SquentialMutableChannel) + + +class TestMutableChannels(unittest.TestCase): + + def test_SquentialMutableChannel(self): + mutable_channel = SquentialMutableChannel(4) + mutable_channel.current_choice = 3 + self.assertEqual(mutable_channel.activated_channels, + mutable_channel.current_choice) + self.assertTrue( + (mutable_channel.current_mask == torch.tensor([1, 1, 1, + 0]).bool()).all()) + channel_str = mutable_channel.__repr__() + self.assertEqual( + channel_str, + 'SquentialMutableChannel(num_channels=4, activated_channels=3)') + + mutable_channel.fix_chosen() + mutable_channel.dump_chosen() + + def test_SimpleMutableChannel(self): + channel = SimpleMutableChannel(4) + channel.current_choice = torch.tensor([1, 0, 0, 0]).bool() + self.assertEqual(channel.activated_channels, 1) + channel.fix_chosen() + with pytest.raises(NotImplementedError): + channel.dump_chosen() diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_sequential_mutable_channel.py b/tests/test_models/test_mutables/test_mutable_channel/test_sequential_mutable_channel.py new file mode 100644 index 000000000..253084d07 --- /dev/null +++ b/tests/test_models/test_mutables/test_mutable_channel/test_sequential_mutable_channel.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmrazor.models.mutables import SquentialMutableChannel + + +class TestSquentialMutableChannel(TestCase): + + def _test_mutable(self, + mutable: SquentialMutableChannel, + set_choice, + get_choice, + activate_channels, + mask=None): + mutable.current_choice = set_choice + assert mutable.current_choice == get_choice + assert mutable.activated_channels == activate_channels + if mask is not None: + assert (mutable.current_mask == mask).all() + + def _generate_mask(self, num: int, all: int): + mask = torch.zeros([all]) + mask[0:num] = 1 + return mask.bool() + + def test_mul_float(self): + channel = SquentialMutableChannel(10) + new_channel = channel * 0.5 + self.assertEqual(new_channel.current_choice, 5) + channel.current_choice = 5 + self.assertEqual(new_channel.current_choice, 2) + + def test_int_choice(self): + channel = SquentialMutableChannel(10) + self._test_mutable(channel, 5, 5, 5, self._generate_mask(5, 10)) + self._test_mutable(channel, 0.2, 2, 2, self._generate_mask(2, 10)) + + def test_float_choice(self): + channel = SquentialMutableChannel(10, choice_mode='ratio') + self._test_mutable(channel, 0.5, 0.5, 5, self._generate_mask(5, 10)) + self._test_mutable(channel, 2, 0.2, 2, self._generate_mask(2, 10)) diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_units/__init__.py b/tests/test_models/test_mutables/test_mutable_channel/test_units/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_models/test_mutables/test_mutable_channel/test_units/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_units/test_l1_mutable_channel_unit.py b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_l1_mutable_channel_unit.py new file mode 100644 index 000000000..f1a0d8529 --- /dev/null +++ b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_l1_mutable_channel_unit.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch.nn as nn + +from mmrazor.models.mutables import L1MutableChannelUnit +from mmrazor.models.mutators import ChannelMutator +from .....data.models import LineModel + + +class TestL1MutableChannelUnit(TestCase): + + def test_init(self): + model = LineModel() + mutator = ChannelMutator( + channel_unit_cfg={ + 'type': 'L1MutableChannelUnit', + 'default_args': { + 'choice_mode': 'ratio' + } + }) + mutator.prepare_from_supernet(model) + mutator.set_choices(mutator.sample_choices()) + print(mutator.units) + print(mutator.mutable_units) + print(mutator.choice_template) + + def test_convnd(self): + unit = L1MutableChannelUnit(8) + conv = nn.Conv3d(3, 8, 3) + norm = unit._get_l1_norm(conv, 0, 8) + self.assertSequenceEqual(norm.shape, [8]) diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_units/test_mutable_channel_units.py b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_mutable_channel_units.py new file mode 100644 index 000000000..ad5b5e56b --- /dev/null +++ b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_mutable_channel_units.py @@ -0,0 +1,147 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List +from unittest import TestCase + +import torch +import torch.nn as nn + +from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin +from mmrazor.models.mutables.mutable_channel import ( + L1MutableChannelUnit, MutableChannelUnit, SequentialMutableChannelUnit) +from mmrazor.models.mutables.mutable_channel.units.channel_unit import ( # noqa + Channel, ChannelUnit) +from mmrazor.structures.graph import ModuleGraph as ModuleGraph +from .....data.models import LineModel +from .....test_core.test_graph.test_graph import TestGraph + +MUTABLE_CFG = dict(type='SimpleMutablechannel') +PARSE_CFG = dict( + type='BackwardTracer', + loss_calculator=dict(type='ImageClassifierPseudoLoss')) + +# DEVICE = torch.device('cuda:0') if torch.cuda.is_available() \ +# else torch.device('cpu') +DEVICE = torch.device('cpu') +GROUPS: List[MutableChannelUnit] = [ + L1MutableChannelUnit, SequentialMutableChannelUnit +] + +DefaultChannelUnit = SequentialMutableChannelUnit + + +class TestMutableChannelUnit(TestCase): + + def test_init_from_graph(self): + model = LineModel() + # init using tracer + graph = ModuleGraph.init_from_backward_tracer(model) + units = DefaultChannelUnit.init_from_graph(graph) + self._test_units(units, model) + + def test_init_from_cfg(self): + model = LineModel() + # init using tracer + + config = { + 'init_args': { + 'num_channels': 8 + }, + 'channels': { + 'input_related': [{ + 'name': 'net.1', + 'start': 0, + 'end': 8, + 'expand_ratio': 1, + 'is_output_channel': False + }, { + 'name': 'net.3', + 'start': 0, + 'end': 8, + 'expand_ratio': 1, + 'is_output_channel': False + }], + 'output_related': [{ + 'name': 'net.0', + 'start': 0, + 'end': 8, + 'expand_ratio': 1, + 'is_output_channel': True + }, { + 'name': 'net.1', + 'start': 0, + 'end': 8, + 'expand_ratio': 1, + 'is_output_channel': True + }] + } + } + units = [DefaultChannelUnit.init_from_cfg(model, config)] + self._test_units(units, model) + + def test_init_from_channel_unit(self): + model = LineModel() + # init using tracer + graph = ModuleGraph.init_from_backward_tracer(model) + units: List[ChannelUnit] = ChannelUnit.init_from_graph(graph) + mutable_units = [ + DefaultChannelUnit.init_from_channel_unit(unit) for unit in units + ] + self._test_units(mutable_units, model) + + def _test_units(self, units: List[MutableChannelUnit], model): + for unit in units: + unit.prepare_for_pruning(model) + mutable_units = [unit for unit in units if unit.is_mutable] + self.assertGreaterEqual(len(mutable_units), 1) + for unit in mutable_units: + choice = unit.sample_choice() + unit.current_choice = choice + self.assertAlmostEqual(unit.current_choice, choice, delta=0.1) + x = torch.rand([2, 3, 224, 224]).to(DEVICE) + y = model(x) + self.assertSequenceEqual(y.shape, [2, 1000]) + + def _test_a_model_from_backward_tracer(self, model): + model.eval() + model = model.to(DEVICE) + graph = ModuleGraph.init_from_backward_tracer(model) + self._test_a_graph(model, graph) + + def test_with_backward_tracer(self): + test_models = TestGraph.backward_tracer_passed_models() + for model_data in test_models: + with self.subTest(model=model_data): + model = model_data() + self._test_a_model_from_backward_tracer(model) + + def test_replace_with_dynamic_ops(self): + model_datas = TestGraph.backward_tracer_passed_models() + for model_data in model_datas: + for unit_type in GROUPS: + with self.subTest(model=model_data, unit=unit_type): + model: nn.Module = model_data() + graph = ModuleGraph.init_from_backward_tracer(model) + units: List[ + MutableChannelUnit] = unit_type.init_from_graph(graph) + for unit in units: + unit.prepare_for_pruning(model) + + for module in model.modules(): + if isinstance(module, nn.Conv2d)\ + and module.groups == module.in_channels\ + and module.groups == 1: + self.assertTrue( + isinstance(module, DynamicChannelMixin)) + if isinstance(module, nn.Linear): + self.assertTrue( + isinstance(module, DynamicChannelMixin)) + if isinstance(module, nn.BatchNorm2d): + self.assertTrue( + isinstance(module, DynamicChannelMixin)) + + def _test_a_graph(self, model, graph): + try: + units = DefaultChannelUnit.init_from_graph(graph) + self._test_units(units, model) + except Exception as e: + self.fail(f'{e}') diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_units/test_one_shot_mutable_channel_unit.py b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_one_shot_mutable_channel_unit.py new file mode 100644 index 000000000..690382596 --- /dev/null +++ b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_one_shot_mutable_channel_unit.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +from mmrazor.models.mutables import OneShotMutableChannelUnit + + +class TestSequentialMutableChannelUnit(TestCase): + + def test_init(self): + unit = OneShotMutableChannelUnit( + 48, [20, 30, 40], choice_mode='number', divisor=8) + self.assertSequenceEqual(unit.candidate_choices, [24, 32, 40]) + + unit = OneShotMutableChannelUnit( + 48, [0.3, 0.5, 0.7], choice_mode='ratio', divisor=8) + self.assertSequenceEqual(unit.candidate_choices, [1 / 3, 0.5, 2 / 3]) diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_units/test_sequential_mutable_channel_unit.py b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_sequential_mutable_channel_unit.py new file mode 100644 index 000000000..8981a8a21 --- /dev/null +++ b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_sequential_mutable_channel_unit.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +from mmrazor.models.mutables import SequentialMutableChannelUnit + + +class TestSequentialMutableChannelUnit(TestCase): + + def test_num(self): + unit = SequentialMutableChannelUnit(48) + unit.current_choice = 24 + self.assertEqual(unit.current_choice, 24) + + unit.current_choice = 0.5 + self.assertEqual(unit.current_choice, 24) + + def test_ratio(self): + unit = SequentialMutableChannelUnit(48, choice_mode='ratio') + unit.current_choice = 0.5 + self.assertEqual(unit.current_choice, 0.5) + unit.current_choice = 24 + self.assertEqual(unit.current_choice, 0.5) + + def test_divisor(self): + unit = SequentialMutableChannelUnit( + 48, choice_mode='number', divisor=8) + unit.current_choice = 20 + self.assertEqual(unit.current_choice, 24) + self.assertTrue(unit.sample_choice() % 8 == 0) + + unit = SequentialMutableChannelUnit(48, choice_mode='ratio', divisor=8) + unit.current_choice = 0.3 + self.assertEqual(unit.current_choice, 1 / 3) + + def test_config_template(self): + unit = SequentialMutableChannelUnit(48, choice_mode='ratio', divisor=8) + config = unit.config_template(with_init_args=True) + unit2 = SequentialMutableChannelUnit.init_from_cfg(None, config) + self.assertDictEqual( + unit2.config_template(with_init_args=True)['init_args'], + config['init_args']) diff --git a/tests/test_models/test_mutables/test_mutable_value.py b/tests/test_models/test_mutables/test_mutable_value.py index 0b5ed7947..d7d05b1d5 100644 --- a/tests/test_models/test_mutables/test_mutable_value.py +++ b/tests/test_models/test_mutables/test_mutable_value.py @@ -5,8 +5,8 @@ import pytest import torch -from mmrazor.models.mutables import (MutableValue, OneShotMutableChannel, - OneShotMutableValue) +from mmrazor.models.mutables import (MutableValue, OneShotMutableValue, + SquentialMutableChannel) class TestMutableValue(TestCase): @@ -87,8 +87,7 @@ def test_mul(self) -> None: _ = mv * 1.2 mv = MutableValue(value_list=[1, 2, 3], default_value=3) - mc = OneShotMutableChannel( - num_channels=4, candidate_choices=[2, 4], candidate_mode='number') + mc = SquentialMutableChannel(num_channels=4) with pytest.raises(TypeError): _ = mc * mv diff --git a/tests/test_models/test_mutators/test_channel_mutator.py b/tests/test_models/test_mutators/test_channel_mutator.py index 6fa3b6368..96908d807 100644 --- a/tests/test_models/test_mutators/test_channel_mutator.py +++ b/tests/test_models/test_mutators/test_channel_mutator.py @@ -1,282 +1,136 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os +import copy +import sys import unittest -from os.path import dirname +from typing import Union -import pytest import torch -from mmcls.models import * # noqa: F401,F403 -from torch import Tensor, nn -from torch.nn import Module -from mmrazor import digit_version -from mmrazor.models.mutables import SlimmableMutableChannel -from mmrazor.models.mutators import (OneShotChannelMutator, - SlimmableChannelMutator) -from mmrazor.models.mutators.utils import (dynamic_bn_converter, - dynamic_conv2d_converter) +from mmrazor.models.mutables.mutable_channel import ( + L1MutableChannelUnit, SequentialMutableChannelUnit) +from mmrazor.models.mutators.channel_mutator import ChannelMutator from mmrazor.registry import MODELS -from .utils import load_and_merge_channel_cfgs +from ...data.models import DynamicLinearModel +from ...test_core.test_graph.test_graph import TestGraph -ONESHOT_MUTATOR_CFG = dict( - type='OneShotChannelMutator', - tracer_cfg=dict( - type='BackwardTracer', - loss_calculator=dict(type='ImageClassifierPseudoLoss')), - mutable_cfg=dict( - type='OneShotMutableChannel', - candidate_choices=[ - 1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0 - ], - candidate_mode='ratio')) - -ONESHOT_MUTATOR_CFG_WITHOUT_TRACER = dict( - type='OneShotChannelMutator', - mutable_cfg=dict( - type='OneShotMutableChannel', - candidate_choices=[ - 1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0 - ], - candidate_mode='ratio')) - - -class MultiConcatModel(Module): - - def __init__(self) -> None: - super().__init__() - - self.op1 = nn.Conv2d(3, 8, 1) - self.op2 = nn.Conv2d(3, 8, 1) - self.op3 = nn.Conv2d(16, 8, 1) - self.op4 = nn.Conv2d(3, 8, 1) - - def forward(self, x: Tensor) -> Tensor: - x1 = self.op1(x) - x2 = self.op2(x) - cat1 = torch.cat([x1, x2], dim=1) - x3 = self.op3(cat1) - x4 = self.op4(x) - output = torch.cat([x3, x4], dim=1) - - return output - - -class MultiConcatModel2(Module): - - def __init__(self) -> None: - super().__init__() - - self.op1 = nn.Conv2d(3, 8, 1) - self.op2 = nn.Conv2d(3, 8, 1) - self.op3 = nn.Conv2d(3, 8, 1) - self.op4 = nn.Conv2d(24, 8, 1) - - def forward(self, x: Tensor) -> Tensor: - x1 = self.op1(x) - x2 = self.op2(x) - x3 = self.op3(x) - cat1 = torch.cat([x1, x2], dim=1) - cat2 = torch.cat([cat1, x3], dim=1) - output = self.op4(cat2) - - return output - - -class ConcatModel(Module): - - def __init__(self) -> None: - super().__init__() - - self.op1 = nn.Conv2d(3, 8, 1) - self.bn1 = nn.BatchNorm2d(8) - self.op2 = nn.Conv2d(3, 8, 1) - self.bn2 = nn.BatchNorm2d(8) - self.op3 = nn.Conv2d(16, 8, 1) - - def forward(self, x: Tensor) -> Tensor: - x1 = self.bn1(self.op1(x)) - x2 = self.bn2(self.op2(x)) - cat1 = torch.cat([x1, x2], dim=1) - x3 = self.op3(cat1) - - return x3 - - -class ResBlock(Module): - - def __init__(self) -> None: - super().__init__() - - self.op1 = nn.Conv2d(3, 8, 1) - self.bn1 = nn.BatchNorm2d(8) - self.op2 = nn.Conv2d(8, 8, 1) - self.bn2 = nn.BatchNorm2d(8) - self.op3 = nn.Conv2d(8, 8, 1) - - def forward(self, x: Tensor) -> Tensor: - x1 = self.bn1(self.op1(x)) - x2 = self.bn2(self.op2(x1)) - x3 = self.op3(x2 + x1) - return x3 - - -class DynamicResBlock(Module): - - def __init__(self, mutable_cfg) -> None: - super().__init__() - - self.dynamic_op1 = dynamic_conv2d_converter( - nn.Conv2d(3, 8, 1), mutable_cfg, mutable_cfg) - self.dynamic_bn1 = dynamic_bn_converter( - nn.BatchNorm2d(8), mutable_cfg, mutable_cfg) - self.dynamic_op2 = dynamic_conv2d_converter( - nn.Conv2d(8, 8, 1), mutable_cfg, mutable_cfg) - self.dynamic_bn2 = dynamic_bn_converter( - nn.BatchNorm2d(8), mutable_cfg, mutable_cfg) - self.dynamic_op3 = dynamic_conv2d_converter( - nn.Conv2d(8, 8, 1), mutable_cfg, mutable_cfg) - self._add_link() - - def _add_link(self): - op1_mutable_out = self.dynamic_op1.mutable_out - bn1_mutable_out = self.dynamic_bn1.mutable_out - - op2_mutable_in = self.dynamic_op2.mutable_in - op2_mutable_out = self.dynamic_op2.mutable_out - bn2_mutable_out = self.dynamic_bn2.mutable_out - - op3_mutable_in = self.dynamic_op3.mutable_in - - bn1_mutable_out.register_same_mutable(op1_mutable_out) - op1_mutable_out.register_same_mutable(bn1_mutable_out) - - op2_mutable_in.register_same_mutable(bn1_mutable_out) - bn1_mutable_out.register_same_mutable(op2_mutable_in) - - bn2_mutable_out.register_same_mutable(op2_mutable_out) - op2_mutable_out.register_same_mutable(bn2_mutable_out) - - op3_mutable_in.register_same_mutable(bn1_mutable_out) - bn1_mutable_out.register_same_mutable(op3_mutable_in) - - op3_mutable_in.register_same_mutable(bn2_mutable_out) - bn2_mutable_out.register_same_mutable(op3_mutable_in) - - def forward(self, x: Tensor) -> Tensor: - x1 = self.dynamic_bn1(self.dynamic_op1(x)) - x2 = self.dynamic_bn2(self.dynamic_op2(x1)) - x3 = self.dynamic_op3(x2 + x1) - return x3 - - -@unittest.skipIf( - digit_version(torch.__version__) == digit_version('1.8.1'), - 'PyTorch version 1.8.1 is not supported by the Backward Tracer.') -def test_oneshot_channel_mutator() -> None: - imgs = torch.randn(16, 3, 224, 224) - - def _test(model): - mutator.prepare_from_supernet(model) - assert hasattr(mutator, 'name2module') - - # test set_min_choices - mutator.set_min_choices() - for mutables in mutator.search_groups.values(): - for mutable in mutables: - # 1 / 8 is the minimum candidate ratio - assert mutable.current_choice == round(1 / 8 * - mutable.num_channels) - - # test set_max_channel - mutator.set_max_choices() - for mutables in mutator.search_groups.values(): - for mutable in mutables: - # 1.0 is the maximum candidate ratio - assert mutable.current_choice == round(1. * - mutable.num_channels) - - # test making groups logic - choice_dict = mutator.sample_choices() - assert isinstance(choice_dict, dict) - mutator.set_choices(choice_dict) - model(imgs) - - mutator: OneShotChannelMutator = MODELS.build(ONESHOT_MUTATOR_CFG) - with pytest.raises(RuntimeError): - _ = mutator.search_groups - with pytest.raises(RuntimeError): - _ = mutator.name2module - - _test(ResBlock()) - _test(MultiConcatModel()) - _test(MultiConcatModel2()) - _test(nn.Sequential(nn.BatchNorm2d(3))) - - mutator: OneShotChannelMutator = MODELS.build( - ONESHOT_MUTATOR_CFG_WITHOUT_TRACER) - dynamic_model = DynamicResBlock( - ONESHOT_MUTATOR_CFG_WITHOUT_TRACER['mutable_cfg']) - _test(dynamic_model) - - -def test_slimmable_channel_mutator() -> None: - imgs = torch.randn(16, 3, 224, 224) - - root_path = dirname(dirname(dirname(__file__))) - channel_cfg_paths = [ - os.path.join(root_path, 'data/subnet1.yaml'), - os.path.join(root_path, 'data/subnet2.yaml') - ] - - mutator = SlimmableChannelMutator( - mutable_cfg=dict(type='SlimmableMutableChannel'), - channel_cfgs=load_and_merge_channel_cfgs(channel_cfg_paths), - tracer_cfg=dict( - type='BackwardTracer', - loss_calculator=dict(type='ImageClassifierPseudoLoss'))) - - model = ResBlock() - mutator.prepare_from_supernet(model) - mutator.switch_choices(0) - for name, module in model.named_modules(): - if isinstance(module, SlimmableMutableChannel): - assert module.current_choice == 0 - _ = model(imgs) - - mutator.switch_choices(1) - for name, module in model.named_modules(): - if isinstance(module, SlimmableMutableChannel): - assert module.current_choice == 1 - _ = model(imgs) - - channel_cfg_paths = [ - os.path.join(root_path, 'data/concat_subnet1.yaml'), - os.path.join(root_path, 'data/concat_subnet2.yaml') - ] - - mutator = SlimmableChannelMutator( - mutable_cfg=dict(type='SlimmableMutableChannel'), - channel_cfgs=load_and_merge_channel_cfgs(channel_cfg_paths), - tracer_cfg=dict( - type='BackwardTracer', - loss_calculator=dict(type='ImageClassifierPseudoLoss'))) - - model = ConcatModel() - - mutator.prepare_from_supernet(model) - - for name, module in model.named_modules(): - if isinstance(module, SlimmableMutableChannel): - assert module.choices == [0, 1] - - mutator.switch_choices(0) - for name, module in model.named_modules(): - if isinstance(module, SlimmableMutableChannel): - assert module.current_choice == 0 - _ = model(imgs) - - mutator.switch_choices(1) - for name, module in model.named_modules(): - if isinstance(module, SlimmableMutableChannel): - assert module.current_choice == 1 - _ = model(imgs) +sys.setrecursionlimit(2000) + + +@MODELS.register_module() +class RandomChannelUnit(SequentialMutableChannelUnit): + + def generate_mask(self, choice: Union[int, float]) -> torch.Tensor: + if isinstance(choice, float): + choice = max(1, int(self.num_channels * choice)) + assert 0 < choice <= self.num_channels + rand_imp = torch.rand([self.num_channels]) + ind = rand_imp.topk(choice)[1] + mask = torch.zeros([self.num_channels]) + mask.scatter_(-1, ind, 1) + return mask + + +DATA_UNITS = [ + SequentialMutableChannelUnit, RandomChannelUnit, L1MutableChannelUnit +] + + +class TestChannelMutator(unittest.TestCase): + + def _test_a_mutator(self, mutator: ChannelMutator, model): + choices = mutator.sample_choices() + mutator.set_choices(choices) + self.assertGreater(len(mutator.mutable_units), 0) + x = torch.rand([2, 3, 224, 224]) + y = model(x) + self.assertEqual(list(y.shape), [2, 1000]) + + def test_sample_subnet(self): + data_models = TestGraph.backward_tracer_passed_models() + + for i, data in enumerate(data_models): + with self.subTest(i=i, data=data): + model = data() + + mutator = ChannelMutator() + mutator.prepare_from_supernet(model) + + self.assertGreaterEqual(len(mutator.mutable_units), 1) + + self._test_a_mutator(mutator, model) + + def test_generic_support(self): + data_models = TestGraph.backward_tracer_passed_models() + + for data_model in data_models[:1]: + for unit_type in DATA_UNITS: + with self.subTest(model=data_model, unit=unit_type): + + model = data_model() + + mutator = ChannelMutator(channel_unit_cfg=unit_type) + mutator.prepare_from_supernet(model) + mutator.units + + self._test_a_mutator(mutator, model) + + def test_init_units_from_cfg(self): + ARCHITECTURE_CFG = dict( + type='mmcls.ImageClassifier', + backbone=dict(type='mmcls.MobileNetV2', widen_factor=1.5), + neck=dict(type='mmcls.GlobalAveragePooling'), + head=dict( + type='mmcls.LinearClsHead', + num_classes=1000, + in_channels=1920, + loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5))) + model = MODELS.build(ARCHITECTURE_CFG) + + # generate config + model1 = copy.deepcopy(model) + mutator = ChannelMutator() + mutator.prepare_from_supernet(model1) + config = mutator.config_template( + with_channels=True, with_unit_init_args=True) + + # test passing config + model2 = copy.deepcopy(model) + config2 = copy.deepcopy(config) + config2['parse_cfg'] = {'type': 'Config'} + mutator2 = MODELS.build(config2) + mutator2.prepare_from_supernet(model2) + self.assertEqual( + len(mutator.mutable_units), len(mutator2.mutable_units)) + self._test_a_mutator(mutator2, model2) + + def test_mix_config_tracer(self): + model = TestGraph.backward_tracer_passed_models()[0]() + + model0 = copy.deepcopy(model) + mutator0 = ChannelMutator() + mutator0.prepare_from_supernet(model0) + config = mutator0.config_template(with_unit_init_args=True) + + model1 = copy.deepcopy(model) + mutator1 = MODELS.build(config) + mutator1.prepare_from_supernet(model1) + config1 = mutator1.config_template(with_unit_init_args=True) + + self.assertDictEqual(config1, config) + self._test_a_mutator(mutator1, model1) + + def test_models_with_predefined_dynamic_op(self): + for Model in [ + DynamicLinearModel, + ]: + with self.subTest(model=Model): + model = Model() + mutator = ChannelMutator( + channel_unit_cfg={ + 'type': 'OneShotMutableChannelUnit', + 'default_args': {} + }, + parse_cfg={'type': 'Predefined'}) + mutator.prepare_from_supernet(model) + self._test_a_mutator(mutator, model) diff --git a/tests/test_models/test_mutators/test_classical_models/test_mbv2_channel_mutator.py b/tests/test_models/test_mutators/test_classical_models/test_mbv2_channel_mutator.py deleted file mode 100644 index 61ec34565..000000000 --- a/tests/test_models/test_mutators/test_classical_models/test_mbv2_channel_mutator.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import unittest -from os.path import dirname - -import torch -from mmcls.models import * # noqa: F401,F403 -from mmcls.structures import ClsDataSample - -from mmrazor import digit_version -from mmrazor.models.mutables import SlimmableMutableChannel -from mmrazor.models.mutators import (OneShotChannelMutator, - SlimmableChannelMutator) -from mmrazor.registry import MODELS -from ..utils import load_and_merge_channel_cfgs - -MODEL_CFG = dict( - _scope_='mmcls', - type='mmcls.ImageClassifier', - backbone=dict(type='MobileNetV2', widen_factor=1.5), - neck=dict(type='GlobalAveragePooling'), - head=dict( - type='LinearClsHead', - num_classes=1000, - in_channels=1920, - loss=dict(type='CrossEntropyLoss', loss_weight=1.0), - topk=(1, 5))) - -ONESHOT_MUTATOR_CFG = dict( - type='OneShotChannelMutator', - skip_prefixes=['head.fc'], - tracer_cfg=dict( - type='BackwardTracer', - loss_calculator=dict(type='ImageClassifierPseudoLoss')), - mutable_cfg=dict( - type='OneShotMutableChannel', - candidate_choices=[ - 1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0 - ], - candidate_mode='ratio')) - - -@unittest.skipIf( - digit_version(torch.__version__) == digit_version('1.8.1'), - 'PyTorch version 1.8.1 is not supported by the Backward Tracer.') -def test_oneshot_channel_mutator() -> None: - imgs = torch.randn(16, 3, 224, 224) - data_samples = [ - ClsDataSample().set_gt_label(torch.randint(0, 1000, (16, ))) - ] - - model = MODELS.build(MODEL_CFG) - mutator: OneShotChannelMutator = MODELS.build(ONESHOT_MUTATOR_CFG) - - mutator.prepare_from_supernet(model) - assert hasattr(mutator, 'name2module') - - # test set_min_choices - mutator.set_min_choices() - for mutables in mutator.search_groups.values(): - for mutable in mutables: - # 1 / 8 is the minimum candidate ratio - assert mutable.current_choice == round(1 / 8 * - mutable.num_channels) - - # test set_max_channel - mutator.set_max_choices() - for mutables in mutator.search_groups.values(): - for mutable in mutables: - # 1.0 is the maximum candidate ratio - assert mutable.current_choice == round(1. * mutable.num_channels) - - # test making groups logic - choice_dict = mutator.sample_choices() - assert isinstance(choice_dict, dict) - mutator.set_choices(choice_dict) - model(imgs, data_samples=data_samples, mode='loss') - - -def test_slimmable_channel_mutator() -> None: - imgs = torch.randn(16, 3, 224, 224) - data_samples = [ - ClsDataSample().set_gt_label(torch.randint(0, 1000, (16, ))) - ] - - root_path = dirname(dirname(dirname(dirname(__file__)))) - channel_cfg_paths = [ - os.path.join(root_path, 'data/MBV2_320M.yaml'), - os.path.join(root_path, 'data/MBV2_220M.yaml') - ] - - mutator = SlimmableChannelMutator( - mutable_cfg=dict(type='SlimmableMutableChannel'), - channel_cfgs=load_and_merge_channel_cfgs(channel_cfg_paths), - tracer_cfg=dict( - type='BackwardTracer', - loss_calculator=dict(type='ImageClassifierPseudoLoss'))) - - model = MODELS.build(MODEL_CFG) - mutator.prepare_from_supernet(model) - mutator.switch_choices(0) - for name, module in model.named_modules(): - if isinstance(module, SlimmableMutableChannel): - assert module.current_choice == 0 - model(imgs, data_samples=data_samples, mode='loss') - - mutator.switch_choices(1) - for name, module in model.named_modules(): - if isinstance(module, SlimmableMutableChannel): - assert module.current_choice == 1 - model(imgs, data_samples=data_samples, mode='loss') diff --git a/tests/test_models/test_mutators/test_one_shot_mutator.py b/tests/test_models/test_mutators/test_one_shot_mutator.py deleted file mode 100644 index 41921a2be..000000000 --- a/tests/test_models/test_mutators/test_one_shot_mutator.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy - -import pytest -from mmcls.models import * # noqa: F401,F403 -from torch import Tensor -from torch.nn import Module - -from mmrazor.models import OneShotModuleMutator, OneShotMutableModule -from mmrazor.registry import MODELS - -MODEL_CFG = dict( - type='mmcls.ImageClassifier', - backbone=dict( - type='mmcls.ResNet', - depth=50, - num_stages=4, - out_indices=(3, ), - style='pytorch'), - neck=dict(type='mmcls.GlobalAveragePooling'), - head=dict( - type='mmcls.LinearClsHead', - num_classes=1000, - in_channels=2048, - loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0), - topk=(1, 5), - )) - -MUTATOR_CFG = dict(type='OneShotModuleMutator') - -MUTABLE_CFG = dict( - type='OneShotMutableOP', - candidates=dict( - choice1=dict( - type='MBBlock', - in_channels=3, - out_channels=3, - expand_ratio=1, - kernel_size=3), - choice2=dict( - type='MBBlock', - in_channels=3, - out_channels=3, - expand_ratio=1, - kernel_size=5), - choice3=dict( - type='MBBlock', - in_channels=3, - out_channels=3, - expand_ratio=1, - kernel_size=7))) - - -def test_one_shot_mutator_normal_model() -> None: - model = MODELS.build(MODEL_CFG) - mutator: OneShotModuleMutator = MODELS.build(MUTATOR_CFG) - - assert mutator.mutable_class_type == OneShotMutableModule - - with pytest.raises(RuntimeError): - _ = mutator.search_groups - - mutator.prepare_from_supernet(model) - assert len(mutator.search_groups) == 0 - assert len(mutator.sample_choices()) == 0 - - -class _SearchableModel(Module): - - def __init__(self) -> None: - super().__init__() - - self.op1 = MODELS.build(MUTABLE_CFG) - self.op2 = MODELS.build(MUTABLE_CFG) - self.op3 = MODELS.build(MUTABLE_CFG) - - def forward(self, x: Tensor) -> Tensor: - x = self.op1(x) - x = self.op2(x) - x = self.op3(x) - - return x - - -def test_one_shot_mutator_mutable_model() -> None: - model = _SearchableModel() - mutator: OneShotModuleMutator = MODELS.build(MUTATOR_CFG) - - mutator.prepare_from_supernet(model) - assert list(mutator.search_groups.keys()) == [0, 1, 2] - - random_choices = mutator.sample_choices() - assert list(random_choices.keys()) == [0, 1, 2] - for choice in random_choices.values(): - assert choice in ['choice1', 'choice2', 'choice3'] - - custom_group = [['op1', 'op2'], ['op3']] - mutator_cfg = copy.deepcopy(MUTATOR_CFG) - mutator_cfg.update({'custom_group': custom_group}) - mutator = MODELS.build(mutator_cfg) - - mutator.prepare_from_supernet(model) - assert list(mutator.search_groups.keys()) == [0, 1] - - random_choices = mutator.sample_choices() - assert list(random_choices.keys()) == [0, 1] - for choice in random_choices.values(): - assert choice in ['choice1', 'choice2', 'choice3'] - - mutator.set_choices(random_choices) - - custom_group.append(['op4']) - mutator_cfg = copy.deepcopy(MUTATOR_CFG) - mutator_cfg.update({'custom_group': custom_group}) - mutator = MODELS.build(mutator_cfg) - with pytest.raises(AssertionError): - mutator.prepare_from_supernet(model) - - -if __name__ == '__main__': - pytest.main() diff --git a/tests/test_models/test_mutators/utils.py b/tests/test_models/test_mutators/utils.py deleted file mode 100644 index 7ddede648..000000000 --- a/tests/test_models/test_mutators/utils.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List - -from mmengine import fileio - -from mmrazor.models.algorithms import SlimmableNetwork - - -def load_and_merge_channel_cfgs(channel_cfg_paths: List[str]) -> Dict: - channel_cfgs = list() - for channel_cfg_path in channel_cfg_paths: - channel_cfg = fileio.load(channel_cfg_path) - channel_cfgs.append(channel_cfg) - - return SlimmableNetwork.merge_channel_cfgs(channel_cfgs) diff --git a/tests/test_utils/test_index_dict.py b/tests/test_utils/test_index_dict.py new file mode 100644 index 000000000..767dd806c --- /dev/null +++ b/tests/test_utils/test_index_dict.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +from mmrazor.utils.index_dict import IndexDict + + +class TestIndexDict(unittest.TestCase): + + def test_dict(self): + dict = IndexDict() + dict[(4, 5)] = 2 + dict[(1, 3)] = 1 + + self.assertSequenceEqual(list(dict.keys()), [(1, 3), (4, 5)]) + with self.assertRaisesRegex(AssertionError, 'overlap'): + dict[2, 3] = 3 diff --git a/tools/get_channel_units.py b/tools/get_channel_units.py new file mode 100644 index 000000000..dc8818fbf --- /dev/null +++ b/tools/get_channel_units.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json + +import torch.nn as nn +from mmengine import MODELS +from mmengine.config import Config + +from mmrazor.models import BaseAlgorithm +from mmrazor.models.mutators import ChannelMutator + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Get channel unit of a model.') + parser.add_argument('config', help='config of the model') + parser.add_argument( + '-c', + '--with-channel', + action='store_true', + help='output with channel config') + parser.add_argument( + '-i', + '--with-init-args', + action='store_true', + help='output with init args') + parser.add_argument( + '--choice', + action='store_true', + help=('output choices template. When this flag is activated, ' + '-c and -i will be ignored')) + parser.add_argument( + '-o', + '--output-path', + default='', + help='the file path to store channel unit info') + return parser.parse_args() + + +def main(): + args = parse_args() + config = Config.fromfile(args.config) + model = MODELS.build(config['model']) + if isinstance(model, BaseAlgorithm): + mutator = model.mutator + elif isinstance(model, nn.Module): + mutator = ChannelMutator() + mutator.prepare_from_supernet(model) + if args.choice: + config = mutator.choice_template + else: + config = mutator.config_template( + with_channels=args.with_channel, + with_unit_init_args=args.with_init_args) + json_config = json.dumps(config, indent=4, separators=(',', ':')) + if args.output_path == '': + print('=' * 100) + print('config template') + print('=' * 100) + print(json_config) + else: + with open(args.output_path, 'w') as file: + file.write(json_config) + + +if __name__ == '__main__': + main()