From a725b51eb871c7adc31c9f40fec5b8cdeb087e3d Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Thu, 12 Jan 2023 13:23:03 +0800 Subject: [PATCH] Merge dev-1.x into quantize (#430) * Fix a bug in make_divisible. (#333) fix bug in make_divisible Co-authored-by: liukai * [Fix] Fix counter mapping bug (#331) * fix counter mapping bug * move judgment into get_counter_type & update UT * [Docs]Add MMYOLO projects link (#334) * [Doc] fix typos in en/usr_guides (#299) * Update README.md * Update README_zh-CN.md Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> * [Features]Support `MethodInputsRecorder` and `FunctionInputsRecorder` (#320) * support MethodInputsRecorder and FunctionInputsRecorder * fix bugs that the model can not be pickled * WIP: add pytest for ema model * fix bugs in recorder and delivery when ema_hook is used * don't register the DummyDataset * fix pytest * [Feature] Add deit-base (#332) * WIP: support deit * WIP: add deithead * WIP: fix checkpoint hook * fix data preprocessor * fix cfg * WIP: add readme * reset single_teacher_distill * add metafile * add model to model-index * fix configs and readme * [Feature]Feature map visualization (#293) * WIP: vis * WIP: add visualization * WIP: add visualization hook * WIP: support razor visualizer * WIP * WIP: wrap draw_featmap * support feature map visualization * add a demo image for visualization * fix typos * change eps to 1e-6 * add pytest for visualization * fix vis hook * fix arguments' name * fix img path * support draw inference results * add visualization doc * fix figure url * move files Co-authored-by: weihan cao * [Feature] Add kd examples (#305) * support kd for mbv2 and shufflenetv2 * WIP: fix ckpt path * WIP: fix kd r34-r18 * add metafile * fix metafile * delete * [Doc] add documents about pruning. (#313) * init * update user guide * update images * update * update How to prune your model * update how_to_use_config_tool_of_pruning.md * update doc * move location * update * update * update * add mutablechannels.md * add references Co-authored-by: liukai Co-authored-by: jacky * [Feature] PyTorch version of `PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient`. (#304) * add pkd * add pytest for pkd * fix cfg * WIP: support fcos3d * WIP: support fcos3d pkd * support mmdet3d * fix cfgs * change eps to 1e-6 and add some comments * fix docstring * fix cfg * add assert * add type hint * WIP: add readme and metafile * fix readme * update metafiles and readme * fix metafile * fix pipeline figure * [Refactor] Refactor Mutables and Mutators (#324) * refactor mutables * update load fix subnet * add DumpChosen Typehint * adapt UTs * fix lint * Add GroupMixin to ChannelMutator (temporarily) * fix type hints * add GroupMixin doc-string * modified by comments * fix type hits * update subnet format * fix channel group bugs and add UTs * fix doc string * fix comments * refactor diff module forward * fix error in channel mutator doc * fix comments Co-authored-by: liukai * [Fix] Update readme (#341) * update kl readme * update dsnas readme * fix url * Bump version to 1.0.0rc1 (#338) update version * [Feature] Add Autoformer algorithm (#315) * update candidates * update subnet_sampler_loop * update candidate * add readme * rename variable * rename variable * clean * update * add doc string * Revert "[Improvement] Support for candidate multiple dimensional search constraints." * [Improvement] Update Candidate with multi-dim search constraints. (#322) * update doc * add support type * clean code * update candidates * clean * xx * set_resource -> set_score * fix ci bug * py36 lint * fix bug * fix check constrain * py36 ci * redesign candidate * fix pre-commit * update cfg * add build_resource_estimator * fix ci bug * remove runner.epoch in testcase * [Feature] Autoformer architecture and dynamicOPs (#327) * add DynamicSequential * dynamiclayernorm * add dynamic_pathchembed * add DynamicMultiheadAttention and DynamicRelativePosition2D * add channel-level dynamicOP * add autoformer algo * clean notes * adapt channel_mutator * vit fly * fix import * mutable init * remove annotation * add DynamicInputResizer * add unittest for mutables * add OneShotMutableChannelUnit_VIT * clean code * reset unit for vit * remove attr * add autoformer backbone UT * add valuemutator UT * clean code * add autoformer algo UT * update classifier UT * fix test error * ignore * make lint * update * fix lint * mutable_attrs * fix test * fix error * remove DynamicInputResizer * fix test ci * remove InputResizer * rename variables * modify type * Continued improvements of ChannelUnit * fix lint * fix lint * remove OneShotMutableChannelUnit * adjust derived type * combination mixins * clean code * fix sample subnet * search loop fly * more annotations * avoid counter warning and modify batch_augment cfg by gy * restore * source_value_mutables restriction * simply arch_setting api * update * clean * fix ut * [Feature] Add performance predictor (#306) * add predictor with 4 handlers * [Improvement] Update Candidate with multi-dim search constraints. (#322) * update doc * add support type * clean code * update candidates * clean * xx * set_resource -> set_score * fix ci bug * py36 lint * fix bug * fix check constrain * py36 ci * redesign candidate * fix pre-commit * update cfg * add build_resource_estimator * fix ci bug * remove runner.epoch in testcase * update metric_predictor: 1. update MetricPredictor; 2. add predictor config for searching; 3. add predictor in evolution_search_loop. * add UT for predictor * add MLPHandler * patch optional.txt for predictors * patch test_evolution_search_loop * refactor apis of predictor and handlers * fix ut and remove predictor_cfg in predictor * adapt new mutable & mutator design * fix ut * remove unness assert after rebase * move predictor-build in __init__ & simplify estimator-build Co-authored-by: Yue Sun * [Feature] Add DCFF (#295) * add ChannelGroup (#250) * rebase new dev-1.x * modification for adding config_template * add docstring to channel_group.py * add docstring to mutable_channel_group.py * rm channel_group_cfg from Graph2ChannelGroups * change choice type of SequentialChannelGroup from float to int * add a warning about group-wise conv * restore __init__ of dynamic op * in_channel_mutable -> mutable_in_channel * rm abstractproperty * add a comment about VT * rm registry for ChannelGroup * MUTABLECHANNELGROUP -> ChannelGroupType * refine docstring of IndexDict * update docstring * update docstring * is_prunable -> is_mutable * update docstring * fix error in pre-commit * update unittest * add return type * unify init_xxx apit * add unitest about init of MutableChannelGroup * update according to reviews * sequential_channel_group -> sequential_mutable_channel_group Co-authored-by: liukai * Add BaseChannelMutator and refactor Autoslim (#289) * add BaseChannelMutator * add autoslim * tmp * make SequentialMutableChannelGroup accpeted both of num and ratio as choice. and supports divisior * update OneShotMutableChannelGroup * pass supernet training of autoslim * refine autoslim * fix bug in OneShotMutableChannelGroup * refactor make_divisible * fix spell error: channl -> channel * init_using_backward_tracer -> init_from_backward_tracer init_from_fx_tracer -> init_from_fx_tracer * refine SequentialMutableChannelGroup * let mutator support models with dynamicop * support define search space in model * tracer_cfg -> parse_cfg * refine * using -> from * update docstring * update docstring Co-authored-by: liukai * tmpsave * migrate ut * tmpsave2 * add loss collector * refactor slimmable and add l1-norm (#291) * refactor slimmable and add l1-norm * make l1-norm support convnd * update get_channel_groups * add l1-norm_resnet34_8xb32_in1k.py * add pretrained to resnet34-l1 * remove old channel mutator * BaseChannelMutator -> ChannelMutator * update according to reviews * add readme to l1-norm * MBV2_slimmable -> MBV2_slimmable_config Co-authored-by: liukai * update config * fix md & pytorch support <1.9.0 in batchnorm init * Clean old codes. (#296) * remove old dynamic ops * move dynamic ops * clean old mutable_channels * rm OneShotMutableChannel * rm MutableChannel * refine * refine * use SquentialMutableChannel to replace OneshotMutableChannel * refactor dynamicops folder * let SquentialMutableChannel support float Co-authored-by: liukai * fix ci * ci fix py3.6.x & add mmpose * ci fix py3.6.9 in utils/index_dict.py * fix mmpose * minimum_version_cpu=3.7 * fix ci 3.7.13 * fix pruning &meta ci * support python3.6.9 * fix py3.6 import caused by circular import patch in py3.7 * fix py3.6.9 * Add channel-flow (#301) * base_channel_mutator -> channel_mutator * init * update docstring * allow omitting redundant configs for channel * add register_mutable_channel_to_a_module to MutableChannelContainer * update according to reviews 1 * update according to reviews 2 * update according to reviews 3 * remove old docstring * fix error * using->from * update according to reviews * support self-define input channel number * update docstring * chanenl -> channel_elem Co-authored-by: liukai Co-authored-by: jacky * support >=3.7 * support py3.6.9 * Rename: ChannelGroup -> ChannelUnit (#302) * refine repr of MutableChannelGroup * rename folder name * ChannelGroup -> ChannelUnit * filename in units folder * channel_group -> channel_unit * groups -> units * group -> unit * update * get_mutable_channel_groups -> get_mutable_channel_units * fix bug * refine docstring * fix ci * fix bug in tracer Co-authored-by: liukai * update new channel config format * update pruning refactor * update merged pruning * update commit * fix dynamic_conv_mixin * update comments: readme&dynamic_conv_mixins.py * update readme * move kl softmax channel pooling to op by comments * fix comments: fix redundant & split README.md * dcff in ItePruneAlgorithm * partial dynamic params for fuseconv * add step_freq & prune_time check * update comments * update comments * update comments * fix ut * fix gpu ut & revise step_freq in ItePruneAlgorithm * update readme * revise ItePruneAlgorithm * fix docs * fix dynamic_conv attr * fix ci Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai Co-authored-by: zengyi.vendor Co-authored-by: jacky * [Fix] Fix optional requirements (#357) * fix optional requirements * fix dcff ut * fix import with get_placeholder * supplement the previous commit * [Fix] Fix configs of wrn models and ofd. (#361) * 1.revise the configs of wrn22, wrn24, and wrn40. 2.revise the data_preprocessor of ofd_backbone_resnet50_resnet18_8xb16_cifar10 * 1.Add README for vanilla-wrm. * 1.Revise readme of wrn Co-authored-by: zhangzhongyu * [Fix] Fix bug on mmrazor visualization, mismatch argument in define and use. (#356) fix bug on mmrazor visualization, mismatch argument in define and use. Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com> * fix bug in benchmark_test (#364) fix bug in configs Co-authored-by: Your Name * [FIX] Fix wrn configs (#368) * fix wrn configs * fix wrn configs * update online wrn model weight * [Fix] fix bug on pkd config. Wrong import filename. (#373) * [CI] Update ci to torch1.13 (#380) update ci to torch1.13 * [Feature] Add BigNAS algorithm (#219) * add calibrate-bn-statistics * add test calibrate-bn-statistics * fix mixins * fix mixins * fix mixin tests * remove slimmable channel mutable and refactor dynamic op * refact dynamic batch norm * add progressive dynamic conv2d * add center crop dynamic conv2d * refactor dynamic directory * refactor dynamic sequential * rename length to depth in dynamic sequential * add test for derived mutable * refactor dynamic op * refactor api of dynamic op * add derive mutable mixin * addbignas algorithm * refactor bignas structure * add input resizer * add input resizer to bignas * move input resizer from algorithm into classifier * remove compnents * add attentive mobilenet * delete json file * nearly(less 0.2) align inference accuracy with gml * move mutate seperated in bignas mobilenet backbone * add zero_init_residual * add set_dropout * set dropout in bignas algorithm * fix registry * add subnet yaml and nearly align inference accuracy with gml * add rsb config for bignas * remove base in config * add gml bignas config * convert to iter based * bignas forward and backward fly * fix merge conflict * fix dynamicseq bug * fix bug and refactor bignas * arrange configs of bignas * fix typo * refactor attentive_mobilenet * fix channel mismatch due to registion of DerivedMutable * update bignas & fix se channel mismatch * add AutoAugmentV2 & remove unness configs * fix lint * recover channel assertion in channel unit * fix a group bug * fix comments * add docstring * add norm in dynamic_embed * fix search loop & other minor changes * fix se expansion * minor change * add ut for bignas & attentive_mobilenet * fix ut * update bignas readme * rm unness ut & supplement get_placeholder * fix lint * fix ut * add subnet deployment in downstream tasks. * minor change * update ofa backbone * minor fix * Continued improvements of searchable backbone * minor change * drop ratio in backbone * fix comments * fix ci test * fix test * add dynamic shortcut UT * modify strategy to fit bignas * fix test * fix bug in neck * fix error * fix error * fix yaml * save subnet ckpt * merge autoslim_val/test_loop into subnet_val_loop * move calibrate_bn_mixin to utils * fix bugs and add docstring * clean code * fix register bug * clean code * update Co-authored-by: wangshiguang Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: aptsunny Co-authored-by: sunyue1 * [Bug] Fix ckpt (#372) fix ckpt * [Feature] Add tools to convert distill ckpt to student-only ckpt. (#381) * [Feature] Add tools to convert distill ckpt to student-only ckpt. * fix bug. * add --model-only to only save model. * Make changes accroding to PR review. * Enhance the Abilities of the Tracer for Pruning. (#371) * tmp * add new mmdet models * add docstring * pass test and pre-commit * rm razor tracer * update fx tracer, now it can automatically wrap methods and functions. * update tracer passed models * add warning for torch <1.12.0 fix bug for python3.6 update placeholder to support placeholder.XXX * fix bug * update docs * fix lint * fix parse_cfg in configs * restore mutablechannel * test ite prune algorithm when using dist * add get_model_from_path to MMModelLibrrary * add mm models to DefaultModelLibrary * add uts * fix bug * fix bug * add uts * add uts * add uts * add uts * fix bug * restore ite_prune_algorithm * update doc * PruneTracer -> ChannelAnalyzer * prune_tracer -> channel_analyzer * add test for fxtracer * fix bug * fix bug * PruneTracer -> ChannelAnalyzer refine * CustomFxTracer -> MMFxTracer * fix bug when test with torch<1.12 * update print log * fix lint * rm unuseful code Co-authored-by: liukai Co-authored-by: jacky Co-authored-by: Your Name Co-authored-by: liukai * fix bug in placer holder (#395) * fix bug in placer holder * remove redundent comment Co-authored-by: liukai * Add get_prune_config and a demo config_pruning (#389) * update tools and test * add demo * disable test doc * add switch for test tools and test_doc * fix bug * update doc * update tools name * mv get_channel_units Co-authored-by: liukai * [Improvement] Adapt OFA series with SearchableMobileNetV3 (#385) * fix mutable bug in AttentiveMobileNetV3 * remove unness code * update ATTENTIVE_SUBNET_A0-A6.yaml with optimized names * unify the sampling usage in sandwich_rule-based NAS * use alias to export subnet * update OFA configs * fix attr bug * fix comments * update convert_supernet2subnet.py * correct the way to dump DerivedMutable * fix convert index bug * update OFA configs & models * fix dynamic2static * generalize convert_ofa_ckpt.py * update input_resizer * update README.md * fix ut * update export_fix_subnet * update _dynamic_to_static * update fix_subnet UT & minor fix bugs * fix ut * add new autoaug compared to attentivenas * clean * fix act * fix act_cfg * update fix_subnet * fix lint * add docstring Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: aptsunny * [Fix]Dcff Deploy Revision (#383) * dcff deploy revision * tempsave * update fix_subnet * update mutator load * export/load_fix_subnet revision for mutator * update fix_subnet with dev-1.x * update comments * update docs * update registry * [Fix] Fix commands in README to adapt branch 1.x (#400) * update commands in README for 1.x * fix commands Co-authored-by: gaoyang07 <1546308416@qq.com> * Set requires_grad to False if the teacher is not trainable (#398) * add choice and mask of units to checkpoint (#397) * add choice and mask of units to checkpoint * update * fix bug * remove device operation * fix bug * fix circle ci error * fix error in numpy for circle ci * fix bug in requirements * restore * add a note * a new solution * save mutable_channel.mask as float for dist training * refine * mv meta file test Co-authored-by: liukai Co-authored-by: jacky * [Bug]Fix fpn teacher distill (#388) fix fpn distill * [CodeCamp #122] Support KD algorithm MGD for detection. (#377) * [Feature] Support KD algorithm MGD for detection. * use connector to beauty mgd. * fix typo, add unitest. * fix mgd loss unitest. * fix mgd connector unitest. * add model pth and log file. * add mAP. * update l1 config (#405) * add l1 config * update l1 config Co-authored-by: jacky * [Feature] Add greedy search for AutoSlim (#336) * WIP: add greedysearch * fix greedy search and add bn_training_mode to autoslim * fix cfg files * fix autoslim configs * fix bugs when converting dynamic bn to static bn * change to test loop * refactor greedy search * rebase and fix greedysearch * fix lint * fix and delete useless codes * fix pytest * fix pytest and add bn_training_mode * fix lint * add reference to AutoSlimGreedySearchLoop's docstring * sort candidate_choices * fix save subnet * delete useless codes in channel container * change files' name: convert greedy_search_loop to autoslim_greedy_search_loop * [Fix] Fix metafile (#422) * fix ckpt path in metafile and readme * fix darts file path * fix docstring in ConfigurableDistiller * fix darts * fix error * add darts of mmrazor version * delete py36 Co-authored-by: liukai * update bignas cfg (#412) * check attentivenas training * update ckpt link * update supernet log Co-authored-by: aptsunny * Bump version to 1.0.0rc2 (#423) bump version to 1.0.0rc2 Co-authored-by: liukai * fix lint * fix ci * add tmp docstring for passed ci * add tmp docstring for passed ci * fix ci * add get_placeholder for quant * add skip for unittest * fix package placeholder bug * add version judgement in __init__ * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai Co-authored-by: Yang Gao Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com> Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: jacky Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com> Co-authored-by: Yue Sun Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com> Co-authored-by: zengyi.vendor Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com> Co-authored-by: zhangzhongyu Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com> Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com> Co-authored-by: Your Name Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com> Co-authored-by: wangshiguang Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: sunyue1 Co-authored-by: liukai Co-authored-by: Ming-Hsuan-Tu Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com> Co-authored-by: Yue Sun --- .github/workflows/build.yml | 38 +++++ configs/pruning/mmpose/dcff/fix_subnet.json | 4 + ...pact_pointrend_resnet50_8xb2_cityscapes.py | 4 + mmrazor/engine/__init__.py | 9 +- mmrazor/engine/runner/__init__.py | 4 +- mmrazor/engine/runner/iteprune_val_loop.py | 1 - mmrazor/engine/runner/quantization_loops.py | 15 +- mmrazor/models/algorithms/nas/autoslim.py | 2 + .../algorithms/pruning/ite_prune_algorithm.py | 4 + .../quantization/mm_architecture.py | 9 +- mmrazor/models/fake_quants/base.py | 6 +- .../models/fake_quants/torch_fake_quants.py | 8 +- mmrazor/models/losses/__init__.py | 1 - .../one_shot_channel_mutator.py | 4 +- mmrazor/models/mutators/group_mixin.py | 68 ++++++++ .../models/mutators/value_mutator/__init__.py | 5 + .../value_mutator/dynamic_value_mutator.py | 14 ++ .../mutators/value_mutator/value_mutator.py | 73 +++++++++ mmrazor/models/observers/base.py | 6 +- mmrazor/models/observers/torch_observers.py | 8 +- .../models/quantizers/academic_quantizer.py | 26 ++- mmrazor/models/quantizers/base.py | 5 +- mmrazor/models/quantizers/native_quantizer.py | 150 ++++++++++++------ .../models/quantizers/openvino_quantizer.py | 17 +- .../models/quantizers/tensorrt_quantizer.py | 12 +- .../task_modules/tracer/fx/custom_tracer.py | 70 ++++---- .../task_modules/tracer/fx/graph_utils.py | 44 ++--- .../quantization/backend_config/academic.py | 34 ++-- .../common_operator_config_utils.py | 86 ++++++---- .../quantization/backend_config/mapping.py | 23 ++- .../quantization/backend_config/native.py | 132 +++++++-------- .../quantization/backend_config/openvino.py | 15 +- .../quantization/backend_config/tensorrt.py | 15 +- mmrazor/structures/quantization/qconfig.py | 7 +- tests/data/models.py | 3 - tests/test_data.py | 8 + .../test_mutators/test_value_mutator.py | 66 ++++++++ .../test_task_modules/test_custom_tracer.py | 35 ---- .../test_task_modules/test_graph_utils.py | 49 +++++- tests/test_registry/test_registry.py | 40 +++-- tests/test_structures/test_qconfig.py | 23 ++- 41 files changed, 838 insertions(+), 305 deletions(-) create mode 100644 mmrazor/models/mutators/value_mutator/__init__.py create mode 100644 mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py create mode 100644 mmrazor/models/mutators/value_mutator/value_mutator.py create mode 100644 tests/test_models/test_mutators/test_value_mutator.py delete mode 100644 tests/test_models/test_task_modules/test_custom_tracer.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 53a184a3d..e00ed24c8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,6 +31,44 @@ jobs: python-version: [3.7] torch: [1.6.0, 1.7.0, 1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0] include: + - torch: 1.6.0 + torch_version: 1.6 + torchvision: 0.7.0 + - torch: 1.7.0 + torch_version: 1.7 + torchvision: 0.8.1 + - torch: 1.7.0 + torch_version: 1.7 + torchvision: 0.8.1 + python-version: 3.8 + - torch: 1.8.0 + torch_version: 1.8 + torchvision: 0.9.0 + - torch: 1.8.0 + torch_version: 1.8 + torchvision: 0.9.0 + python-version: 3.8 + - torch: 1.9.0 + torch_version: 1.9 + torchvision: 0.10.0 + - torch: 1.9.0 + torch_version: 1.9 + torchvision: 0.10.0 + python-version: 3.8 + - torch: 1.10.0 + torch_version: 1.10 + torchvision: 0.11.0 + - torch: 1.10.0 + torch_version: 1.10 + torchvision: 0.11.0 + python-version: 3.8 + - torch: 1.11.0 + torch_version: 1.11 + torchvision: 0.12.0 + - torch: 1.11.0 + torch_version: 1.11 + torchvision: 0.12.0 + python-version: 3.8 - torch: 1.12.0 torch_version: 1.12 torchvision: 0.13.0 diff --git a/configs/pruning/mmpose/dcff/fix_subnet.json b/configs/pruning/mmpose/dcff/fix_subnet.json index dfdcea758..f7b40f41d 100644 --- a/configs/pruning/mmpose/dcff/fix_subnet.json +++ b/configs/pruning/mmpose/dcff/fix_subnet.json @@ -54,7 +54,11 @@ "min_value":1, "min_ratio":0.9 }, +<<<<<<< HEAD "choice":0.59375 +======= + "choice":0.59374 +>>>>>>> 985a611e (Merge dev-1.x into quantize (#430)) }, "backbone.layer2.1.conv1_(0, 128)_128":{ "init_args":{ diff --git a/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py b/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py index e6c1eb031..a0d0d044a 100644 --- a/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py +++ b/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py @@ -1,7 +1,11 @@ _base_ = ['dcff_pointrend_resnet50_8xb2_cityscapes.py'] # model settings +<<<<<<< HEAD _base_.model = dict( +======= +model_cfg = dict( +>>>>>>> 985a611e (Merge dev-1.x into quantize (#430)) _scope_='mmrazor', type='sub_model', cfg=_base_.architecture, diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py index da6cec34d..603aa3d77 100644 --- a/mmrazor/engine/__init__.py +++ b/mmrazor/engine/__init__.py @@ -4,15 +4,14 @@ from .optimizers import SeparateOptimWrapperConstructor from .runner import (AutoSlimGreedySearchLoop, DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop, EvolutionSearchLoop, - GreedySamplerTrainLoop, SelfDistillValLoop, - SingleTeacherDistillValLoop, SlimmableValLoop, - SubnetValLoop) + GreedySamplerTrainLoop, PTQLoop, QATEpochBasedLoop, + SelfDistillValLoop, SingleTeacherDistillValLoop, + SlimmableValLoop, SubnetValLoop) __all__ = [ 'SeparateOptimWrapperConstructor', 'DumpSubnetHook', 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'EstimateResourcesHook', 'SelfDistillValLoop', - 'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'StopDistillHook', - 'DMCPSubnetHook' + 'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'PTQLoop', 'QATEpochBasedLoop' ] diff --git a/mmrazor/engine/runner/__init__.py b/mmrazor/engine/runner/__init__.py index 647d8b410..2ca6c0dbb 100644 --- a/mmrazor/engine/runner/__init__.py +++ b/mmrazor/engine/runner/__init__.py @@ -13,6 +13,6 @@ 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop', - 'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'PTQLoop', - 'QATEpochBasedLoop' + 'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'QATEpochBasedLoop', + 'PTQLoop' ] diff --git a/mmrazor/engine/runner/iteprune_val_loop.py b/mmrazor/engine/runner/iteprune_val_loop.py index bbca5d53a..2a627f398 100644 --- a/mmrazor/engine/runner/iteprune_val_loop.py +++ b/mmrazor/engine/runner/iteprune_val_loop.py @@ -52,7 +52,6 @@ def _save_fix_subnet(self): file.write(fix_subnet) torch.save({'state_dict': static_model.state_dict()}, osp.join(self.runner.work_dir, weight_name)) - self.runner.logger.info( 'export finished and ' f'{subnet_name}, ' diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index 2a0aa812f..e90715910 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -4,9 +4,18 @@ import torch from mmengine.evaluator import Evaluator from mmengine.runner import EpochBasedTrainLoop, TestLoop, ValLoop -from torch.ao.quantization import (disable_observer, enable_fake_quant, - enable_observer) -from torch.nn.intrinsic.qat import freeze_bn_stats + +try: + from torch.ao.quantization import (disable_observer, enable_fake_quant, + enable_observer) + from torch.nn.intrinsic.qat import freeze_bn_stats +except ImportError: + from mmrazor.utils import get_placeholder + disable_observer = get_placeholder('torch>=1.13') + enable_fake_quant = get_placeholder('torch>=1.13') + enable_observer = get_placeholder('torch>=1.13') + freeze_bn_stats = get_placeholder('torch>=1.13') + from torch.utils.data import DataLoader from mmrazor.registry import LOOPS diff --git a/mmrazor/models/algorithms/nas/autoslim.py b/mmrazor/models/algorithms/nas/autoslim.py index dc8d54c0e..77bb6cacc 100644 --- a/mmrazor/models/algorithms/nas/autoslim.py +++ b/mmrazor/models/algorithms/nas/autoslim.py @@ -75,6 +75,8 @@ def __init__(self, self._optim_wrapper_count_status_reinitialized = False self.norm_training = norm_training + self.bn_training_mode = bn_training_mode + def _build_mutator(self, mutator: VALID_MUTATOR_TYPE = None) -> ChannelMutator: """Build mutator.""" diff --git a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py index 937aaa156..f510acd76 100644 --- a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py +++ b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py @@ -10,6 +10,7 @@ from mmrazor.models.mutables import MutableChannelUnit from mmrazor.models.mutators import ChannelMutator from mmrazor.registry import MODELS +from mmrazor.utils import ValidFixMutable from ..base import BaseAlgorithm LossResults = Dict[str, torch.Tensor] @@ -97,6 +98,8 @@ class ItePruneAlgorithm(BaseAlgorithm): mutator_cfg (Union[Dict, ChannelMutator], optional): The config of a mutator. Defaults to dict( type='ChannelMutator', channel_unit_cfg=dict( type='SequentialMutableChannelUnit')). + fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or + loaded dict or built :obj:`FixSubnet`. Defaults to None. data_preprocessor (Optional[Union[Dict, nn.Module]], optional): Defaults to None. target_pruning_ratio (dict, optional): The prune-target. The template @@ -118,6 +121,7 @@ def __init__(self, type='ChannelMutator', channel_unit_cfg=dict( type='SequentialMutableChannelUnit')), + fix_subnet: Optional[ValidFixMutable] = None, data_preprocessor: Optional[Union[Dict, nn.Module]] = None, target_pruning_ratio: Optional[Dict[str, float]] = None, step_freq=1, diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index c14aae08c..f5cf30f10 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -7,12 +7,17 @@ from mmengine.runner import load_checkpoint from mmengine.structures import BaseDataElement from torch import nn -from torch.ao.quantization import FakeQuantizeBase -from mmrazor.models.task_modules import build_graphmodule +from mmrazor.models.task_modules.tracer import build_graphmodule from mmrazor.registry import MODEL_WRAPPERS, MODELS from ..base import BaseAlgorithm +try: + from torch.ao.quantization import FakeQuantizeBase +except ImportError: + from mmrazor.utils import get_placeholder + FakeQuantizeBase = get_placeholder('torch>=1.13') + LossResults = Dict[str, torch.Tensor] TensorResults = Union[Tuple[torch.Tensor], torch.Tensor] PredictResults = List[BaseDataElement] diff --git a/mmrazor/models/fake_quants/base.py b/mmrazor/models/fake_quants/base.py index 1d4c6dfe0..45aed7421 100644 --- a/mmrazor/models/fake_quants/base.py +++ b/mmrazor/models/fake_quants/base.py @@ -1,4 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from torch.ao.quantization import FakeQuantize +try: + from torch.ao.quantization import FakeQuantize +except ImportError: + from mmrazor.utils import get_placeholder + FakeQuantize = get_placeholder('torch>=1.13') BaseFakeQuantize = FakeQuantize diff --git a/mmrazor/models/fake_quants/torch_fake_quants.py b/mmrazor/models/fake_quants/torch_fake_quants.py index ad1a0d966..b477929ad 100644 --- a/mmrazor/models/fake_quants/torch_fake_quants.py +++ b/mmrazor/models/fake_quants/torch_fake_quants.py @@ -2,10 +2,14 @@ import inspect from typing import List -import torch.ao.quantization.fake_quantize as torch_fake_quant_src - from mmrazor.registry import MODELS +try: + import torch.ao.quantization.fake_quantize as torch_fake_quant_src +except ImportError: + from mmrazor.utils import get_package_placeholder + torch_fake_quant_src = get_package_placeholder('torch>=1.13') + def register_torch_fake_quants() -> List[str]: """Register fake_quants in ``torch.ao.quantization.fake_quantize`` to the diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py index 3509acd5c..65e2108fd 100644 --- a/mmrazor/models/losses/__init__.py +++ b/mmrazor/models/losses/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ab_loss import ABLoss -from .adaround_loss import AdaRoundLoss from .at_loss import ATLoss from .crd_loss import CRDLoss from .cross_entropy_loss import CrossEntropyLoss diff --git a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py index cc008b0b8..3aca98c95 100644 --- a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py @@ -4,11 +4,13 @@ from mmrazor.models.mutables import OneShotMutableChannelUnit from mmrazor.registry import MODELS +from ..group_mixin import DynamicSampleMixin from .channel_mutator import ChannelMutator, ChannelUnitType @MODELS.register_module() -class OneShotChannelMutator(ChannelMutator[OneShotMutableChannelUnit]): +class OneShotChannelMutator(ChannelMutator[OneShotMutableChannelUnit], + DynamicSampleMixin): """OneShotChannelMutator based on ChannelMutator. It use OneShotMutableChannelUnit by default. diff --git a/mmrazor/models/mutators/group_mixin.py b/mmrazor/models/mutators/group_mixin.py index 569f01ebc..3ecd44b74 100644 --- a/mmrazor/models/mutators/group_mixin.py +++ b/mmrazor/models/mutators/group_mixin.py @@ -8,6 +8,11 @@ from mmrazor.models.mutables.mutable_module import MutableModule from .base_mutator import MUTABLE_TYPE +if sys.version_info < (3, 8): + from typing_extensions import Protocol +else: + from typing import Protocol + class GroupMixin(): """A mixin for :class:`BaseMutator`, which can group mutables by @@ -259,3 +264,66 @@ def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]], f'When a mutable is set alias attribute :{alias_key},' f'the corresponding module name {mutable_name} should ' f'not be used in `custom_group` {custom_group}.') + + +class MutatorProtocol(Protocol): # pragma: no cover + + @property + def mutable_class_type(self) -> Type[BaseMutable]: + ... + + @property + def search_groups(self) -> Dict: + ... + + +class OneShotSampleMixin: + """Sample mixin for one-shot mutators.""" + + def sample_choices(self: MutatorProtocol) -> Dict: + """Sample choices for each group in search_groups.""" + random_choices = dict() + for group_id, modules in self.search_groups.items(): + random_choices[group_id] = modules[0].sample_choice() + + return random_choices + + def set_choices(self: MutatorProtocol, choices: Dict) -> None: + """Set choices for each group in search_groups.""" + for group_id, modules in self.search_groups.items(): + choice = choices[group_id] + for module in modules: + module.current_choice = choice + + +class DynamicSampleMixin(OneShotSampleMixin): + + def sample_choices(self: MutatorProtocol, kind: str = 'random') -> Dict: + """Sample choices for each group in search_groups.""" + random_choices = dict() + for group_id, modules in self.search_groups.items(): + if kind == 'max': + random_choices[group_id] = modules[0].max_choice + elif kind == 'min': + random_choices[group_id] = modules[0].min_choice + else: + random_choices[group_id] = modules[0].sample_choice() + return random_choices + + @property + def max_choice(self: MutatorProtocol) -> Dict: + """Get max choices for each group in search_groups.""" + max_choice = dict() + for group_id, modules in self.search_groups.items(): + max_choice[group_id] = modules[0].max_choice + + return max_choice + + @property + def min_choice(self: MutatorProtocol) -> Dict: + """Get min choices for each group in search_groups.""" + min_choice = dict() + for group_id, modules in self.search_groups.items(): + min_choice[group_id] = modules[0].min_choice + + return min_choice diff --git a/mmrazor/models/mutators/value_mutator/__init__.py b/mmrazor/models/mutators/value_mutator/__init__.py new file mode 100644 index 000000000..a29577bb1 --- /dev/null +++ b/mmrazor/models/mutators/value_mutator/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dynamic_value_mutator import DynamicValueMutator +from .value_mutator import ValueMutator + +__all__ = ['ValueMutator', 'DynamicValueMutator'] diff --git a/mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py b/mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py new file mode 100644 index 000000000..d8d081343 --- /dev/null +++ b/mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmrazor.models.mutables import OneShotMutableValue +from mmrazor.registry import MODELS +from ..group_mixin import DynamicSampleMixin +from .value_mutator import ValueMutator + + +@MODELS.register_module() +class DynamicValueMutator(ValueMutator, DynamicSampleMixin): + """Dynamic value mutator with type as `OneShotMutableValue`.""" + + @property + def mutable_class_type(self): + return OneShotMutableValue diff --git a/mmrazor/models/mutators/value_mutator/value_mutator.py b/mmrazor/models/mutators/value_mutator/value_mutator.py new file mode 100644 index 000000000..5127cbe37 --- /dev/null +++ b/mmrazor/models/mutators/value_mutator/value_mutator.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Type + +from torch.nn import Module + +from mmrazor.models.mutables import MutableValue +from mmrazor.registry import MODELS +from ..base_mutator import BaseMutator +from ..group_mixin import GroupMixin + + +@MODELS.register_module() +class ValueMutator(BaseMutator[MutableValue], GroupMixin): + """The base class for mutable based mutator. All subclass should implement + the following APIS: + + - ``mutable_class_type`` + Args: + custom_group (list[list[str]], optional): User-defined search groups. + All searchable modules that are not in ``custom_group`` will be + grouped separately. + """ + + def __init__(self, + custom_group: Optional[List[List[str]]] = None, + init_cfg: Optional[Dict] = None) -> None: + super().__init__(init_cfg) + + if custom_group is None: + custom_group = [] + self._custom_group = custom_group + self._search_groups: Optional[Dict[int, List[MutableValue]]] = None + + # TODO + # should be a class property + @property + def mutable_class_type(self) -> Type[MutableValue]: + """Corresponding mutable class type. + + Returns: + Type[MUTABLE_TYPE]: Mutable class type. + """ + return MutableValue + + def prepare_from_supernet(self, supernet: Module) -> None: + """Do some necessary preparations with supernet. + + Note: + For mutable based mutator, we need to build search group first. + Args: + supernet (:obj:`torch.nn.Module`): The supernet to be searched + in your algorithm. + """ + self._search_groups = self.build_search_groups(supernet, + self.mutable_class_type, + self._custom_group) + + @property + def search_groups(self) -> Dict[int, List[MutableValue]]: + """Search group of supernet. + + Note: + For mutable based mutator, the search group is composed of + corresponding mutables. + Raises: + RuntimeError: Called before search group has been built. + Returns: + Dict[int, List[MUTABLE_TYPE]]: Search group. + """ + if self._search_groups is None: + raise RuntimeError( + 'Call `prepare_from_supernet` before access search group!') + return self._search_groups diff --git a/mmrazor/models/observers/base.py b/mmrazor/models/observers/base.py index a68410eb0..ce226cb48 100644 --- a/mmrazor/models/observers/base.py +++ b/mmrazor/models/observers/base.py @@ -1,4 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from torch.ao.quantization.observer import UniformQuantizationObserverBase +try: + from torch.ao.quantization.observer import UniformQuantizationObserverBase +except ImportError: + from mmrazor.utils import get_placeholder + UniformQuantizationObserverBase = get_placeholder('torch>=1.13') BaseObserver = UniformQuantizationObserverBase diff --git a/mmrazor/models/observers/torch_observers.py b/mmrazor/models/observers/torch_observers.py index 8e0e81d58..5dc24609f 100644 --- a/mmrazor/models/observers/torch_observers.py +++ b/mmrazor/models/observers/torch_observers.py @@ -2,10 +2,14 @@ import inspect from typing import List -import torch.ao.quantization.observer as torch_observer_src - from mmrazor.registry import MODELS +try: + import torch.ao.quantization.observer as torch_observer_src +except ImportError: + from mmrazor.utils import get_package_placeholder + torch_observer_src = get_package_placeholder('torch>=1.13') + def register_torch_observers() -> List[str]: """Register observers in ``torch.ao.quantization.observer`` to the diff --git a/mmrazor/models/quantizers/academic_quantizer.py b/mmrazor/models/quantizers/academic_quantizer.py index 6a6500791..768f51c53 100644 --- a/mmrazor/models/quantizers/academic_quantizer.py +++ b/mmrazor/models/quantizers/academic_quantizer.py @@ -1,16 +1,26 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization.fx import prepare -from torch.ao.quantization.fx.custom_config import (FuseCustomConfig, - PrepareCustomConfig) -from torch.ao.quantization.qconfig_mapping import QConfigMapping -from torch.ao.quantization.quant_type import _quant_type_from_str -from torch.ao.quantization.quantize_fx import _fuse_fx from mmrazor.registry import MODELS from mmrazor.structures.quantization import BackendConfigs, QConfigHander from .base import BaseQuantizer +try: + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.fx.custom_config import (FuseCustomConfig, + PrepareCustomConfig) + from torch.ao.quantization.qconfig_mapping import QConfigMapping + from torch.ao.quantization.quant_type import _quant_type_from_str + from torch.ao.quantization.quantize_fx import _fuse_fx +except ImportError: + from mmrazor.utils import get_placeholder + prepare = get_placeholder('torch>=1.13') + FuseCustomConfig = get_placeholder('torch>=1.13') + PrepareCustomConfig = get_placeholder('torch>=1.13') + QConfigMapping = get_placeholder('torch>=1.13') + _quant_type_from_str = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + GLOBAL_DICT_KEY = '_global_' OBJECT_TYPE_DICT_KEY = 'object_type' MODULE_NAME_REGEX_DICT_KEY = 'module_name_regex' @@ -23,6 +33,7 @@ @MODELS.register_module() class AcademicQuantizer(BaseQuantizer): + """tmp.""" def __init__(self, qconfig_mapping, @@ -37,6 +48,7 @@ def __init__(self, self.example_inputs = (torch.randn(1, 3, 224, 224), ) def prepare(self, model, graph_module): + """tmp.""" preserved_attributes = self.prepare_custom_config.preserved_attributes for attr_name in preserved_attributes: setattr(graph_module, attr_name, getattr(model, attr_name)) @@ -60,6 +72,7 @@ def prepare(self, model, graph_module): return prepared def gen_qconfig_mapping(self, qconfig_mapping): + """tmp.""" conf = QConfigMapping() if GLOBAL_DICT_KEY in qconfig_mapping: qconfig = QConfigHander(qconfig_mapping[GLOBAL_DICT_KEY]).convert() @@ -86,6 +99,7 @@ def gen_qconfig_mapping(self, qconfig_mapping): return conf def gen_prepare_custom_config(self, prepare_custom_config): + """tmp.""" conf = PrepareCustomConfig() if prepare_custom_config is None: return conf diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py index d98fbd786..0f14917ac 100644 --- a/mmrazor/models/quantizers/base.py +++ b/mmrazor/models/quantizers/base.py @@ -8,6 +8,7 @@ class BaseQuantizer(BaseModule): + """tmp.""" def __init__(self, tracer): super().__init__() @@ -15,11 +16,11 @@ def __init__(self, tracer): @abstractmethod def prepare(self, model, graph_module): + """tmp.""" pass def swap_ff_with_fxff(self, model): - r""" Swap FloatFunctional with FXFloatFunctional - """ + """Swap FloatFunctional with FXFloatFunctional.""" modules_to_swap = [] for name, module in model.named_children(): if isinstance(module, torch.ao.nn.quantized.FloatFunctional): diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index 84be1edfb..b3f2002e5 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -1,45 +1,62 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Tuple + import torch -from torch.ao.quantization import enable_fake_quant -from torch.ao.quantization.fx import prepare -from torch.ao.quantization.qconfig_mapping import QConfigMapping -from torch.ao.quantization.quantize_fx import _fuse_fx -from torch.nn.intrinsic.qat import modules as qat_fused_modules -from torch.nn.qat import modules as qat_modules +try: + from torch.ao.quantization import enable_fake_quant + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.qconfig_mapping import QConfigMapping + from torch.ao.quantization.quantize_fx import _fuse_fx + from torch.nn.intrinsic.qat import modules as qat_fused_modules + from torch.nn.qat import modules as qat_modules +except ImportError: + from mmrazor.utils import get_package_placeholder, get_placeholder + enable_fake_quant = get_placeholder('torch>=1.13') + prepare = get_placeholder('torch>=1.13') + QConfigMapping = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + qat_fused_modules = get_package_placeholder('torch>=1.13') + qat_modules = get_package_placeholder('torch>=1.13') + +from mmrazor import digit_version from mmrazor.models.task_modules.tracer.fx import ( del_fakequant_after_function, del_fakequant_after_method, del_fakequant_after_module, del_fakequant_after_op, del_fakequant_before_function, del_fakequant_before_method, del_fakequant_before_module, del_fakequant_before_op) - from mmrazor.models.utils import str2class from mmrazor.registry import MODELS from mmrazor.structures.quantization import BackendConfigs, QConfigHander from .base import BaseQuantizer -SUPPORT_QAT_MODULES = ( - qat_fused_modules.ConvBn1d, qat_fused_modules.ConvBn2d, - qat_fused_modules.ConvBn3d, qat_fused_modules.ConvBnReLU1d, - qat_fused_modules.ConvBnReLU2d, qat_fused_modules.ConvBnReLU3d, - qat_fused_modules.ConvReLU1d, qat_fused_modules.ConvReLU2d, - qat_fused_modules.ConvReLU3d, qat_fused_modules.LinearBn1d, - qat_fused_modules.LinearReLU, qat_modules.Conv1d, qat_modules.Conv2d, - qat_modules.Conv3d, qat_modules.Linear) - -MERGE_BN_MAPPINGS = { - qat_fused_modules.ConvBn1d: qat_modules.Conv1d, - qat_fused_modules.ConvBn2d: qat_modules.Conv2d, - qat_fused_modules.ConvBn3d: qat_modules.Conv3d, - qat_fused_modules.ConvBnReLU1d: qat_fused_modules.ConvReLU1d, - qat_fused_modules.ConvBnReLU2d: qat_fused_modules.ConvReLU2d, - qat_fused_modules.ConvBnReLU3d: qat_fused_modules.ConvReLU3d, - qat_fused_modules.LinearBn1d: qat_modules.Linear -} +if digit_version(torch.__version__) >= digit_version('1.13.0'): + SUPPORT_QAT_MODULES: Tuple = ( + qat_fused_modules.ConvBn1d, qat_fused_modules.ConvBn2d, + qat_fused_modules.ConvBn3d, qat_fused_modules.ConvBnReLU1d, + qat_fused_modules.ConvBnReLU2d, qat_fused_modules.ConvBnReLU3d, + qat_fused_modules.ConvReLU1d, qat_fused_modules.ConvReLU2d, + qat_fused_modules.ConvReLU3d, qat_fused_modules.LinearBn1d, + qat_fused_modules.LinearReLU, qat_modules.Conv1d, qat_modules.Conv2d, + qat_modules.Conv3d, qat_modules.Linear) + + MERGE_BN_MAPPINGS: Dict = { + qat_fused_modules.ConvBn1d: qat_modules.Conv1d, + qat_fused_modules.ConvBn2d: qat_modules.Conv2d, + qat_fused_modules.ConvBn3d: qat_modules.Conv3d, + qat_fused_modules.ConvBnReLU1d: qat_fused_modules.ConvReLU1d, + qat_fused_modules.ConvBnReLU2d: qat_fused_modules.ConvReLU2d, + qat_fused_modules.ConvBnReLU3d: qat_fused_modules.ConvReLU3d, + qat_fused_modules.LinearBn1d: qat_modules.Linear + } +else: + SUPPORT_QAT_MODULES = () + MERGE_BN_MAPPINGS = {} @MODELS.register_module() class NativeQuantizer(BaseQuantizer): + """tmp.""" # backend: 'native' # support_w_modes = ['per_tensor', 'per_channel'] @@ -52,12 +69,12 @@ def __init__(self, extra_redundant_fakequants=dict( extra_module_prev_wo_fakequant=tuple(), extra_module_next_wo_fakequant=tuple(), - extra_function_prev_wo_fakequant = tuple(), - extra_function_next_wo_fakequant = tuple(), - extra_method_prev_wo_fakequant = tuple(), - extra_method_next_wo_fakequant = tuple(), - extra_op_prev_wo_fakequant = tuple(), - extra_op_next_wo_fakequant = tuple())): + extra_function_prev_wo_fakequant=tuple(), + extra_function_next_wo_fakequant=tuple(), + extra_method_prev_wo_fakequant=tuple(), + extra_method_next_wo_fakequant=tuple(), + extra_op_prev_wo_fakequant=tuple(), + extra_op_next_wo_fakequant=tuple())): super().__init__(tracer) self.qconfig = QConfigHander(global_qconfig) if self.qconfig.w_qscheme.is_per_channel: @@ -86,17 +103,21 @@ def __init__(self, @property def backend(self): + """tmp.""" return 'native' @property def support_w_modes(self): + """tmp.""" return ['per_tensor', 'per_channel'] @property def support_a_modes(self): + """tmp.""" return ['per_tensor'] def prepare(self, model, graph_module): + """tmp.""" graph_module = _fuse_fx( graph_module=graph_module, is_qat=True, @@ -115,6 +136,7 @@ def prepare(self, model, graph_module): def post_process_weight_fakequant(self, observed_module, keep_fake_quant=False): + """tmp.""" def traverse(module): for name, child in module.named_children(): @@ -145,70 +167,104 @@ def traverse(module): traverse(observed_module) def prepare_for_mmdeploy(self, model, dummy_input, checkpoint): + """tmp.""" raise NotImplementedError def del_redundant_fakequant(self, prepared): - extra_module_prev_wo_fakequant = self.extra_redundant_fakequants.get('extra_module_prev_wo_fakequant', tuple()) + """tmp.""" + extra_module_prev_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_module_prev_wo_fakequant', tuple()) prepared = del_fakequant_before_module( - prepared, self.module_prev_wo_fakequant + extra_module_prev_wo_fakequant, inplace=True) + prepared, + self.module_prev_wo_fakequant + extra_module_prev_wo_fakequant, + inplace=True) - extra_module_next_wo_fakequant = self.extra_redundant_fakequants.get('extra_module_next_wo_fakequant', tuple()) + extra_module_next_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_module_next_wo_fakequant', tuple()) prepared = del_fakequant_after_module( - prepared, self.module_next_wo_fakequant + extra_module_next_wo_fakequant, inplace=True) + prepared, + self.module_next_wo_fakequant + extra_module_next_wo_fakequant, + inplace=True) - extra_function_prev_wo_fakequant = self.extra_redundant_fakequants.get('extra_function_prev_wo_fakequant', tuple()) + extra_function_prev_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_function_prev_wo_fakequant', tuple()) prepared = del_fakequant_before_method( - prepared, self.function_prev_wo_fakequant + extra_function_prev_wo_fakequant, inplace=True) + prepared, + self.function_prev_wo_fakequant + extra_function_prev_wo_fakequant, + inplace=True) - extra_function_next_wo_fakequant = self.extra_redundant_fakequants.get('extra_function_next_wo_fakequant', tuple()) + extra_function_next_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_function_next_wo_fakequant', tuple()) prepared = del_fakequant_after_method( - prepared, self.function_next_wo_fakequant + extra_function_next_wo_fakequant, inplace=True) + prepared, + self.function_next_wo_fakequant + extra_function_next_wo_fakequant, + inplace=True) - extra_method_prev_wo_fakequant = self.extra_redundant_fakequants.get('extra_method_prev_wo_fakequant', tuple()) + extra_method_prev_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_method_prev_wo_fakequant', tuple()) prepared = del_fakequant_before_function( - prepared, self.method_prev_wo_fakequant + extra_method_prev_wo_fakequant, inplace=True) + prepared, + self.method_prev_wo_fakequant + extra_method_prev_wo_fakequant, + inplace=True) - extra_method_next_wo_fakequant = self.extra_redundant_fakequants.get('extra_method_next_wo_fakequant', tuple()) + extra_method_next_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_method_next_wo_fakequant', tuple()) prepared = del_fakequant_after_function( - prepared, self.method_next_wo_fakequant + extra_method_next_wo_fakequant, inplace=True) + prepared, + self.method_next_wo_fakequant + extra_method_next_wo_fakequant, + inplace=True) - extra_op_prev_wo_fakequant = self.extra_redundant_fakequants.get('extra_op_prev_wo_fakequant', tuple()) + extra_op_prev_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_op_prev_wo_fakequant', tuple()) prepared = del_fakequant_before_op( - prepared, self.op_prev_wo_fakequant + extra_op_prev_wo_fakequant, inplace=True) + prepared, + self.op_prev_wo_fakequant + extra_op_prev_wo_fakequant, + inplace=True) - extra_op_next_wo_fakequant = self.extra_redundant_fakequants.get('extra_op_next_wo_fakequant', tuple()) + extra_op_next_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_op_next_wo_fakequant', tuple()) prepared = del_fakequant_after_op( - prepared, self.op_next_wo_fakequant + extra_op_next_wo_fakequant, inplace=True) + prepared, + self.op_next_wo_fakequant + extra_op_next_wo_fakequant, + inplace=True) return prepared @property def module_prev_wo_fakequant(self): + """tmp.""" return tuple() @property def module_next_wo_fakequant(self): + """tmp.""" return tuple() @property def function_prev_wo_fakequant(self): + """tmp.""" return tuple() @property def function_next_wo_fakequant(self): + """tmp.""" return tuple() @property def method_prev_wo_fakequant(self): + """tmp.""" return tuple() @property def method_next_wo_fakequant(self): + """tmp.""" return tuple() @property def op_prev_wo_fakequant(self): + """tmp.""" return tuple() @property def op_next_wo_fakequant(self): + """tmp.""" return tuple() diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py index 0b13b23f9..23abf40da 100644 --- a/mmrazor/models/quantizers/openvino_quantizer.py +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -1,8 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Tuple import torch -from torch.ao.quantization import disable_observer + +try: + from torch.ao.quantization import disable_observer +except ImportError: + from mmrazor.utils import get_placeholder + disable_observer = get_placeholder('torch>=1.13') from mmrazor.models.task_modules.tracer.fx import build_graphmodule from mmrazor.registry import MODELS @@ -19,21 +23,24 @@ class OpenVINOQuantizer(NativeQuantizer): @property def backend(self): + """tmp.""" return 'openvino' @property def support_w_modes(self): + """tmp.""" return ['per_tensor', 'per_channel'] @property def support_a_modes(self): + """tmp.""" return ['per_tensor'] def prepare_for_mmdeploy(self, model, dummy_input=(1, 3, 224, 224), checkpoint=None): - + """tmp.""" self.swap_ff_with_fxff(model) graph = self.tracer.trace(model) graph_module = build_graphmodule(model, graph) @@ -52,16 +59,20 @@ def prepare_for_mmdeploy(self, @property def module_prev_wo_fakequant(self): + """tmp.""" return (torch.nn.ReLU6, torch.nn.Identity) @property def module_next_wo_fakequant(self): + """tmp.""" return (torch.nn.MaxPool2d, ) @property def method_next_wo_fakequant(self): + """tmp.""" return ('flatten', ) @property def op_prev_wo_fakequant(self): + """tmp.""" return ('output', ) diff --git a/mmrazor/models/quantizers/tensorrt_quantizer.py b/mmrazor/models/quantizers/tensorrt_quantizer.py index 4d9868c4f..36e3f2be7 100644 --- a/mmrazor/models/quantizers/tensorrt_quantizer.py +++ b/mmrazor/models/quantizers/tensorrt_quantizer.py @@ -1,6 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization import disable_observer + +try: + from torch.ao.quantization import disable_observer +except ImportError: + from mmrazor.utils import get_placeholder + disable_observer = get_placeholder('torch>=1.13') from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ build_graphmodule @@ -24,21 +29,24 @@ def __init__(self, @property def backend(self): + """tmp.""" return 'tensorrt' @property def support_w_modes(self): + """tmp.""" return ['per_tensor', 'per_channel'] @property def support_a_modes(self): + """tmp.""" return ['per_tensor'] def prepare_for_mmdeploy(self, model, dummy_input=(1, 3, 224, 224), checkpoint=None): - + """tmp.""" self.swap_ff_with_fxff(model) graph = self.tracer.trace(model) graph_module = build_graphmodule(model, graph) diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py index 0e118290e..2d33e9875 100644 --- a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -5,18 +5,32 @@ import torch import torch.nn as nn + +try: + from torch._C import ScriptObject # type: ignore[attr-defined] + from torch.ao.quantization.quantize_fx import QuantizationTracer + from torch.fx import Graph, GraphModule, Tracer + from torch.fx._symbolic_trace import (_autowrap_check, + _patch_wrapped_functions, _Patcher) + from torch.fx.proxy import Proxy +except ImportError: + from mmrazor.utils import get_placeholder + ScriptObject = get_placeholder('torch>=1.13') + QuantizationTracer = get_placeholder('torch>=1.13') + GraphModule = get_placeholder('torch>=1.13') + Tracer = get_placeholder('torch>=1.13') + Graph = get_placeholder('torch>=1.13') + _autowrap_check = get_placeholder('torch>=1.13') + _patch_wrapped_functions = get_placeholder('torch>=1.13') + _Patcher = get_placeholder('torch>=1.13') + Proxy = get_placeholder('torch>=1.13') + from mmengine.utils import import_modules_from_strings -from torch._C import ScriptObject # type: ignore[attr-defined] -from torch.ao.quantization.quantize_fx import QuantizationTracer -from torch.fx import GraphModule, Tracer -from torch.fx._symbolic_trace import (Graph, _autowrap_check, - _patch_wrapped_functions, _Patcher) -from torch.fx.proxy import Proxy from mmrazor.registry import TASK_UTILS -_orig_module_call: Callable = torch.nn.Module.__call__ -_orig_module_getattr: Callable = torch.nn.Module.__getattr__ +_orig_module_call: Callable = nn.Module.__call__ +_orig_module_getattr: Callable = nn.Module.__getattr__ class UntracedMethodRegistry: @@ -59,13 +73,12 @@ def method(*args, **kwargs): return wrapped_method -def custom_symbolic_trace( - root: Union[torch.nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule: +def custom_symbolic_trace(root: Union[nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None): """Modified `symbolic_trace` function. Args: - root (Union[torch.nn.Module, Callable]): Module or function to be + root (Union[nn.Module, Callable]): Module or function to be traced and converted into a Graph representation. concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized. @@ -75,12 +88,12 @@ def custom_symbolic_trace( """ tracer = CustomTracer() graph = tracer.trace(root, concrete_args) - name = root.__class__.__name__ if isinstance( - root, torch.nn.Module) else root.__name__ + name = root.__class__.__name__ if isinstance(root, + nn.Module) else root.__name__ return GraphModule(tracer.root, graph, name) -def _prepare_module_dict(model: nn.Module, fx_graph: torch.fx.Graph): +def _prepare_module_dict(model: nn.Module, fx_graph): """If there is a class method that can not be traced by the symbolic tracer, a ``call_method`` ``Node`` will be inserted into the ``Graph`` in ``CustomTracer``. @@ -128,7 +141,7 @@ def _prepare_module_dict(model: nn.Module, fx_graph: torch.fx.Graph): Args: model (nn.Module): The original model. - fx_graph (torch.fx.Graph): The fx Graph traced by fx tracer. + fx_graph (Graph): The fx Graph traced by fx tracer. """ def _get_attrs(target, attrs): @@ -157,9 +170,7 @@ def _get_attrs(target, attrs): return module_dict -def build_graphmodule(model: nn.Module, - fx_graph: torch.fx.Graph, - name: str = 'GraphModule'): +def build_graphmodule(model: nn.Module, fx_graph, name: str = 'GraphModule'): modules = dict(model.named_modules()) module_dict = _prepare_module_dict(model, fx_graph) modules.update(module_dict) @@ -228,7 +239,7 @@ def register_skipped_methods(self): method_registry = UntracedMethodRegistry(method) method_registry.__set_name__(imported_cls, method_str) - def call_method(self, m: torch.nn.Module, name, method, args, kwargs): + def call_method(self, m: nn.Module, name, method, args, kwargs): """Method that specifies the behavior of this ``Tracer`` when it encounters a call to an ``nn.Module`` instance. @@ -266,7 +277,7 @@ def call_method(self, m: torch.nn.Module, name, method, args, kwargs): return self.create_proxy('call_method', name, args, kwargs) def trace(self, root, concrete_args=None): - if isinstance(root, torch.nn.Module): + if isinstance(root, nn.Module): self.root = root fn = type(root).forward self.submodule_paths = { @@ -274,7 +285,7 @@ def trace(self, root, concrete_args=None): for name, mod in root.named_modules() } else: - self.root = torch.nn.Module() + self.root = nn.Module() fn = root tracer_cls: Optional[Type['Tracer']] = getattr(self, '__class__', None) @@ -286,7 +297,7 @@ def trace(self, root, concrete_args=None): # used downstream in create_arg self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {} - def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): + def collect_tensor_attrs(m: nn.Module, prefix_atoms: List[str]): for k, v in m.__dict__.items(): if isinstance(v, (torch.Tensor, ScriptObject)): self.tensor_attrs[v] = '.'.join(prefix_atoms + [k]) @@ -298,8 +309,7 @@ def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): assert isinstance(fn, FunctionType) fn_globals = fn.__globals__ # run before it gets patched - fn, args = self.create_args_for_root(fn, - isinstance(root, torch.nn.Module), + fn, args = self.create_args_for_root(fn, isinstance(root, nn.Module), concrete_args) # Reduce number of get_attr calls @@ -328,15 +338,12 @@ def forward(*args, **kwargs): with _Patcher() as patcher: # allow duplicate patches to support the case of nested calls patcher.patch_method( - torch.nn.Module, + nn.Module, '__getattr__', module_getattr_wrapper, deduplicate=False) patcher.patch_method( - torch.nn.Module, - '__call__', - module_call_wrapper, - deduplicate=False) + nn.Module, '__call__', module_call_wrapper, deduplicate=False) for name, value in UntracedMethodRegistry.method_dict.items(): wrapped = value['wrapped'] @@ -363,8 +370,7 @@ def is_skipped_method(self, m): custom = isinstance(m, mods) return custom - def is_leaf_module(self, m: torch.nn.Module, - module_qualified_name: str) -> bool: + def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool: # return super().is_leaf_module(m, module_qualified_name) leaf = super().is_leaf_module(m, module_qualified_name) return leaf diff --git a/mmrazor/models/task_modules/tracer/fx/graph_utils.py b/mmrazor/models/task_modules/tracer/fx/graph_utils.py index fe8d620c2..5e3ddc2f4 100644 --- a/mmrazor/models/task_modules/tracer/fx/graph_utils.py +++ b/mmrazor/models/task_modules/tracer/fx/graph_utils.py @@ -2,8 +2,13 @@ import copy from typing import Any, List, Tuple -import torch.fx -from torch.ao.quantization.fake_quantize import FakeQuantizeBase +import torch + +try: + from torch.ao.quantization.fake_quantize import FakeQuantizeBase +except ImportError: + from mmrazor.utils import get_placeholder + FakeQuantizeBase = get_placeholder('torch>=1.13') def _get_attrs(target: torch.nn.Module, attr: str) -> Any: @@ -67,9 +72,9 @@ def recursive_find_erased_nodes(node, prepared_model): return nodes_to_erase -def del_fakequant_before_op(prepared_model: torch.fx.GraphModule, +def del_fakequant_before_op(prepared_model, target_ops: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant before nodes whose ``op`` attribute (node.op) is in `target_ops`. @@ -104,9 +109,9 @@ def del_fakequant_before_op(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_after_op(prepared_model: torch.fx.GraphModule, +def del_fakequant_after_op(prepared_model, target_ops: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant after nodes whose ``op`` attribute (node.op) is in `target_ops`. @@ -145,9 +150,9 @@ def del_fakequant_after_op(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_before_method(prepared_model: torch.fx.GraphModule, +def del_fakequant_before_method(prepared_model, method_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant before nodes whose op attribute (node.op) is `call_method` and target attribute (node.target) is in `target_patterns`. @@ -182,9 +187,9 @@ def del_fakequant_before_method(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_after_method(prepared_model: torch.fx.GraphModule, +def del_fakequant_after_method(prepared_model, method_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant after nodes whose op attribute (node.op) is `call_method` and target attribute (node.target) is in `target_patterns`. @@ -224,10 +229,9 @@ def del_fakequant_after_method(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_before_function( - prepared_model: torch.fx.GraphModule, - function_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: +def del_fakequant_before_function(prepared_model, + function_patterns: Tuple, + inplace: bool = True): """Delete useless fakequant before nodes whose op attribute (node.op) is `call_function` and target attribute (node.target) is in `target_patterns`. @@ -262,9 +266,9 @@ def del_fakequant_before_function( return prepared_model -def del_fakequant_after_function(prepared_model: torch.fx.GraphModule, +def del_fakequant_after_function(prepared_model, function_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant after nodes whose op attribute (node.op) is `call_function` and target attribute (node.target) is in `target_patterns`. @@ -304,9 +308,9 @@ def del_fakequant_after_function(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_before_module(prepared_model: torch.fx.GraphModule, +def del_fakequant_before_module(prepared_model, module_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant before modules whose type are in `module_patterns`. @@ -340,9 +344,9 @@ def del_fakequant_before_module(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_after_module(prepared_model: torch.fx.GraphModule, +def del_fakequant_after_module(prepared_model, module_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant after modules whose type are in `module_patterns`. diff --git a/mmrazor/structures/quantization/backend_config/academic.py b/mmrazor/structures/quantization/backend_config/academic.py index 5983c3996..4348e7179 100644 --- a/mmrazor/structures/quantization/backend_config/academic.py +++ b/mmrazor/structures/quantization/backend_config/academic.py @@ -1,23 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig + +try: + from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') from .common_operator_config_utils import (_get_conv_configs, _get_linear_configs) -# =================== -# | DTYPE CONFIGS | -# =================== - -# weighted op int8 dtype config -# this is config for ops that has quantized weights, like linear, conv -weighted_op_int8_dtype_config = DTypeConfig( - input_dtype=torch.quint8, - output_dtype=torch.quint8, - weight_dtype=torch.qint8, - bias_dtype=torch.float, -) - # ===================== # | BACKEND CONFIGS | # ===================== @@ -25,6 +18,19 @@ def get_academic_backend_config() -> BackendConfig: """Return the `BackendConfig` for academic reseaching.""" + + # =================== + # | DTYPE CONFIGS | + # =================== + # weighted op int8 dtype config + # this is config for ops that has quantized weights, like linear, conv + weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + conv_dtype_configs = [weighted_op_int8_dtype_config] linear_dtype_configs = [weighted_op_int8_dtype_config] diff --git a/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py b/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py index 2a855e687..0a381d5d0 100644 --- a/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py +++ b/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py @@ -5,39 +5,71 @@ import torch import torch.nn as nn -import torch.nn.functional as F -import torch.nn.intrinsic as nni -import torch.nn.intrinsic.qat as nniqat -import torch.nn.qat as nnqat -import torch.nn.quantized._reference as nnqr -from torch.ao.quantization.backend_config import (BackendPatternConfig, - DTypeConfig, ObservationType) -from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize -from torch.ao.quantization.fuser_method_mappings import ( - fuse_conv_bn, fuse_conv_bn_relu, fuse_convtranspose_bn, fuse_linear_bn, - reverse2, reverse3, reverse_sequential_wrapper2) -from torch.ao.quantization.qconfig_mapping import _FIXED_QPARAMS_OP_TO_OBSERVER + +from mmrazor import digit_version + +try: + import torch.nn.functional as F + import torch.nn.intrinsic as nni + import torch.nn.intrinsic.qat as nniqat + import torch.nn.qat as nnqat + import torch.nn.quantized._reference as nnqr + from torch.ao.quantization.backend_config import (BackendPatternConfig, + DTypeConfig, + ObservationType) + from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize + from torch.ao.quantization.fuser_method_mappings import ( + fuse_conv_bn, fuse_conv_bn_relu, fuse_convtranspose_bn, fuse_linear_bn, + reverse2, reverse3, reverse_sequential_wrapper2) + from torch.ao.quantization.qconfig_mapping import \ + _FIXED_QPARAMS_OP_TO_OBSERVER +except ImportError: + from mmrazor.utils import get_package_placeholder, get_placeholder + F = get_package_placeholder('torch>=1.13') + nni = get_package_placeholder('torch>=1.13') + nniqat = get_package_placeholder('torch>=1.13') + nnqat = get_package_placeholder('torch>=1.13') + nnqr = get_package_placeholder('torch>=1.13') + BackendPatternConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') + ObservationType = get_placeholder('torch>=1.13') + FixedQParamsFakeQuantize = get_placeholder('torch>=1.13') + fuse_conv_bn = get_placeholder('torch>=1.13') + fuse_conv_bn_relu = get_placeholder('torch>=1.13') + fuse_convtranspose_bn = get_placeholder('torch>=1.13') + fuse_linear_bn = get_placeholder('torch>=1.13') + reverse2 = get_placeholder('torch>=1.13') + reverse3 = get_placeholder('torch>=1.13') + reverse_sequential_wrapper2 = get_placeholder('torch>=1.13') + _FIXED_QPARAMS_OP_TO_OBSERVER = get_placeholder('torch>=1.13') _ConvMetadata = namedtuple('_ConvMetadata', [ 'root', 'transpose', 'bn', 'reference', 'transpose_reference', 'fused_conv_relu', 'fused_conv_bn', 'fused_conv_bn_relu', 'qat', 'relu_qat', 'bn_qat', 'bn_relu_qat', 'func' ]) -_Conv1dMetadata = _ConvMetadata(nn.Conv1d, nn.ConvTranspose1d, nn.BatchNorm1d, - nnqr.Conv1d, nnqr.ConvTranspose1d, - nni.ConvReLU1d, nni.ConvBn1d, nni.ConvBnReLU1d, - nnqat.Conv1d, nniqat.ConvReLU1d, - nniqat.ConvBn1d, nniqat.ConvBnReLU1d, F.conv1d) -_Conv2dMetadata = _ConvMetadata(nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, - nnqr.Conv2d, nnqr.ConvTranspose2d, - nni.ConvReLU2d, nni.ConvBn2d, nni.ConvBnReLU2d, - nnqat.Conv2d, nniqat.ConvReLU2d, - nniqat.ConvBn2d, nniqat.ConvBnReLU2d, F.conv2d) -_Conv3dMetadata = _ConvMetadata(nn.Conv3d, nn.ConvTranspose3d, nn.BatchNorm3d, - nnqr.Conv3d, nnqr.ConvTranspose3d, - nni.ConvReLU3d, nni.ConvBn3d, nni.ConvBnReLU3d, - nnqat.Conv3d, nniqat.ConvReLU3d, - nniqat.ConvBn3d, nniqat.ConvBnReLU3d, F.conv3d) + +if digit_version(torch.__version__) >= digit_version('1.13.0'): + _Conv1dMetadata = _ConvMetadata( + nn.Conv1d, nn.ConvTranspose1d, nn.BatchNorm1d, nnqr.Conv1d, + nnqr.ConvTranspose1d, nni.ConvReLU1d, nni.ConvBn1d, nni.ConvBnReLU1d, + nnqat.Conv1d, nniqat.ConvReLU1d, nniqat.ConvBn1d, nniqat.ConvBnReLU1d, + F.conv1d) + _Conv2dMetadata = _ConvMetadata( + nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nnqr.Conv2d, + nnqr.ConvTranspose2d, nni.ConvReLU2d, nni.ConvBn2d, nni.ConvBnReLU2d, + nnqat.Conv2d, nniqat.ConvReLU2d, nniqat.ConvBn2d, nniqat.ConvBnReLU2d, + F.conv2d) + _Conv3dMetadata = _ConvMetadata( + nn.Conv3d, nn.ConvTranspose3d, nn.BatchNorm3d, nnqr.Conv3d, + nnqr.ConvTranspose3d, nni.ConvReLU3d, nni.ConvBn3d, nni.ConvBnReLU3d, + nnqat.Conv3d, nniqat.ConvReLU3d, nniqat.ConvBn3d, nniqat.ConvBnReLU3d, + F.conv3d) +else: + toy_val = _ConvMetadata(*[i for i in range(13)]) + _Conv1dMetadata = toy_val + _Conv2dMetadata = toy_val + _Conv3dMetadata = toy_val def _get_binary_op_configs( diff --git a/mmrazor/structures/quantization/backend_config/mapping.py b/mmrazor/structures/quantization/backend_config/mapping.py index 4c87a73b9..b9cc5372b 100644 --- a/mmrazor/structures/quantization/backend_config/mapping.py +++ b/mmrazor/structures/quantization/backend_config/mapping.py @@ -1,12 +1,23 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmrazor import digit_version from .academic import get_academic_backend_config from .native import get_native_backend_config from .openvino import get_openvino_backend_config from .tensorrt import get_tensorrt_backend_config -BackendConfigs = { - 'academic': get_academic_backend_config(), - 'native': get_native_backend_config(), - 'tensorrt': get_tensorrt_backend_config(), - 'openvino': get_openvino_backend_config() -} +if digit_version(torch.__version__) >= digit_version('1.13.0'): + BackendConfigs = { + 'academic': get_academic_backend_config(), + 'native': get_native_backend_config(), + 'tensorrt': get_tensorrt_backend_config(), + 'openvino': get_openvino_backend_config() + } +else: + BackendConfigs = { + 'academic': None, + 'native': None, + 'tensorrt': None, + 'openvino': None + } diff --git a/mmrazor/structures/quantization/backend_config/native.py b/mmrazor/structures/quantization/backend_config/native.py index d771b6012..94c35d535 100644 --- a/mmrazor/structures/quantization/backend_config/native.py +++ b/mmrazor/structures/quantization/backend_config/native.py @@ -1,6 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig + +try: + from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') from .common_operator_config_utils import ( # noqa: F401,F403 _get_binary_op_configs, _get_bn_configs, _get_cat_config, @@ -8,68 +14,6 @@ _get_fixed_qparams_op_configs, _get_linear_configs, _get_ln_configs, _get_rnn_op_configs, _get_share_qparams_op_configs) -# =================== -# | DTYPE CONFIGS | -# =================== - -# weighted op int8 dtype config -# this is config for ops that has quantized weights, like linear, conv -weighted_op_int8_dtype_config = DTypeConfig( - input_dtype=torch.quint8, - output_dtype=torch.quint8, - weight_dtype=torch.qint8, - bias_dtype=torch.float, -) - -default_op_quint8_dtype_config = DTypeConfig( - input_dtype=torch.quint8, - output_dtype=torch.quint8, -) - -default_dynamic_int8_dtype_config = DTypeConfig( - input_dtype=torch.quint8, - output_dtype=torch.float, - weight_dtype=torch.qint8, - bias_dtype=torch.float, - # currently the dtype check is not yet enabled, so we provided the - # dtype_configs but it is not really used yet, - # we will enable it a bit later after we moved everything to - # backend_config_dict - is_dynamic=True, -) - -default_dynamic_float16_dtype_config = DTypeConfig( - input_dtype=torch.float16, - output_dtype=torch.float, - weight_dtype=torch.float16, - bias_dtype=torch.float, - # currently the dtype check is not yet enabled, so we provided the - # dtype_configs but it is not really used yet, we will enable it a bit - # later after we moved everything to backend_config_dict - is_dynamic=True, -) - -# Needed for LayerNorm and f.layer_norm, since currently the kernel only -# supports float weights -input_output_only_quint8_dtype_config = DTypeConfig( - input_dtype=torch.quint8, - output_dtype=torch.quint8, - weight_dtype=torch.float, - bias_dtype=torch.float, -) - -weight_only_quint8_dtype_config = DTypeConfig( - input_dtype=torch.float, - output_dtype=torch.float, - weight_dtype=torch.quint8, -) - -weight_only_quint4x2_dtype_config = DTypeConfig( - input_dtype=torch.float, - output_dtype=torch.float, - weight_dtype=torch.quint4x2, -) - # ===================== # | BACKEND CONFIGS | # ===================== @@ -80,6 +24,68 @@ def get_native_backend_config() -> BackendConfig: (fbgemm/qnnpack).""" # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK # BackendConfigs + + # =================== + # | DTYPE CONFIGS | + # =================== + # weighted op int8 dtype config + # this is config for ops that has quantized weights, like linear, conv + weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + + default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + ) + + default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the + # dtype_configs but it is not really used yet, + # we will enable it a bit later after we moved everything to + # backend_config_dict + is_dynamic=True, + ) + + default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the + # dtype_configs but it is not really used yet, we will enable it a bit + # later after we moved everything to backend_config_dict + is_dynamic=True, + ) + + # Needed for LayerNorm and f.layer_norm, since currently the kernel only + # supports float weights + input_output_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.float, + bias_dtype=torch.float, + ) + + weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, + ) + + weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, + ) + conv_dtype_configs = [weighted_op_int8_dtype_config] linear_dtype_configs = [ weighted_op_int8_dtype_config, diff --git a/mmrazor/structures/quantization/backend_config/openvino.py b/mmrazor/structures/quantization/backend_config/openvino.py index fd24eed17..d990d4ef9 100644 --- a/mmrazor/structures/quantization/backend_config/openvino.py +++ b/mmrazor/structures/quantization/backend_config/openvino.py @@ -1,8 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization.backend_config import (BackendConfig, - BackendPatternConfig, - DTypeConfig, ObservationType) + +try: + from torch.ao.quantization.backend_config import (BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType) +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + BackendPatternConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') + ObservationType = get_placeholder('torch>=1.13') from .common_operator_config_utils import (_get_binary_op_configs, _get_conv_configs, diff --git a/mmrazor/structures/quantization/backend_config/tensorrt.py b/mmrazor/structures/quantization/backend_config/tensorrt.py index abb585c6a..53305f650 100644 --- a/mmrazor/structures/quantization/backend_config/tensorrt.py +++ b/mmrazor/structures/quantization/backend_config/tensorrt.py @@ -1,8 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization.backend_config import (BackendConfig, - BackendPatternConfig, - DTypeConfig, ObservationType) + +try: + from torch.ao.quantization.backend_config import (BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType) +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + BackendPatternConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') + ObservationType = get_placeholder('torch>=1.13') from .common_operator_config_utils import (_get_binary_op_configs, _get_conv_configs, diff --git a/mmrazor/structures/quantization/qconfig.py b/mmrazor/structures/quantization/qconfig.py index 3dca49730..e0fdf113d 100644 --- a/mmrazor/structures/quantization/qconfig.py +++ b/mmrazor/structures/quantization/qconfig.py @@ -3,7 +3,12 @@ import torch from mmengine.config import Config -from torch.ao.quantization import QConfig + +try: + from torch.ao.quantization import QConfig +except ImportError: + from mmrazor.utils import get_placeholder + QConfig = get_placeholder('torch>=1.13') from mmrazor.registry import MODELS diff --git a/tests/data/models.py b/tests/data/models.py index 33fb0c624..0347b9147 100644 --- a/tests/data/models.py +++ b/tests/data/models.py @@ -78,7 +78,6 @@ def untracable_method(self, x): x = x * -2 return x - @MODELS.register_module() class UntracableBackBone(nn.Module): @@ -123,7 +122,6 @@ def forward(self, x): x_last = self.conv2(x_attn) return self.head(x_last) - @MODELS.register_module() class LinearHeadForTest(Module): @@ -704,7 +702,6 @@ def current_choice(self): def current_choice(self, choice): super().current_choice(choice) - class DynamicLinearModel(nn.Module): """ x diff --git a/tests/test_data.py b/tests/test_data.py index df3e07f69..d56a2950b 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -6,8 +6,13 @@ from .data.model_library import (DefaultModelLibrary, MMClsModelLibrary, MMDetModelLibrary, MMModelLibrary, +<<<<<<< HEAD MMPoseModelLibrary, MMSegModelLibrary, ModelGenerator, TorchModelLibrary) +======= + MMSegModelLibrary, ModelGenerator, + TorchModelLibrary) +>>>>>>> 985a611e (Merge dev-1.x into quantize (#430)) from .data.models import SingleLineModel from .data.tracer_passed_models import (BackwardPassedModelManager, FxPassedModelManager) @@ -45,6 +50,7 @@ def test_mmseg(self): if not TEST_DATA: self.skipTest('not test data to save time.') library = MMSegModelLibrary() +<<<<<<< HEAD print(library.short_names()) self.assertTrue(library.is_default_includes_cover_all_models()) @@ -55,6 +61,8 @@ def test_mmpose(self): self.skipTest('not test data to save time.') library = MMPoseModelLibrary() print(library.short_names()) +======= +>>>>>>> 985a611e (Merge dev-1.x into quantize (#430)) self.assertTrue(library.is_default_includes_cover_all_models()) def test_get_model_by_config(self): diff --git a/tests/test_models/test_mutators/test_value_mutator.py b/tests/test_models/test_mutators/test_value_mutator.py new file mode 100644 index 000000000..a76257a9e --- /dev/null +++ b/tests/test_models/test_mutators/test_value_mutator.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch + +from mmrazor.models.mutables import MutableValue +from mmrazor.models.mutators import DynamicValueMutator +from tests.data.models import DynamicAttention, DynamicMMBlock + + +class TestValueMutator(unittest.TestCase): + + def test_models_with_predefined_dynamic_op(self): + for Model in [ + DynamicAttention, + ]: + with self.subTest(model=Model): + model = Model() + value_mutator = DynamicValueMutator() + value_mutator.prepare_from_supernet(model) + value_choices = value_mutator.sample_choices() + value_mutator.set_choices(value_choices) + + mutable_value_space = [] + for mutable_value, module in model.named_modules(): + if isinstance(module, MutableValue): + mutable_value_space.append(mutable_value) + elif hasattr(module, 'source_mutables'): + for each_mutables in module.source_mutables: + if isinstance(each_mutables, MutableValue): + mutable_value_space.append(each_mutables) + assert len( + value_mutator.search_groups) == len(mutable_value_space) + + x = torch.rand([2, 3, 224, 224]) + y = model(x) + self.assertEqual(list(y.shape), [2, 624]) + + def test_models_with_multiple_value(self): + for Model in [ + DynamicMMBlock, + ]: + with self.subTest(model=Model): + model = Model() + value_mutator = DynamicValueMutator() + value_mutator.prepare_from_supernet(model) + value_choices = value_mutator.sample_choices() + value_mutator.set_choices(value_choices) + + # TODO check DynamicMMBlock + mutable_value_space = [] + for mutable_value, module in model.named_modules(): + if isinstance(module, MutableValue): + mutable_value_space.append(mutable_value) + elif hasattr(module, 'source_mutables'): + for each_mutables in module.source_mutables: + if isinstance(each_mutables, MutableValue): + mutable_value_space.append(each_mutables) + count = 0 + for values in value_mutator.search_groups.values(): + count += len(values) + assert count == len(mutable_value_space) + + x = torch.rand([2, 3, 224, 224]) + y = model(x) + self.assertEqual(list(y[-1].shape), [2, 1984, 1, 1]) diff --git a/tests/test_models/test_task_modules/test_custom_tracer.py b/tests/test_models/test_task_modules/test_custom_tracer.py deleted file mode 100644 index 671922f69..000000000 --- a/tests/test_models/test_task_modules/test_custom_tracer.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest import TestCase - -from mmrazor.models.task_modules import CustomTracer, UntracedMethodRegistry -from mmrazor.testing import ConvBNReLU - - -class testCustomTracer(TestCase): - - def test_init(self): - tracer = CustomTracer() - assert tracer.skipped_methods.__len__() == 0 - - def test_trace(self): - tracer = CustomTracer() - model = ConvBNReLU(3, 3, norm_cfg=dict(type='BN')) - graph = tracer.trace(model) # noqa: F841 - - def test_auto_skip_call_module(self): - pass - - def test_auto_skip_call_method(self): - pass - - def test_configurable_skipped_methods(self): - pass - - -class testUntracedMethodRgistry(TestCase): - - def test_init(self): - self.assertEqual(len(UntracedMethodRegistry.method_dict), 0) - - def test_add_method(self): - pass diff --git a/tests/test_models/test_task_modules/test_graph_utils.py b/tests/test_models/test_task_modules/test_graph_utils.py index 7250bee95..d8f53c03c 100644 --- a/tests/test_models/test_task_modules/test_graph_utils.py +++ b/tests/test_models/test_task_modules/test_graph_utils.py @@ -4,13 +4,21 @@ import torch import torch.nn as nn -from torch.ao.quantization import QConfigMapping -from torch.ao.quantization.fake_quantize import FakeQuantizeBase -from torch.ao.quantization.fx import prepare -from torch.ao.quantization.quantize_fx import _fuse_fx -from mmrazor.models.task_modules import build_graphmodule -from mmrazor.models.task_modules.tracer import CustomTracer +try: + from torch.ao.quantization import QConfigMapping + from torch.ao.quantization.fake_quantize import FakeQuantizeBase + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.quantize_fx import _fuse_fx +except ImportError: + from mmrazor.utils import get_placeholder + QConfigMapping = get_placeholder('torch>=1.13') + FakeQuantizeBase = get_placeholder('torch>=1.13') + prepare = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.task_modules.tracer import CustomTracer, build_graphmodule from mmrazor.models.task_modules.tracer.fx import ( del_fakequant_after_function, del_fakequant_after_method, del_fakequant_after_module, del_fakequant_after_op, @@ -106,6 +114,9 @@ def forward(self, x): class TestGraphUtils(TestCase): def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.tracer = CustomTracer() self.backend_config = BackendConfigs['native'] self.qconfig = QConfigHander(global_qconfig) @@ -114,6 +125,9 @@ def setUp(self): self.example_inputs = (torch.randn(1, 3, 224, 224), ) def swap_ff_with_fxff(self, model): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + modules_to_swap = [] for name, module in model.named_children(): if isinstance(module, torch.ao.nn.quantized.FloatFunctional): @@ -126,6 +140,9 @@ def swap_ff_with_fxff(self, model): model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() def test_del_fakequant_before_op(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -170,6 +187,9 @@ def test_del_fakequant_before_op(self): _get_attrs(prepared, args[0].target), FakeQuantizeBase) def test_del_fakequant_after_op(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -211,6 +231,8 @@ def test_del_fakequant_after_op(self): _get_attrs(prepared, node.next.target), FakeQuantizeBase) def test_del_fakequant_before_method(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') model_to_quantize = ToyModel() model_to_quantize.eval() @@ -259,6 +281,9 @@ def test_del_fakequant_before_method(self): _get_attrs(prepared, args[0].target), FakeQuantizeBase) def test_del_fakequant_after_method(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -303,6 +328,9 @@ def test_del_fakequant_after_method(self): _get_attrs(prepared, node.next.target), FakeQuantizeBase) def test_del_fakequant_before_function(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -356,6 +384,9 @@ def test_del_fakequant_before_function(self): _get_attrs(prepared, args[1].target), FakeQuantizeBase) def test_del_fakequant_after_function(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -400,6 +431,9 @@ def test_del_fakequant_after_function(self): _get_attrs(prepared, node.next.target), FakeQuantizeBase) def test_del_fakequant_before_module(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -452,6 +486,9 @@ def test_del_fakequant_before_module(self): _get_attrs(prepared, args[0].target), FakeQuantizeBase) def test_del_fakequant_after_module(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index 009640684..c8340f352 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -12,6 +12,8 @@ from mmrazor.models.algorithms.base import BaseAlgorithm from mmrazor.models.mutables import OneShotMutableOP from mmrazor.registry import MODELS +from mmrazor.structures import load_fix_subnet +from mmrazor.utils import ValidFixMutable @MODELS.register_module() @@ -44,13 +46,15 @@ class MockAlgorithm(BaseAlgorithm): def __init__(self, architecture: Union[BaseModel, Dict], - _return_architecture_: Optional[bool] = None): + fix_subnet: Optional[ValidFixMutable] = None): super().__init__(architecture) - if _return_architecture_ is True: - self.return_model = self.architecture + if fix_subnet is not None: + # According to fix_subnet, delete the unchosen part of supernet + load_fix_subnet(self, fix_subnet, prefix='architecture.') + self.is_supernet = False else: - self.return_model = self + self.is_supernet = True class TestRegistry(TestCase): @@ -68,18 +72,34 @@ def test_build_razor_from_cfg(self): # model = MODELS.build(self.arch_cfg_path) # self.assertIsNotNone(model) - # test return architecture + # test fix subnet cfg = Config.fromfile( - 'tests/data/test_registry/registry_architecture_config.py') + 'tests/data/test_registry/registry_subnet_config.py') model = MODELS.build(cfg.model) - self.assertTrue(isinstance(model.return_model, MockModel)) - # test return model + # test return architecture cfg = Config.fromfile( 'tests/data/test_registry/registry_architecture_config.py') - cfg.model.pop('_return_architecture_') model = MODELS.build(cfg.model) - self.assertTrue(isinstance(model.return_model, MockAlgorithm)) + self.assertTrue(isinstance(model, BaseModel)) + + def test_build_subnet_prune_from_cfg(self): + mutator_cfg = fileio.load('tests/data/test_registry/subnet.json') + init_cfg = dict( + type='Pretrained', + checkpoint='tests/data/test_registry/subnet_weight.pth') + # test fix subnet + model_cfg = dict( + # use mmrazor's build_func + type='mmrazor.sub_model', + cfg=dict( + cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', + pretrained=False), + fix_subnet=mutator_cfg, + mode='mutator', + init_cfg=init_cfg) + model = MODELS.build(model_cfg) + self.assertTrue(isinstance(model, BaseModel)) def test_build_subnet_prune_from_cfg_by_mutator(self): mutator_cfg = fileio.load('tests/data/test_registry/subnet.json') diff --git a/tests/test_structures/test_qconfig.py b/tests/test_structures/test_qconfig.py index 045b02c83..4730ab6cc 100644 --- a/tests/test_structures/test_qconfig.py +++ b/tests/test_structures/test_qconfig.py @@ -4,8 +4,14 @@ import torch from mmengine.config import Config -from torch.ao.quantization import QConfig +try: + from torch.ao.quantization import QConfig +except ImportError: + from mmrazor.utils import get_placeholder + QConfig = get_placeholder('torch>=1.13') + +from mmrazor import digit_version from mmrazor.models.fake_quants import register_torch_fake_quants from mmrazor.models.observers import register_torch_observers from mmrazor.structures import QConfigHander, QSchemeHander @@ -17,6 +23,9 @@ class TestQSchemeHander(TestCase): def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + # per_channel qscheme = QSchemeHander(is_symmetry=True, is_per_channel=True) assert qscheme.torch_qscheme is torch.per_channel_symmetric @@ -34,6 +43,9 @@ def test_init(self): assert qscheme.is_symmetric_range is True def test_to_observer_params(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + # qdtype = quint8 ret_params = QSchemeHander(qdtype='quint8').to_observer_params() assert ret_params['dtype'] == torch.quint8 @@ -78,6 +90,9 @@ def setUp(self): self.qconfig = Config(self.qconfig_dict) def test_check_qconfig(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + assert QConfigHander.check_qconfig(self.qconfig_dict) is True assert QConfigHander.check_qconfig(self.qconfig) is True qconfig_dict = copy.copy(self.qconfig_dict) @@ -86,6 +101,9 @@ def test_check_qconfig(self): assert QConfigHander.check_qconfig(qconfig_dict) is False def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + # test dict init qconfig = QConfigHander(self.qconfig_dict) assert hasattr(qconfig, 'w_qscheme') @@ -105,6 +123,9 @@ def test_init(self): assert qconfig.a_qscheme.is_per_channel is True def test_convert(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + qconfig = QConfigHander(self.qconfig) torch_qconfig = qconfig.convert() assert isinstance(torch_qconfig, QConfig)