From 454f39781d52d7018185c0c595f75c1d49055256 Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Thu, 25 May 2023 16:50:09 +0800 Subject: [PATCH] [Feature] Merge dev-large into main (#543) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add sparse gpt (#499) init Co-authored-by: liukai * enhence sparsegpt (#505) * update * fix bug * fix bug * update opt * add memory efficient forward for opt * support to set device for pruning --------- Co-authored-by: liukai Co-authored-by: Your Name * Lk large (#510) * update * update --------- Co-authored-by: liukai * refine sparse gpt, support multiple gpus with fsdp (#520) * add mmrazor large * update readme * add fsdp for opt * update * update * rename * update args * support fsdp * refine * refine * refine * refine * fix out of memorry bug --------- Co-authored-by: liukai Co-authored-by: Your Name * refine sparse gpt (#526) * save cpu memory * update * update * update * update * refine * update * update --------- Co-authored-by: Your Name * merge main (#527) * fix bug for autoslim (#511) * fix bug for autoslim * delete resnet50 for dmcp --------- Co-authored-by: liukai * Add timm (#512) * add timm to optional.txt * fix deit paths * [Feature] Add MMRazor quantization (#513) * [FEATURE] add quant algo `Learned Step Size Quantization` (#346) * update * Fix a bug in make_divisible. (#333) fix bug in make_divisible Co-authored-by: liukai * [Fix] Fix counter mapping bug (#331) * fix counter mapping bug * move judgment into get_counter_type & update UT * [Docs]Add MMYOLO projects link (#334) * [Doc] fix typos in en/usr_guides (#299) * Update README.md * Update README_zh-CN.md Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> * [Features]Support `MethodInputsRecorder` and `FunctionInputsRecorder` (#320) * support MethodInputsRecorder and FunctionInputsRecorder * fix bugs that the model can not be pickled * WIP: add pytest for ema model * fix bugs in recorder and delivery when ema_hook is used * don't register the DummyDataset * fix pytest * updated * retina loss & predict & tesnor DONE * [Feature] Add deit-base (#332) * WIP: support deit * WIP: add deithead * WIP: fix checkpoint hook * fix data preprocessor * fix cfg * WIP: add readme * reset single_teacher_distill * add metafile * add model to model-index * fix configs and readme * [Feature]Feature map visualization (#293) * WIP: vis * WIP: add visualization * WIP: add visualization hook * WIP: support razor visualizer * WIP * WIP: wrap draw_featmap * support feature map visualization * add a demo image for visualization * fix typos * change eps to 1e-6 * add pytest for visualization * fix vis hook * fix arguments' name * fix img path * support draw inference results * add visualization doc * fix figure url * move files Co-authored-by: weihan cao * [Feature] Add kd examples (#305) * support kd for mbv2 and shufflenetv2 * WIP: fix ckpt path * WIP: fix kd r34-r18 * add metafile * fix metafile * delete * [Doc] add documents about pruning. (#313) * init * update user guide * update images * update * update How to prune your model * update how_to_use_config_tool_of_pruning.md * update doc * move location * update * update * update * add mutablechannels.md * add references Co-authored-by: liukai Co-authored-by: jacky * [Feature] PyTorch version of `PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient`. (#304) * add pkd * add pytest for pkd * fix cfg * WIP: support fcos3d * WIP: support fcos3d pkd * support mmdet3d * fix cfgs * change eps to 1e-6 and add some comments * fix docstring * fix cfg * add assert * add type hint * WIP: add readme and metafile * fix readme * update metafiles and readme * fix metafile * fix pipeline figure * for RFC * Customed FX initialize * add UT init * [Refactor] Refactor Mutables and Mutators (#324) * refactor mutables * update load fix subnet * add DumpChosen Typehint * adapt UTs * fix lint * Add GroupMixin to ChannelMutator (temporarily) * fix type hints * add GroupMixin doc-string * modified by comments * fix type hits * update subnet format * fix channel group bugs and add UTs * fix doc string * fix comments * refactor diff module forward * fix error in channel mutator doc * fix comments Co-authored-by: liukai * [Fix] Update readme (#341) * update kl readme * update dsnas readme * fix url * Bump version to 1.0.0rc1 (#338) update version * init demo * add customer_tracer * add quantizer * add fake_quant, loop, config * remove CPatcher in custome_tracer * demo_try * init version * modified base.py * pre-rebase * wip of adaround series * adaround experiment * trasfer to s2 * update api * point at sub_reconstruction * pre-checkout * export onnx * add customtracer * fix lint * move custom tracer * fix import * TDO: UTs * Successfully RUN * update loop * update loop docstrings * update quantizer docstrings * update qscheme docstrings * update qobserver docstrings * update tracer docstrings * update UTs init * update UTs init * fix review comments * fix CI * fix UTs * update torch requirements Co-authored-by: huangpengsheng Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai Co-authored-by: Yang Gao Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com> Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: jacky Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com> Co-authored-by: humu789 * [Features]Quantize pipeline (#350) * init demo * add customer_tracer * add quantizer * add fake_quant, loop, config * remove CPatcher in custome_tracer * demo_try * init version * modified base.py * pre-rebase * wip of adaround series * adaround experiment * trasfer to s2 * update api * point at sub_reconstruction * pre-checkout * export onnx * add customtracer * fix lint * move custom tracer * fix import * update * updated * retina loss & predict & tesnor DONE * for RFC * Customed FX initialize * add UT init * TDO: UTs * Successfully RUN * update loop * update loop docstrings * update quantizer docstrings * update qscheme docstrings * update qobserver docstrings * update tracer docstrings * update UTs init * update UTs init * fix bugs * fix lsq * refactor quantize pipeline * fix quant * WIP: debug qat * fix lsq bugs * fix qat, docstring in progress * TDO: UTs * fix bugs * fix lsq * refactor quantize pipeline * fix quant * WIP: debug qat * fix lsq bugs * fix qat, docstring in progress * fixed DefaultQconfigs name * fix bugs * add comments and fix typos * delete useless codes * fix bugs and add comments * rename prepare_module_dict * update lsq config Co-authored-by: humu789 Co-authored-by: huangpengsheng Co-authored-by: FreakieHuang Co-authored-by: pppppM * [Feature] Add `prepare_for_mmdeploy` interface (#365) * remove useless code * fix build graph module import bug * refactor general quant * rename GeneralQuant to MMArchitectureQuant * fix some dtype bugs * add prepare_for_mmdeploy interface * update prepare for mmdeploy args * fix some comments Co-authored-by: humu789 * CodeCamp #132 add MinMaxFloorObserver (#376) * add minmaxfloor_observer.py * add MinMaxFloorObserver and normative docstring * add test for MinMaxFloorObserver * Quant go (#409) * add torch observer * add torch fakequant * refactor base quantizer * add QConfigHander and QSchemeHander & finish quantizer_refactor_beta * passed ptq_pipeline * tmp-commit * fix loop and algorithm * delete fakequant * refactor code structure * remove lsq * valid ptq pipeline * wip * fix del functions * fix * fix lint and pytest Co-authored-by: HIT-cwh <2892770585@qq.com> * [Refactor & Doc] Refactor graph_utils and add docstring and pytest (#420) * refactor graph_utils and add docstring and pytest * fix del fakequant * delete useless codes * Merge dev-1.x into quantize (#430) * Fix a bug in make_divisible. (#333) fix bug in make_divisible Co-authored-by: liukai * [Fix] Fix counter mapping bug (#331) * fix counter mapping bug * move judgment into get_counter_type & update UT * [Docs]Add MMYOLO projects link (#334) * [Doc] fix typos in en/usr_guides (#299) * Update README.md * Update README_zh-CN.md Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> * [Features]Support `MethodInputsRecorder` and `FunctionInputsRecorder` (#320) * support MethodInputsRecorder and FunctionInputsRecorder * fix bugs that the model can not be pickled * WIP: add pytest for ema model * fix bugs in recorder and delivery when ema_hook is used * don't register the DummyDataset * fix pytest * [Feature] Add deit-base (#332) * WIP: support deit * WIP: add deithead * WIP: fix checkpoint hook * fix data preprocessor * fix cfg * WIP: add readme * reset single_teacher_distill * add metafile * add model to model-index * fix configs and readme * [Feature]Feature map visualization (#293) * WIP: vis * WIP: add visualization * WIP: add visualization hook * WIP: support razor visualizer * WIP * WIP: wrap draw_featmap * support feature map visualization * add a demo image for visualization * fix typos * change eps to 1e-6 * add pytest for visualization * fix vis hook * fix arguments' name * fix img path * support draw inference results * add visualization doc * fix figure url * move files Co-authored-by: weihan cao * [Feature] Add kd examples (#305) * support kd for mbv2 and shufflenetv2 * WIP: fix ckpt path * WIP: fix kd r34-r18 * add metafile * fix metafile * delete * [Doc] add documents about pruning. (#313) * init * update user guide * update images * update * update How to prune your model * update how_to_use_config_tool_of_pruning.md * update doc * move location * update * update * update * add mutablechannels.md * add references Co-authored-by: liukai Co-authored-by: jacky * [Feature] PyTorch version of `PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient`. (#304) * add pkd * add pytest for pkd * fix cfg * WIP: support fcos3d * WIP: support fcos3d pkd * support mmdet3d * fix cfgs * change eps to 1e-6 and add some comments * fix docstring * fix cfg * add assert * add type hint * WIP: add readme and metafile * fix readme * update metafiles and readme * fix metafile * fix pipeline figure * [Refactor] Refactor Mutables and Mutators (#324) * refactor mutables * update load fix subnet * add DumpChosen Typehint * adapt UTs * fix lint * Add GroupMixin to ChannelMutator (temporarily) * fix type hints * add GroupMixin doc-string * modified by comments * fix type hits * update subnet format * fix channel group bugs and add UTs * fix doc string * fix comments * refactor diff module forward * fix error in channel mutator doc * fix comments Co-authored-by: liukai * [Fix] Update readme (#341) * update kl readme * update dsnas readme * fix url * Bump version to 1.0.0rc1 (#338) update version * [Feature] Add Autoformer algorithm (#315) * update candidates * update subnet_sampler_loop * update candidate * add readme * rename variable * rename variable * clean * update * add doc string * Revert "[Improvement] Support for candidate multiple dimensional search constraints." * [Improvement] Update Candidate with multi-dim search constraints. (#322) * update doc * add support type * clean code * update candidates * clean * xx * set_resource -> set_score * fix ci bug * py36 lint * fix bug * fix check constrain * py36 ci * redesign candidate * fix pre-commit * update cfg * add build_resource_estimator * fix ci bug * remove runner.epoch in testcase * [Feature] Autoformer architecture and dynamicOPs (#327) * add DynamicSequential * dynamiclayernorm * add dynamic_pathchembed * add DynamicMultiheadAttention and DynamicRelativePosition2D * add channel-level dynamicOP * add autoformer algo * clean notes * adapt channel_mutator * vit fly * fix import * mutable init * remove annotation * add DynamicInputResizer * add unittest for mutables * add OneShotMutableChannelUnit_VIT * clean code * reset unit for vit * remove attr * add autoformer backbone UT * add valuemutator UT * clean code * add autoformer algo UT * update classifier UT * fix test error * ignore * make lint * update * fix lint * mutable_attrs * fix test * fix error * remove DynamicInputResizer * fix test ci * remove InputResizer * rename variables * modify type * Continued improvements of ChannelUnit * fix lint * fix lint * remove OneShotMutableChannelUnit * adjust derived type * combination mixins * clean code * fix sample subnet * search loop fly * more annotations * avoid counter warning and modify batch_augment cfg by gy * restore * source_value_mutables restriction * simply arch_setting api * update * clean * fix ut * [Feature] Add performance predictor (#306) * add predictor with 4 handlers * [Improvement] Update Candidate with multi-dim search constraints. (#322) * update doc * add support type * clean code * update candidates * clean * xx * set_resource -> set_score * fix ci bug * py36 lint * fix bug * fix check constrain * py36 ci * redesign candidate * fix pre-commit * update cfg * add build_resource_estimator * fix ci bug * remove runner.epoch in testcase * update metric_predictor: 1. update MetricPredictor; 2. add predictor config for searching; 3. add predictor in evolution_search_loop. * add UT for predictor * add MLPHandler * patch optional.txt for predictors * patch test_evolution_search_loop * refactor apis of predictor and handlers * fix ut and remove predictor_cfg in predictor * adapt new mutable & mutator design * fix ut * remove unness assert after rebase * move predictor-build in __init__ & simplify estimator-build Co-authored-by: Yue Sun * [Feature] Add DCFF (#295) * add ChannelGroup (#250) * rebase new dev-1.x * modification for adding config_template * add docstring to channel_group.py * add docstring to mutable_channel_group.py * rm channel_group_cfg from Graph2ChannelGroups * change choice type of SequentialChannelGroup from float to int * add a warning about group-wise conv * restore __init__ of dynamic op * in_channel_mutable -> mutable_in_channel * rm abstractproperty * add a comment about VT * rm registry for ChannelGroup * MUTABLECHANNELGROUP -> ChannelGroupType * refine docstring of IndexDict * update docstring * update docstring * is_prunable -> is_mutable * update docstring * fix error in pre-commit * update unittest * add return type * unify init_xxx apit * add unitest about init of MutableChannelGroup * update according to reviews * sequential_channel_group -> sequential_mutable_channel_group Co-authored-by: liukai * Add BaseChannelMutator and refactor Autoslim (#289) * add BaseChannelMutator * add autoslim * tmp * make SequentialMutableChannelGroup accpeted both of num and ratio as choice. and supports divisior * update OneShotMutableChannelGroup * pass supernet training of autoslim * refine autoslim * fix bug in OneShotMutableChannelGroup * refactor make_divisible * fix spell error: channl -> channel * init_using_backward_tracer -> init_from_backward_tracer init_from_fx_tracer -> init_from_fx_tracer * refine SequentialMutableChannelGroup * let mutator support models with dynamicop * support define search space in model * tracer_cfg -> parse_cfg * refine * using -> from * update docstring * update docstring Co-authored-by: liukai * tmpsave * migrate ut * tmpsave2 * add loss collector * refactor slimmable and add l1-norm (#291) * refactor slimmable and add l1-norm * make l1-norm support convnd * update get_channel_groups * add l1-norm_resnet34_8xb32_in1k.py * add pretrained to resnet34-l1 * remove old channel mutator * BaseChannelMutator -> ChannelMutator * update according to reviews * add readme to l1-norm * MBV2_slimmable -> MBV2_slimmable_config Co-authored-by: liukai * update config * fix md & pytorch support <1.9.0 in batchnorm init * Clean old codes. (#296) * remove old dynamic ops * move dynamic ops * clean old mutable_channels * rm OneShotMutableChannel * rm MutableChannel * refine * refine * use SquentialMutableChannel to replace OneshotMutableChannel * refactor dynamicops folder * let SquentialMutableChannel support float Co-authored-by: liukai * fix ci * ci fix py3.6.x & add mmpose * ci fix py3.6.9 in utils/index_dict.py * fix mmpose * minimum_version_cpu=3.7 * fix ci 3.7.13 * fix pruning &meta ci * support python3.6.9 * fix py3.6 import caused by circular import patch in py3.7 * fix py3.6.9 * Add channel-flow (#301) * base_channel_mutator -> channel_mutator * init * update docstring * allow omitting redundant configs for channel * add register_mutable_channel_to_a_module to MutableChannelContainer * update according to reviews 1 * update according to reviews 2 * update according to reviews 3 * remove old docstring * fix error * using->from * update according to reviews * support self-define input channel number * update docstring * chanenl -> channel_elem Co-authored-by: liukai Co-authored-by: jacky * support >=3.7 * support py3.6.9 * Rename: ChannelGroup -> ChannelUnit (#302) * refine repr of MutableChannelGroup * rename folder name * ChannelGroup -> ChannelUnit * filename in units folder * channel_group -> channel_unit * groups -> units * group -> unit * update * get_mutable_channel_groups -> get_mutable_channel_units * fix bug * refine docstring * fix ci * fix bug in tracer Co-authored-by: liukai * update new channel config format * update pruning refactor * update merged pruning * update commit * fix dynamic_conv_mixin * update comments: readme&dynamic_conv_mixins.py * update readme * move kl softmax channel pooling to op by comments * fix comments: fix redundant & split README.md * dcff in ItePruneAlgorithm * partial dynamic params for fuseconv * add step_freq & prune_time check * update comments * update comments * update comments * fix ut * fix gpu ut & revise step_freq in ItePruneAlgorithm * update readme * revise ItePruneAlgorithm * fix docs * fix dynamic_conv attr * fix ci Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai Co-authored-by: zengyi.vendor Co-authored-by: jacky * [Fix] Fix optional requirements (#357) * fix optional requirements * fix dcff ut * fix import with get_placeholder * supplement the previous commit * [Fix] Fix configs of wrn models and ofd. (#361) * 1.revise the configs of wrn22, wrn24, and wrn40. 2.revise the data_preprocessor of ofd_backbone_resnet50_resnet18_8xb16_cifar10 * 1.Add README for vanilla-wrm. * 1.Revise readme of wrn Co-authored-by: zhangzhongyu * [Fix] Fix bug on mmrazor visualization, mismatch argument in define and use. (#356) fix bug on mmrazor visualization, mismatch argument in define and use. Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com> * fix bug in benchmark_test (#364) fix bug in configs Co-authored-by: Your Name * [FIX] Fix wrn configs (#368) * fix wrn configs * fix wrn configs * update online wrn model weight * [Fix] fix bug on pkd config. Wrong import filename. (#373) * [CI] Update ci to torch1.13 (#380) update ci to torch1.13 * [Feature] Add BigNAS algorithm (#219) * add calibrate-bn-statistics * add test calibrate-bn-statistics * fix mixins * fix mixins * fix mixin tests * remove slimmable channel mutable and refactor dynamic op * refact dynamic batch norm * add progressive dynamic conv2d * add center crop dynamic conv2d * refactor dynamic directory * refactor dynamic sequential * rename length to depth in dynamic sequential * add test for derived mutable * refactor dynamic op * refactor api of dynamic op * add derive mutable mixin * addbignas algorithm * refactor bignas structure * add input resizer * add input resizer to bignas * move input resizer from algorithm into classifier * remove compnents * add attentive mobilenet * delete json file * nearly(less 0.2) align inference accuracy with gml * move mutate seperated in bignas mobilenet backbone * add zero_init_residual * add set_dropout * set dropout in bignas algorithm * fix registry * add subnet yaml and nearly align inference accuracy with gml * add rsb config for bignas * remove base in config * add gml bignas config * convert to iter based * bignas forward and backward fly * fix merge conflict * fix dynamicseq bug * fix bug and refactor bignas * arrange configs of bignas * fix typo * refactor attentive_mobilenet * fix channel mismatch due to registion of DerivedMutable * update bignas & fix se channel mismatch * add AutoAugmentV2 & remove unness configs * fix lint * recover channel assertion in channel unit * fix a group bug * fix comments * add docstring * add norm in dynamic_embed * fix search loop & other minor changes * fix se expansion * minor change * add ut for bignas & attentive_mobilenet * fix ut * update bignas readme * rm unness ut & supplement get_placeholder * fix lint * fix ut * add subnet deployment in downstream tasks. * minor change * update ofa backbone * minor fix * Continued improvements of searchable backbone * minor change * drop ratio in backbone * fix comments * fix ci test * fix test * add dynamic shortcut UT * modify strategy to fit bignas * fix test * fix bug in neck * fix error * fix error * fix yaml * save subnet ckpt * merge autoslim_val/test_loop into subnet_val_loop * move calibrate_bn_mixin to utils * fix bugs and add docstring * clean code * fix register bug * clean code * update Co-authored-by: wangshiguang Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: aptsunny Co-authored-by: sunyue1 * [Bug] Fix ckpt (#372) fix ckpt * [Feature] Add tools to convert distill ckpt to student-only ckpt. (#381) * [Feature] Add tools to convert distill ckpt to student-only ckpt. * fix bug. * add --model-only to only save model. * Make changes accroding to PR review. * Enhance the Abilities of the Tracer for Pruning. (#371) * tmp * add new mmdet models * add docstring * pass test and pre-commit * rm razor tracer * update fx tracer, now it can automatically wrap methods and functions. * update tracer passed models * add warning for torch <1.12.0 fix bug for python3.6 update placeholder to support placeholder.XXX * fix bug * update docs * fix lint * fix parse_cfg in configs * restore mutablechannel * test ite prune algorithm when using dist * add get_model_from_path to MMModelLibrrary * add mm models to DefaultModelLibrary * add uts * fix bug * fix bug * add uts * add uts * add uts * add uts * fix bug * restore ite_prune_algorithm * update doc * PruneTracer -> ChannelAnalyzer * prune_tracer -> channel_analyzer * add test for fxtracer * fix bug * fix bug * PruneTracer -> ChannelAnalyzer refine * CustomFxTracer -> MMFxTracer * fix bug when test with torch<1.12 * update print log * fix lint * rm unuseful code Co-authored-by: liukai Co-authored-by: jacky Co-authored-by: Your Name Co-authored-by: liukai * fix bug in placer holder (#395) * fix bug in placer holder * remove redundent comment Co-authored-by: liukai * Add get_prune_config and a demo config_pruning (#389) * update tools and test * add demo * disable test doc * add switch for test tools and test_doc * fix bug * update doc * update tools name * mv get_channel_units Co-authored-by: liukai * [Improvement] Adapt OFA series with SearchableMobileNetV3 (#385) * fix mutable bug in AttentiveMobileNetV3 * remove unness code * update ATTENTIVE_SUBNET_A0-A6.yaml with optimized names * unify the sampling usage in sandwich_rule-based NAS * use alias to export subnet * update OFA configs * fix attr bug * fix comments * update convert_supernet2subnet.py * correct the way to dump DerivedMutable * fix convert index bug * update OFA configs & models * fix dynamic2static * generalize convert_ofa_ckpt.py * update input_resizer * update README.md * fix ut * update export_fix_subnet * update _dynamic_to_static * update fix_subnet UT & minor fix bugs * fix ut * add new autoaug compared to attentivenas * clean * fix act * fix act_cfg * update fix_subnet * fix lint * add docstring Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: aptsunny * [Fix]Dcff Deploy Revision (#383) * dcff deploy revision * tempsave * update fix_subnet * update mutator load * export/load_fix_subnet revision for mutator * update fix_subnet with dev-1.x * update comments * update docs * update registry * [Fix] Fix commands in README to adapt branch 1.x (#400) * update commands in README for 1.x * fix commands Co-authored-by: gaoyang07 <1546308416@qq.com> * Set requires_grad to False if the teacher is not trainable (#398) * add choice and mask of units to checkpoint (#397) * add choice and mask of units to checkpoint * update * fix bug * remove device operation * fix bug * fix circle ci error * fix error in numpy for circle ci * fix bug in requirements * restore * add a note * a new solution * save mutable_channel.mask as float for dist training * refine * mv meta file test Co-authored-by: liukai Co-authored-by: jacky * [Bug]Fix fpn teacher distill (#388) fix fpn distill * [CodeCamp #122] Support KD algorithm MGD for detection. (#377) * [Feature] Support KD algorithm MGD for detection. * use connector to beauty mgd. * fix typo, add unitest. * fix mgd loss unitest. * fix mgd connector unitest. * add model pth and log file. * add mAP. * update l1 config (#405) * add l1 config * update l1 config Co-authored-by: jacky * [Feature] Add greedy search for AutoSlim (#336) * WIP: add greedysearch * fix greedy search and add bn_training_mode to autoslim * fix cfg files * fix autoslim configs * fix bugs when converting dynamic bn to static bn * change to test loop * refactor greedy search * rebase and fix greedysearch * fix lint * fix and delete useless codes * fix pytest * fix pytest and add bn_training_mode * fix lint * add reference to AutoSlimGreedySearchLoop's docstring * sort candidate_choices * fix save subnet * delete useless codes in channel container * change files' name: convert greedy_search_loop to autoslim_greedy_search_loop * [Fix] Fix metafile (#422) * fix ckpt path in metafile and readme * fix darts file path * fix docstring in ConfigurableDistiller * fix darts * fix error * add darts of mmrazor version * delete py36 Co-authored-by: liukai * update bignas cfg (#412) * check attentivenas training * update ckpt link * update supernet log Co-authored-by: aptsunny * Bump version to 1.0.0rc2 (#423) bump version to 1.0.0rc2 Co-authored-by: liukai * fix lint * fix ci * add tmp docstring for passed ci * add tmp docstring for passed ci * fix ci * add get_placeholder for quant * add skip for unittest * fix package placeholder bug * add version judgement in __init__ * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai Co-authored-by: Yang Gao Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com> Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: jacky Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com> Co-authored-by: Yue Sun Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com> Co-authored-by: zengyi.vendor Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com> Co-authored-by: zhangzhongyu Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com> Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com> Co-authored-by: Your Name Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com> Co-authored-by: wangshiguang Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: sunyue1 Co-authored-by: liukai Co-authored-by: Ming-Hsuan-Tu Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com> Co-authored-by: Yue Sun * [Docs] Add docstring and unittest about backendconfig & observer & fakequant (#428) * add ut about backendconfig * add ut about observers and fakequants in torch * fix torch1.13 ci * [Docs] Add docstring for `MMArchitectureQuant` & `NativeQuantizer` (#425) * add docstring on mm_architecture& native_quantizer * add naive openvino r18 qat config & dist_ptq.sh * Added a more accurate description * unitest&doc * checkpoint url * unitest * passed_pre_commit * unitest on native_quantizer& fix bugs * remove dist_ptq * add get_placeholder&skipTest * complete arg descriptions * fix import bugs * fix pre-commit * add get_placeholder * add typehint and doctring * update docstring&typehint * update docstring * pre-commit * fix some problems * fix bug * [Docs] Add docstring and unitest about custom tracer (#427) * rename QConfigHandler and QSchemeHandler * add docstring about custom tracer * add ut about custom tracer * fix torch1.13 ci * fix lint * fix ci * fix ci * [Docs & Refactor] Add docstring and UT of other quantizers (#439) * add quantizer docstring and refactor the interface of AcademicQuantizer * add AcademicQuantizer unittest * add TensorRTQuantizer and OpenVINOQuantizer unittest & refactor prepare interface * adapt torch113 ci * fix import * fix lint * update some docstring * fix ci * [Feature&Doc]Modify ptq pipeline and support lsq (#435) * modify ptq pipeline and support lsq * use placeholder * fix lsq && quantloop * add lsq pytest * add quant loop pytest * test lsq observer * fix bug under pt13 * fix reset_min_max_vals * fix bugs under pt13 * fix configs * add get_qconfig_mapping * delete is_qat, add doc and fix pytest * delete useless codes in custom_tracer * skip pytest under pt13 * add todo: check freezebn * fix pytest bugs * fix pytest * fix pytest * fix pytest * [Docs] Add customize_quantization_tutorial (#440) * [Docs] Add quantization user guide (#441) * add quantization user guide * fix layout * fix layout * update README * [Bug] Fix del redundant fakequant (#447) fix del redundant fakequant * [Feature] Add onnx exporters (#475) * fix del redundant fakequant * add onnx exporters * fix onnx exporters and add docstring * fix comments * delete useless codes * fix export_onnx in native quantizer --------- Co-authored-by: pppppM * [Feature]Rewrite the origin model during prepare (#488) * add rewriter * add deploy_cfg arg * modify post_process_for_mmdeploy * fix bugs * add det config * [Feature] Using rewriter in mmrazor when building qmodels. (#490) * add rewriter * add deploy_cfg arg * modify post_process_for_mmdeploy * fix bugs * add det config * replace deepcopy * pop detectors' forward * [Feature] Quantization global optimization (#491) * add trtquantizer * unify all fakequant before deploy * move to aide * add yolox config * pre-rebase * add unittest * add a arg of post_process_for_deploy * test trt yolox deploy * opt quantizer interface * fix rebase * add trt r50 config * update trt setting * del redundant code * fix lint * fix ut of quantizers * del redundant file * fix lint * fix some comments * Fix code syntax in UT (#470) Co-authored-by: 王盟 * passed lint and pytest * try to fix ci * [Bug] Try to fix CI (#502) fix lint * [Feature] Support lsq (#501) * support deploy_cfg=None * replace fakequant before load ckpt * add _load_from_state_dict to lsq fakequant * fix pre-commit * test lsq load state dict * change github ci: ubuntu 18.04 to ubuntu 20.04 * get_deploy_model order change back * sync before save ckpt * delete strict=False * test context rewriter * fix pre commit config * try to fix ci * [Bug] Try to fix CI (#502) fix lint --------- Co-authored-by: humu789 Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com> * [Feature] Add exporter pytest (#504) * add exporter pytest * fix bugs * delete useless codes * handle onnx * delete useless codes * [Bug] Fix ci converage setting (#508) fix ci converage * [Bug] Fix codecov (#509) * remove codecov in requirements * try to fix ci * del adaround loss * [BUG] Fix quantization loop (#507) * fix quantization loop * fix quant loop * fix quant loop * fix qat configs * [Bug] Fix ci converage setting (#508) fix ci converage * [Bug] Fix codecov (#509) * remove codecov in requirements * try to fix ci * del adaround loss * add freeze_bn_begin to lsq * delete useless codes --------- Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com> * add test ptq * opt ptq pipeline * refactor quant configs * update config path * add summary analyse tool * fix benchmark_test:detnas_frcnn_shufflenet_subnet_coco_1x.py * update quantization README.md * update quantization metafile, readme, config path * update quantization docs * update git main link in workflow * update benchmark_summary_analyse.py * del dmcp results * [Bug] fix a rebase error (#514) fix a rebase error * [Bug] Fix CI (#515) * fix ci * mmcv2.0 need torch1.8+ * Update CI config and Passed (#516) * test ci * update test.yml based on mmcv2.0.0 * [Docs] Fix cwd test accuary (#517) * test ci * update test.yml based on mmcv2.0.0 * update cwd_logits_pspnet result --------- Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: huangpengsheng Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai Co-authored-by: Yang Gao Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com> Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: jacky Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com> Co-authored-by: FreakieHuang Co-authored-by: pppppM Co-authored-by: L-Icarus <30308843+L-Icarus@users.noreply.github.com> Co-authored-by: HIT-cwh <2892770585@qq.com> Co-authored-by: Yue Sun Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com> Co-authored-by: zengyi.vendor Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com> Co-authored-by: zhangzhongyu Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com> Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com> Co-authored-by: Your Name Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com> Co-authored-by: wangshiguang Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: sunyue1 Co-authored-by: liukai Co-authored-by: Ming-Hsuan-Tu Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com> Co-authored-by: Yue Sun Co-authored-by: Ivan Zhang <51170394+415905716@users.noreply.github.com> Co-authored-by: wm901115nwpu Co-authored-by: 王盟 * [Docs&Feature] Prepare for checkouting default branch and releasing new version (#518) * prepare for checkout default branch * update README.md and model zoo * update installation.md and update dev-1.x links * update README_zh-CN * add changelog * update ci config * update some links in quantization readme * update quantization user guide * update calibrate_dataloader * add interface pop_rewriter_function_record * Bump version to 1.0.0 (#521) * update release time * bump version to 1.0.0 * [CI] Fix merge stage test (#523) fix merge_stage_test in ci --------- Co-authored-by: liukai Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com> Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: huangpengsheng Co-authored-by: liukai Co-authored-by: Yang Gao Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com> Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: jacky Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com> Co-authored-by: FreakieHuang Co-authored-by: pppppM Co-authored-by: L-Icarus <30308843+L-Icarus@users.noreply.github.com> Co-authored-by: HIT-cwh <2892770585@qq.com> Co-authored-by: Yue Sun Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com> Co-authored-by: zengyi.vendor Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com> Co-authored-by: zhangzhongyu Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com> Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com> Co-authored-by: Your Name Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com> Co-authored-by: wangshiguang Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: sunyue1 Co-authored-by: Ming-Hsuan-Tu Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com> Co-authored-by: Yue Sun Co-authored-by: Ivan Zhang <51170394+415905716@users.noreply.github.com> Co-authored-by: wm901115nwpu Co-authored-by: 王盟 * move folders and update readme (#528) * move folders * update readme --------- Co-authored-by: liukai * [Bug] Fix torch2 error (#536) fix torch2 error * [Feature] Add GPTQ and uniform interfaces (#538) * add gptq implementation * pre-checkout * passed resnet example * passed llama example * aglin gptq acc * add activation quantization * uniform interfaces * add gptq readme * update mmrazor_large redame * add gptq opt example * fix sparse_gpt example for opt * fix import Protocol from py37 * fix error function name * fix bug in test * fix bug * fix bug * limit sparsegpt test with torch>=1.12 * add docstring for gptq and sparse_gpt * pre-commit * align acc & add save load ckpt & add ut * fix ut * fix ut * fix ut * fix ut & add torch2.0 for ci * del torch2.0 for ci * fix ut --------- Co-authored-by: FIRST_NAME LAST_NAME --------- Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai Co-authored-by: Your Name Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: huangpengsheng Co-authored-by: liukai Co-authored-by: Yang Gao Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com> Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: jacky Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com> Co-authored-by: FreakieHuang Co-authored-by: pppppM Co-authored-by: L-Icarus <30308843+L-Icarus@users.noreply.github.com> Co-authored-by: HIT-cwh <2892770585@qq.com> Co-authored-by: Yue Sun Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com> Co-authored-by: zengyi.vendor Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com> Co-authored-by: zhangzhongyu Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com> Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com> Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com> Co-authored-by: wangshiguang Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: sunyue1 Co-authored-by: Ming-Hsuan-Tu Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com> Co-authored-by: Yue Sun Co-authored-by: Ivan Zhang <51170394+415905716@users.noreply.github.com> Co-authored-by: wm901115nwpu Co-authored-by: 王盟 Co-authored-by: FIRST_NAME LAST_NAME --- .pre-commit-config.yaml | 1 + README.md | 2 + mmrazor/implementations/pruning/__init__.py | 4 +- .../pruning/sparse_gpt/__init__.py | 9 + .../pruning/sparse_gpt/compressor.py | 106 ++++ .../implementations/pruning/sparse_gpt/ops.py | 278 +++++++++ .../pruning/sparse_gpt/sparse24_utils.py | 10 + .../pruning/sparse_gpt/utils.py | 140 +++++ .../quantization/gptq/__init__.py | 14 + .../quantization/gptq/compressor.py | 146 +++++ .../quantization/gptq/custom_autotune.py | 254 ++++++++ .../implementations/quantization/gptq/gptq.py | 318 ++++++++++ .../implementations/quantization/gptq/ops.py | 566 ++++++++++++++++++ .../quantization/gptq/quantizer.py | 144 +++++ .../quantization/gptq/utils.py | 56 ++ .../common_operator_config_utils.py | 4 +- .../quantization/backend_config/mapping.py | 4 +- mmrazor/utils/log_tools.py | 12 +- projects/mmrazor_large/README.md | 42 ++ projects/mmrazor_large/algorithms/GPTQ.md | 56 ++ .../mmrazor_large/algorithms/SparseGPT.md | 55 ++ .../mmrazor_large/examples/ResNet/README.md | 25 + .../examples/ResNet/resnet18_gptq.py | 187 ++++++ .../examples/ResNet/resnet18_sparse_gpt.py | 137 +++++ .../examples/language_models/LLaMA/README.md | 55 ++ .../language_models/LLaMA/datautils.py | 152 +++++ .../language_models/LLaMA/llama_gptq.py | 162 +++++ .../language_models/LLaMA/llama_sparse_gpt.py | 106 ++++ .../LLaMA/llama_sparse_gpt_fsdp.py | 198 ++++++ .../examples/language_models/LLaMA/utils.py | 173 ++++++ .../examples/language_models/OPT/README.md | 55 ++ .../examples/language_models/OPT/datautils.py | 152 +++++ .../examples/language_models/OPT/opt_gptq.py | 157 +++++ .../language_models/OPT/opt_sparse_gpt.py | 105 ++++ .../OPT/opt_sparse_gpt_fsdp.py | 198 ++++++ .../examples/language_models/OPT/utils.py | 171 ++++++ requirements/tests.txt | 1 + .../test_pruning/test_sparse_gpt/test_op.py | 75 +++ .../test_gptq/test_op_gptq.py | 80 +++ 39 files changed, 4404 insertions(+), 6 deletions(-) create mode 100644 mmrazor/implementations/pruning/sparse_gpt/__init__.py create mode 100644 mmrazor/implementations/pruning/sparse_gpt/compressor.py create mode 100644 mmrazor/implementations/pruning/sparse_gpt/ops.py create mode 100644 mmrazor/implementations/pruning/sparse_gpt/sparse24_utils.py create mode 100644 mmrazor/implementations/pruning/sparse_gpt/utils.py create mode 100644 mmrazor/implementations/quantization/gptq/__init__.py create mode 100644 mmrazor/implementations/quantization/gptq/compressor.py create mode 100644 mmrazor/implementations/quantization/gptq/custom_autotune.py create mode 100644 mmrazor/implementations/quantization/gptq/gptq.py create mode 100644 mmrazor/implementations/quantization/gptq/ops.py create mode 100644 mmrazor/implementations/quantization/gptq/quantizer.py create mode 100644 mmrazor/implementations/quantization/gptq/utils.py create mode 100644 projects/mmrazor_large/README.md create mode 100644 projects/mmrazor_large/algorithms/GPTQ.md create mode 100644 projects/mmrazor_large/algorithms/SparseGPT.md create mode 100644 projects/mmrazor_large/examples/ResNet/README.md create mode 100644 projects/mmrazor_large/examples/ResNet/resnet18_gptq.py create mode 100644 projects/mmrazor_large/examples/ResNet/resnet18_sparse_gpt.py create mode 100644 projects/mmrazor_large/examples/language_models/LLaMA/README.md create mode 100755 projects/mmrazor_large/examples/language_models/LLaMA/datautils.py create mode 100644 projects/mmrazor_large/examples/language_models/LLaMA/llama_gptq.py create mode 100644 projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt.py create mode 100644 projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt_fsdp.py create mode 100644 projects/mmrazor_large/examples/language_models/LLaMA/utils.py create mode 100644 projects/mmrazor_large/examples/language_models/OPT/README.md create mode 100755 projects/mmrazor_large/examples/language_models/OPT/datautils.py create mode 100644 projects/mmrazor_large/examples/language_models/OPT/opt_gptq.py create mode 100644 projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt.py create mode 100644 projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt_fsdp.py create mode 100644 projects/mmrazor_large/examples/language_models/OPT/utils.py create mode 100644 tests/test_impl/test_pruning/test_sparse_gpt/test_op.py create mode 100644 tests/test_impl/test_quantization/test_gptq/test_op_gptq.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cd73ef928..0454da4ce 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -69,4 +69,5 @@ repos: | ^docs | ^configs | ^.*/configs* + | ^projects ) diff --git a/README.md b/README.md index 4dbb364d5..ad92732cf 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,8 @@ English | [简体中文](README_zh-CN.md) +**:star: MMRazor for Large Models** is Available Now! Please refer to [MMRazorLarge](projects/mmrazor_large/README.md) + ## Introduction MMRazor is a model compression toolkit for model slimming and AutoML, which includes 4 mainstream technologies: diff --git a/mmrazor/implementations/pruning/__init__.py b/mmrazor/implementations/pruning/__init__.py index e28ae7dc2..d536adf1f 100644 --- a/mmrazor/implementations/pruning/__init__.py +++ b/mmrazor/implementations/pruning/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import group_fisher +from . import group_fisher, sparse_gpt -__all__ = ['group_fisher'] +__all__ = ['group_fisher', 'sparse_gpt'] diff --git a/mmrazor/implementations/pruning/sparse_gpt/__init__.py b/mmrazor/implementations/pruning/sparse_gpt/__init__.py new file mode 100644 index 000000000..8caae7fd0 --- /dev/null +++ b/mmrazor/implementations/pruning/sparse_gpt/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .compressor import SparseGptCompressor +from .ops import SparseGptLinear, SparseGptMixIn +from .utils import replace_with_dynamic_ops + +__all__ = [ + 'SparseGptLinear', 'SparseGptMixIn', 'replace_with_dynamic_ops', + 'SparseGptCompressor' +] diff --git a/mmrazor/implementations/pruning/sparse_gpt/compressor.py b/mmrazor/implementations/pruning/sparse_gpt/compressor.py new file mode 100644 index 000000000..f5ef42ec6 --- /dev/null +++ b/mmrazor/implementations/pruning/sparse_gpt/compressor.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmrazor.utils import print_log +from .ops import SparseGptConv2d, SparseGptLinear, SparseGptMixIn +from .utils import replace_with_dynamic_ops + + +def to_static_model(model: nn.Module): + """Replace dynamicops with torch modules.""" + from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet, + load_fix_subnet) + fix_subnet = export_fix_subnet(model)[0] + load_fix_subnet(model, fix_subnet) + return model + + +class SparseGptCompressor(): + """The compressor with SparseGPT.""" + + def __init__(self) -> None: + self.model: nn.Module = None + + def prepare(self, + model: nn.Module, + prune_conv=True, + prune_linear=True) -> None: + """Prepare for compressing model.""" + self.model = model + prune_modules: dict = {} + if prune_conv: + prune_modules[nn.Conv2d] = SparseGptConv2d + if prune_linear: + prune_modules[nn.Linear] = SparseGptLinear + replace_with_dynamic_ops(model, prune_modules) + + @classmethod + def to_static_model(cls, model): + """Convert replaced op with the original torch model.""" + return to_static_model(model) + + # hessian + + def register_hessian_hooks(self): + """Register updating hessian hooks for specified ops.""" + for module in self.sparse_ops: + module.register_hessian_hook() + + def remove_hessian_hooks(self): + """Remove updating hessian hooks for specified ops.""" + for module in self.sparse_ops: + module.remove_hessian_hook() + + def init_hessian(self, device=None): + """Init hessian.""" + for op in self.sparse_ops: + op.init_hessian(device=device) + + # prune + def prune(self, + sparsity, + prunen=0, + prunem=0, + blocksize=128, + percdamp=.01, + device=torch.device('cuda')): + """Apply the compression algorithm to the model.""" + for name, module in self.named_sparse_ops: + try: + original_device = next(module.parameters()).device + module: SparseGptMixIn = module.to(device) + error = module.prune( + sparsity=sparsity, + prunen=prunen, + prunem=prunem, + blocksize=blocksize, + percdamp=percdamp, + ) + print_log(f'prune {name} success \t error = {error}') + module.to(original_device) + torch.cuda.empty_cache() + except Exception as e: + print_log(f'prune {name} failed as {e}') + + def prune_24(self, device=torch.device('cuda:0')): + """Apply the compression algorithm to the model with the specified + setting.""" + self.prune(0.5, prunen=2, prunem=4, device=device) + + # ops + + @property + def sparse_ops(self): + """The ops to be applied the algorithm.""" + assert self.model is not None + for module in self.model.modules(): + if isinstance(module, SparseGptMixIn): + yield module + + @property + def named_sparse_ops(self): + """The named ops to be applied the algorithm.""" + for name, module in self.model.named_modules(): + if isinstance(module, SparseGptMixIn): + yield name, module diff --git a/mmrazor/implementations/pruning/sparse_gpt/ops.py b/mmrazor/implementations/pruning/sparse_gpt/ops.py new file mode 100644 index 000000000..0f11b176f --- /dev/null +++ b/mmrazor/implementations/pruning/sparse_gpt/ops.py @@ -0,0 +1,278 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys + +if sys.version_info < (3, 8): + from typing_extensions import Protocol +else: + from typing import Protocol + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from mmrazor.models.architectures.dynamic_ops import (DynamicConv2d, + DynamicLinear) +from .utils import ModuleProtocol, torch_setting + + +class SparseGptMixIn(ModuleProtocol): + """The core algorithm implementation for SparseGpt.""" + + def _sparse_gpt_mix_in_init(self): + """Init mixin.""" + self.sparse_gpt_handles = [] + self.rows = self.weight_matrix.shape[0] + self.columns = self.weight_matrix.shape[1] + + self._hessian: torch.Tensor = None + self.hessian_batch = 0 + + # weight and input adaptive + + @property + def weight_matrix(self): + """Return weight with shape (out in)""" + return self.weight.flatten(1) # out in + + @weight_matrix.setter + def weight_matrix(self, value: torch.Tensor): + """Set weight.""" + with torch.no_grad(): + value = value.reshape(self.weight.shape).to(self.weight.device).to( + self.weight.dtype) + self.weight.data.copy_(value) + + def format_input(self, input: torch.Tensor): + """Return input with shape (B N C)""" + if len(input.shape) == 2: # N C + input = input.unsqueeze(0) # 1 N C + return input + + # compute hessian + + @property + def hessian(self): + """hessian always return float.""" + if dist.is_initialized(): + if dist.get_rank() == 0: + assert self._hessian is not None, 'hessian is not initialized.' + hessian = self._hessian.to(self.weight_matrix.device) + else: + hessian = torch.zeros( + self.columns, + self.columns, + device=self.weight_matrix.device) + dist.broadcast(hessian, 0) + return hessian + else: + return self._hessian + + @hessian.setter + def hessian(self, value: torch.Tensor): + """Set hessian.""" + with torch.no_grad(): + if dist.is_initialized(): + if dist.get_rank() == 0: + assert self._hessian is not None, 'hessian is not initialized.' # noqa + self._hessian.data.copy_( + value.data.to(self._hessian.device)) + else: + self._hessian = None + else: + self._hessian.data.copy_(value.data.to(self._hessian.device)) + + @torch.no_grad() + def update_hessian(self, input: torch.Tensor): + """Update hessian.""" + input = self.format_input(input).float() + H_save = self.hessian + H_save = H_save.to(input.device) + + assert len(input.shape) == 3 + B = input.shape[0] # B N C + input = input.transpose(0, -1).flatten(1) # C D + + H = input @ input.T * 2 # C C + + if dist.is_initialized(): + dist.all_reduce(H) + B *= dist.get_world_size() + H_save = (H_save * self.hessian_batch + H) / (self.hessian_batch + B) + self.hessian = H_save + self.hessian_batch = self.hessian_batch + B + + def register_hessian_hook(self): + """Register updating hessian hook.""" + + @torch.no_grad() + def forward_pre_hook(module: Protocol, input: tuple): + assert len(input) == 1 + self.update_hessian(input[0]) + + handle = self.register_forward_pre_hook(forward_pre_hook) + self.sparse_gpt_handles.append(handle) + + def remove_hessian_hook(self): + """Remove updating hessian hook.""" + for h in self.sparse_gpt_handles: + h.remove() + + def init_hessian(self, device=None): + """Init hessian.""" + if dist.is_initialized(): + if dist.get_rank() == 0: + self._hessian = torch.zeros([self.columns, self.columns], + device=device, + dtype=torch.float) + else: + self._hessian = None + else: + self._hessian = torch.zeros([self.columns, self.columns], + device=device, + dtype=torch.float) + + # prune + + @torch.no_grad() + def prune(self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01): + """The implementation for SparseGPT.""" + with torch_setting(dtype=torch.float): + # Converted from https://github.com/ist-daslab/sparsegpt + + assert self.hessian is not None + W: torch.Tensor = self.weight_matrix.float() # out in + + H = self.hessian.float().to(W.device) + + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + Losses = torch.zeros(self.rows, device=W.device) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=W.device) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + mask = None + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + if prunen == 0: + if mask is not None: + mask1 = mask[:, i1:i2] + else: + tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1)))**2 + thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * + sparsity)] + mask1 = tmp <= thresh + else: + mask1 = torch.zeros_like(W1) == 1 + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if prunen != 0 and i % prunem == 0: + tmp = W1[:, i:(i + prunem)]**2 / (torch.diag(Hinv1)[i:( + i + prunem)].reshape((1, -1)))**2 + mask1.scatter_( + 1, i + + torch.topk(tmp, prunen, dim=1, largest=False)[1], + True) + + q = w.clone() + q[mask1[:, i]] = 0 + + Q1[:, i] = q + Losses1[:, i] = (w - q)**2 / d**2 + + err1 = (w - q) / d + W1[:, + i:] -= err1.unsqueeze(1).matmul(Hinv1[i, + i:].unsqueeze(0)) + Err1[:, i] = err1 + + W[:, i1:i2] = Q1 + Losses += torch.sum(Losses1, 1) / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + if W.device.type == 'cuda': + torch.cuda.synchronize() + from .sparse24_utils import is_weight_sparse_24 + if prunen == 2 and prunem == 4: + assert is_weight_sparse_24( + W, -1), f'Weight dose not satisfy 24 with shape {W.shape}' + error = torch.sum(Losses) + + if torch.isnan(error).any(): + raise Exception('get nan error') + else: + self.weight_matrix = W.data + + return error.item() + + +# SparseGpt Ops for Linear and Conv2d + + +class SparseGptLinear(DynamicLinear, SparseGptMixIn): + """Custom Linear for SparseGpt.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._sparse_gpt_mix_in_init() + + @classmethod + def convert_from(cls, module: nn.Linear) -> 'DynamicConv2d': + """Convert to cls from torch's module.""" + if module.out_features < module.in_features: + return module + new_module = super().convert_from(module) + new_module.load_state_dict(module.state_dict(), strict=False) + + dtype = next(module.parameters()).dtype + new_module = new_module.to(dtype) + + return new_module + + +class SparseGptConv2d(DynamicConv2d, SparseGptMixIn): + """Custom Conv2d for SparseGpt.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._sparse_gpt_mix_in_init() + + @classmethod + def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d': + """Convert to cls from torch's module.""" + new_module = super().convert_from(module) + new_module.load_state_dict(module.state_dict(), strict=False) + + dtype = next(module.parameters()).dtype + new_module = new_module.to(dtype) + + return new_module + + def format_input(self, input: torch.Tensor): + """Format input shape.""" + # input B C H W + input = F.unfold( + input, self.kernel_size, padding=self.padding, + stride=self.stride) # B C D + return input.transpose(-1, -2) diff --git a/mmrazor/implementations/pruning/sparse_gpt/sparse24_utils.py b/mmrazor/implementations/pruning/sparse_gpt/sparse24_utils.py new file mode 100644 index 000000000..1d646dee1 --- /dev/null +++ b/mmrazor/implementations/pruning/sparse_gpt/sparse24_utils.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +@torch.no_grad() +def is_weight_sparse_24(weight: torch.Tensor, dim=-1): + """"Check if the weight is sparse 24.""" + weight = weight.transpose(-1, dim).reshape(-1, 4) # N 4 + is_zero = (weight == 0).sum(-1) # N + return (is_zero >= 2).all() diff --git a/mmrazor/implementations/pruning/sparse_gpt/utils.py b/mmrazor/implementations/pruning/sparse_gpt/utils.py new file mode 100644 index 000000000..df82784c1 --- /dev/null +++ b/mmrazor/implementations/pruning/sparse_gpt/utils.py @@ -0,0 +1,140 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys +from typing import Dict, Type + +if sys.version_info < (3, 8): + from typing_extensions import Protocol +else: + from typing import Protocol + +import torch +import torch.nn as nn + +from mmrazor.models.architectures.dynamic_ops import DynamicMixin +from mmrazor.utils import print_log + + +class ModuleProtocol(Protocol): + """Custom module protocol for algorithm mixin.""" + weight: torch.Tensor + + def forward(self, x): + """The abstract method.""" + pass + + def register_forward_hook(self, hook): + """The abstract method.""" + pass + + def register_backward_hook(self, hook): + """The abstract method.""" + pass + + def register_forward_pre_hook(self, hook): + """The abstract method.""" + pass + + def register_buffer(self, name, tensor): + """The abstract method.""" + pass + + +def replace_with_dynamic_ops(model: nn.Module, + dynamicop_map: Dict[Type[nn.Module], + Type[DynamicMixin]]): + """Replace torch modules with dynamic-ops.""" + + def replace_op(model: nn.Module, name: str, module: nn.Module): + names = name.split('.') + for sub_name in names[:-1]: + model = getattr(model, sub_name) + + setattr(model, names[-1], module) + + for name, module in model.named_modules(): + if type(module) in dynamicop_map: + new_module = dynamicop_map[type(module)].convert_from(module) + replace_op(model, name, new_module) + + +def register_efficient_forward_hook(module: nn.Module, + device=torch.device('cuda:0')): + """Register efficient forward hook.""" + + def forward_pre_hook(module: nn.Module, input): + module.to(device) + + def forward_hook(module: nn.Module, input, output): + module.to('cpu') + torch.cuda.empty_cache() + + h1 = module.register_forward_pre_hook(forward_pre_hook) + h2 = module.register_forward_hook(forward_hook) + return [h1, h2] + + +def enable_efficient_forward(model: nn.Module, + device=torch.device('cuda:0'), + wrap_modules=[]): + """Enable efficient forward.""" + handles = [] + blocks = [] + for name, module in model.named_children(): + if type(module) in wrap_modules or len(module._parameters) != 0 or len( + module._buffers) != 0: + handles_ = register_efficient_forward_hook(module, device) + blocks_ = [name] + else: + handles_, blocks_ = enable_efficient_forward( + module, device, wrap_modules) + handles += handles_ + blocks += blocks_ + return handles, blocks + + +class memory_efficient_forward: + """The class for Memory efficient forward.""" + + def __init__(self, + model: nn.Module, + enabled=True, + device=torch.device('cuda:0'), + wrap_modules=[]) -> None: + self.model = model + self.device = device + self.wrap_modules = wrap_modules + self.enabled = enabled + self.handlers: list = [] + + if not enabled: + model.to(device) + + def __enter__(self, ): + """Enter.""" + if self.enabled: + handles, blocks = enable_efficient_forward(self.model, self.device, + self.wrap_modules) + print_log(f'enable memory efficient forward for {blocks}') + self.handlers = handles + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Exit.""" + for h in self.handlers: + h.remove() + + +class torch_setting(): + """Set the default torch dtype setting.""" + + def __init__(self, dtype=None) -> None: + self.original_dtype = torch.get_default_dtype() + self.dtype = dtype + + def __enter__(self): + """Enter.""" + if self.dtype is not None: + torch.set_default_dtype(self.dtype) + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Exit.""" + torch.set_default_dtype(self.original_dtype) diff --git a/mmrazor/implementations/quantization/gptq/__init__.py b/mmrazor/implementations/quantization/gptq/__init__.py new file mode 100644 index 000000000..4981c8014 --- /dev/null +++ b/mmrazor/implementations/quantization/gptq/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .compressor import GPTQCompressor +from .gptq import GPTQMixIn +from .ops import GPTQConv2d, GPTQLinear, TritonGPTQLinear +from .quantizer import Quantizer + +__all__ = [ + 'GPTQCompressor', + 'GPTQMixIn', + 'GPTQConv2d', + 'GPTQLinear', + 'TritonGPTQLinear', + 'Quantizer', +] diff --git a/mmrazor/implementations/quantization/gptq/compressor.py b/mmrazor/implementations/quantization/gptq/compressor.py new file mode 100644 index 000000000..4a5aadd80 --- /dev/null +++ b/mmrazor/implementations/quantization/gptq/compressor.py @@ -0,0 +1,146 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, Type + +import torch +import torch.nn as nn + +from mmrazor.utils import print_log +from .ops import GPTQConv2d, GPTQLinear, GPTQMixIn, TritonGPTQLinear +from .quantizer import Quantizer + + +def replace_with_dynamic_ops(model: nn.Module, + dynamicop_map: Dict[Type[nn.Module], Type[Any]], + skipped_layers=[], + a_qconfig=None, + **kwargs): + """Replace torch modules with dynamic-ops.""" + + def replace_op(model: nn.Module, name: str, module: nn.Module): + names = name.split('.') + for sub_name in names[:-1]: + model = getattr(model, sub_name) + + setattr(model, names[-1], module) + + for name, module in model.named_modules(): + if type(module) in dynamicop_map and name not in skipped_layers: + if isinstance(module, nn.Linear): + if a_qconfig: + a_fakequant = Quantizer() + a_fakequant.configure(**a_qconfig) + kwargs.update({'a_fakequant': a_fakequant}) + new_module = dynamicop_map[type(module)].convert_from( + module, **kwargs) + else: + new_module = dynamicop_map[type(module)].convert_from(module) + replace_op(model, name, new_module) + + +def to_static_model(model: nn.Module): + """Replace dynamicops with torch modules.""" + from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet, + load_fix_subnet) + fix_subnet = export_fix_subnet(model)[0] + load_fix_subnet(model, fix_subnet) + return model + + +class GPTQCompressor(): + """The compressor with GPTQ.""" + + def __init__(self) -> None: + self.model: nn.Module = None + + def prepare(self, + model: nn.Module, + quant_conv=True, + quant_linear=True, + use_triton_ops=True, + skipped_layers=[], + a_qconfig=None, + **kwargs) -> None: + """Prepare for compressing model.""" + self.model = model + quant_modules: dict = {} + if quant_conv: + quant_modules[nn.Conv2d] = GPTQConv2d + if quant_linear: + gptq_linear = TritonGPTQLinear if use_triton_ops else GPTQLinear + quant_modules[nn.Linear] = gptq_linear + replace_with_dynamic_ops(model, quant_modules, skipped_layers, + a_qconfig, **kwargs) + + @classmethod + def to_static_model(cls, model): + """Convert replaced op with the original torch model.""" + return to_static_model(model) + + # hessian + + def register_hessian_hooks(self): + """Register updating hessian hooks for specified ops.""" + for module in self.quant_ops: + module.register_hessian_hook() + + def remove_hessian_hooks(self): + """Remove updating hessian hooks for specified ops.""" + for module in self.quant_ops: + module.remove_hessian_hook() + + def init_hessian(self, device=None): + """Init hessian.""" + for op in self.quant_ops: + op.init_hessian(device=device) + + # quant + def quant(self, + blocksize=128, + percdamp=0.01, + groupsize=-1, + actorder=False, + device=torch.device('cuda:0'), + **qconfig): + """Apply the compression algorithm to the model.""" + for name, module in self.named_quant_ops: + try: + original_device = next(module.parameters()).device + module: GPTQMixIn = module.to(device) + quantizer = Quantizer() + quantizer.configure(**qconfig) + # print_log(f'quant {name}...') + error = module.quant( + quantizer=quantizer, + blocksize=blocksize, + percdamp=percdamp, + groupsize=groupsize, + actorder=actorder) + print_log(f'quant {name} success \t error = {error}') + module.to(original_device) + module.free() + except Exception as e: + print_log(f'quant {name} failed as {e}') + + def quant_with_default_qconfig(self, groupsize=128, device='cpu'): + """Apply the compression algorithm to the model with the specified + setting.""" + qconfig = dict(bits=4, perchannel=True, sym=False) + self.quant( + groupsize=groupsize, actorder=True, device=device, **qconfig) + + # ops + + @property + def quant_ops(self): + """The ops to be applied the algorithm.""" + assert self.model is not None + for module in self.model.modules(): + if isinstance(module, GPTQMixIn): + yield module + + @property + def named_quant_ops(self): + """The named ops to be applied the algorithm.""" + for name, module in self.model.named_modules(): + if isinstance(module, GPTQMixIn): + yield name, module diff --git a/mmrazor/implementations/quantization/gptq/custom_autotune.py b/mmrazor/implementations/quantization/gptq/custom_autotune.py new file mode 100644 index 000000000..1bc0d7d5f --- /dev/null +++ b/mmrazor/implementations/quantization/gptq/custom_autotune.py @@ -0,0 +1,254 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# https://github.com/fpgaminer/GPTQ-triton +"""Mostly the same as the autotuner in Triton, but with a few changes like +using 40 runs instead of 100.""" + +import builtins +import math +import time +from typing import Dict + +try: + import triton +except ImportError: + from mmrazor.utils import get_package_placeholder + triton = get_package_placeholder('triton >= 2.0.0') + + +class Autotuner(triton.KernelInterface): + """Autotuner.""" + + def __init__(self, + fn, + arg_names, + configs, + key, + reset_to_zero, + prune_configs_by: Dict = None, + nearest_power_of_two: bool = False): + '''prune_configs_by: a dict of functions that are used to prune + configs, fields: + 'perf_model': performance model used to predicate running time + with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune + num_stages. It take configs:List[Config] as its input, and + returns pruned configs. + 'nearest_power_of_two'(optional): whether to round key arguments + to the nearest power of two when caching tuning results.''' + if not configs: + self.configs = [triton.Config({}, num_warps=4, num_stages=2)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.nearest_power_of_two = nearest_power_of_two + self.cache: Dict = {} + # hook to reset all required tensor to zeros before relaunching + # a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + + self.hook = _hook + self.arg_names = arg_names + # prune configs + if prune_configs_by: + perf_model, top_k = prune_configs_by[ + 'perf_model'], prune_configs_by['top_k'] + if 'early_config_prune' in prune_configs_by: + early_config_prune = prune_configs_by['early_config_prune'] + else: + perf_model, top_k, early_config_prune = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + self.fn = fn + + def _bench(self, *args, config, **meta): + """Check for conflicts, i.e. meta-parameters both provided as kwargs + and by the autotuner.""" + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + + def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) + self.hook(args) + self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **current) + + try: + # In testings using only 40 reps seems to be close enough and it + # appears to be what PyTorch uses + # PyTorch also sets fast_flush to True, but I didn't see any + # speedup so I'll leave the default + return triton.testing.do_bench( + kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40) + except triton.compiler.OutOfResources: + return (float('inf'), float('inf'), float('inf')) + + def run(self, *args, **kwargs): + """Run.""" + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + key = tuple(args[i] for i in self.key_idx) + + # This reduces the amount of autotuning by rounding the keys to + # the nearest power of two + # In my testing this gives decent results, and greatly reduces + # the amount of tuning required + if self.nearest_power_of_two: + key = tuple([2**int(math.log2(x) + 0.5) for x in key]) + + if key not in self.cache: + # prune configs + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = { + config: self._bench(*args, config=config, **kwargs) + for config in pruned_configs + } + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.hook(args) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if config.pre_hook is not None: + config.pre_hook(self.nargs) + return self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs) + + def prune_configs(self, kwargs): + """Prune configs.""" + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.kwargs, + num_stages=config.num_stages, + num_warps=config.num_warps) + for config in pruned_configs + } + pruned_configs = sorted( + est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + """Warm up.""" + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None + + +def autotune(configs, + key, + prune_configs_by=None, + reset_to_zero=None, + nearest_power_of_two=False): + """Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated + # anytime the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run + multiple time.This means that whatever value the kernel updates will + be updated multiple times.To avoid this undesired behavior, you can + use the `reset_to_zero` argument, which reset the value of the + provided tensor to `zero` before running any configuration. + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger + the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune + configs, fields: + 'perf_model': performance model used to predicate running time with + different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune + (eg, num_stages). It take configs:List[Config] as its input, and + returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset + to zero before evaluating any configs. + :type reset_to_zero: list[str] + """ + + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, + prune_configs_by, nearest_power_of_two) + + return decorator + + +def matmul248_kernel_config_pruner(configs, nargs): + """The main purpose of this function is to shrink BLOCK_SIZE_* when the + corresponding dimension is smaller.""" + m = max(2**int(math.ceil(math.log2(nargs['M']))), 16) + n = max(2**int(math.ceil(math.log2(nargs['N']))), 16) + k = max(2**int(math.ceil(math.log2(nargs['K']))), 16) + + used = set() + for config in configs: + block_size_m = min(m, config.kwargs['BLOCK_SIZE_M']) + block_size_n = min(n, config.kwargs['BLOCK_SIZE_N']) + block_size_k = min(k, config.kwargs['BLOCK_SIZE_K']) + group_size_m = config.kwargs['GROUP_SIZE_M'] + + if (block_size_m, block_size_n, block_size_k, group_size_m, + config.num_stages, config.num_warps) in used: + continue + + used.add((block_size_m, block_size_n, block_size_k, group_size_m, + config.num_stages, config.num_warps)) + yield triton.Config( + { + 'BLOCK_SIZE_M': block_size_m, + 'BLOCK_SIZE_N': block_size_n, + 'BLOCK_SIZE_K': block_size_k, + 'GROUP_SIZE_M': group_size_m + }, + num_stages=config.num_stages, + num_warps=config.num_warps) diff --git a/mmrazor/implementations/quantization/gptq/gptq.py b/mmrazor/implementations/quantization/gptq/gptq.py new file mode 100644 index 000000000..84cfd3a4b --- /dev/null +++ b/mmrazor/implementations/quantization/gptq/gptq.py @@ -0,0 +1,318 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys + +if sys.version_info < (3, 8): + from typing_extensions import Protocol +else: + from typing import Protocol + +import numpy as np +import torch +import torch.distributed as dist + +from mmrazor.implementations.pruning.sparse_gpt.utils import torch_setting + + +class ModuleProtocol(Protocol): + """Custom module protocol for algorithm mixin.""" + weight: torch.Tensor + + def forward(self, x): + """The abstract method.""" + pass + + def register_forward_hook(self, hook): + """The abstract method.""" + pass + + def register_backward_hook(self, hook): + """The abstract method.""" + pass + + def register_forward_pre_hook(self, hook): + """The abstract method.""" + pass + + def register_buffer(self, name, tensor): + """The abstract method.""" + pass + + +class GPTQMixIn(ModuleProtocol): + """The core algorithm implementation for GPTQ.""" + + def _gptq_mix_in_init(self): + """Init mixin.""" + self.gptq_handles = [] + self.rows = self.weight_matrix.shape[0] + self.columns = self.weight_matrix.shape[1] + + self._hessian: torch.Tensor = None + self.hessian_batch = 0 + + # weight and input adaptive + + @property + def weight_matrix(self): + """Return weight with shape (out in)""" + return self.weight.flatten(1) # out in + + @weight_matrix.setter + def weight_matrix(self, value: torch.Tensor): + """Set weight.""" + with torch.no_grad(): + value = value.reshape(self.weight.shape).to(self.weight.device).to( + self.weight.dtype) + self.weight.data.copy_(value) + + def format_input(self, input: torch.Tensor): + """Return input with shape (B N C)""" + if len(input.shape) == 2: # N C + input = input.unsqueeze(0) # 1 N C + return input + + # compute hessian + + @property + def hessian(self): + """hessian always return float.""" + if dist.is_initialized(): + if dist.get_rank() == 0: + assert self._hessian is not None, 'hessian is not initialized.' + hessian = self._hessian.to(self.weight_matrix.device) + else: + hessian = torch.zeros( + self.columns, + self.columns, + device=self.weight_matrix.device) + dist.broadcast(hessian, 0) + return hessian + else: + return self._hessian + + @hessian.setter + def hessian(self, value: torch.Tensor): + """Set hessian.""" + with torch.no_grad(): + if dist.is_initialized(): + if dist.get_rank() == 0: + assert self._hessian is not None, 'hessian is not initialized.' # noqa + self._hessian.data.copy_( + value.data.to(self._hessian.device)) + else: + self._hessian = None + else: + self._hessian.data.copy_(value.data.to(self._hessian.device)) + + @torch.no_grad() + def update_hessian(self, input: torch.Tensor): + """Update hessian.""" + input = self.format_input(input).float() + H_save = self.hessian + H_save = H_save.to(input.device) + + assert len(input.shape) == 3 + B = input.shape[0] # B N C + input = input.transpose(0, -1).flatten(1) # C D + + H = input @ input.T * 2 # C C + + if dist.is_initialized(): + dist.all_reduce(H) + B *= dist.get_world_size() + H_save = (H_save * self.hessian_batch + H) / (self.hessian_batch + B) + self.hessian = H_save + self.hessian_batch = self.hessian_batch + B + + def register_hessian_hook(self): + """Register updating hessian hook.""" + + @torch.no_grad() + def forward_pre_hook(module: Protocol, input: tuple): + assert len(input) == 1 + self.update_hessian(input[0]) + + handle = self.register_forward_pre_hook(forward_pre_hook) + self.gptq_handles.append(handle) + + def remove_hessian_hook(self): + """Remove updating hessian hook.""" + for h in self.gptq_handles: + h.remove() + + def init_hessian(self, device=None): + """Init hessian.""" + if dist.is_initialized(): + if dist.get_rank() == 0: + self._hessian = torch.zeros([self.columns, self.columns], + device=device, + dtype=torch.float) + else: + self._hessian = None + else: + self._hessian = torch.zeros([self.columns, self.columns], + device=device, + dtype=torch.float) + + def pack(self, scales, zeros, g_idx=None): + """Pack and update qparams with groupsize_idx.""" + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if self.bias is not None: + self.bias.half() + + intweight = [] + for idx in range(self.in_features): + intweight.append( + torch.round( + (self.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / + self.scales[self.g_idx[idx]]).to(torch.int)[:, None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.cpu().numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), + dtype=np.uint32) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError('Only 2,4,8 bits are supported.') + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight).to(self.weight.device) + + zeros -= 1 + zeros = zeros.cpu().numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), + dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError('Only 2,4,8 bits are supported.') + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros).to(self.weight.device) + + @torch.no_grad() + def quant(self, + quantizer, + blocksize=128, + percdamp=0.01, + groupsize=-1, + actorder=False): + """The implementation for GPTQ.""" + with torch_setting(dtype=torch.float): + assert self.hessian is not None + W: torch.Tensor = self.weight_matrix.float() # out in + + if not quantizer.ready(): + quantizer.find_params(W, weight=True) + + H = self.hessian.float().to(W.device) + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + if actorder: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=W.device) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + g_idx = [] + scale = [] + zero = [] + now_idx = 1 + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if groupsize != -1: + if (i1 + i) % groupsize == 0: + quantizer.find_params( + W[:, (i1 + i):(i1 + i + groupsize)], + weight=True) + + if ((i1 + i) // groupsize) - now_idx == -1: + scale.append(quantizer.scale) + zero.append(quantizer.zero) + now_idx += 1 + + q = quantizer.quantize(w.unsqueeze(1)).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q)**2 / d**2 + + err1 = (w - q) / d + W1[:, + i:] -= err1.unsqueeze(1).matmul(Hinv1[i, + i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + torch.cuda.synchronize() + error = torch.sum(Losses).item() + + groupsize = groupsize if groupsize != -1 else self.columns + g_idx = [i // groupsize for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) + if actorder: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + g_idx = g_idx[invperm] + + if scale == []: + scale.append(quantizer.scale) + zero.append(quantizer.zero) + scale = torch.cat(scale, dim=1) + zero = torch.cat(zero, dim=1) + self.weight_matrix = Q.data.to(self.weight_matrix.dtype) + if self.is_custom_kernel: + self.pack(scale, zero, g_idx) + del self.weight + return error + + def free(self): + """Free some cache and memory.""" + self._hessian = None + torch.cuda.empty_cache() diff --git a/mmrazor/implementations/quantization/gptq/ops.py b/mmrazor/implementations/quantization/gptq/ops.py new file mode 100644 index 000000000..b8c139412 --- /dev/null +++ b/mmrazor/implementations/quantization/gptq/ops.py @@ -0,0 +1,566 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + +from mmrazor.models.architectures.dynamic_ops import (DynamicConv2d, + DynamicLinear) +# from mmrazor.implementations.pruning.sparse_gpt.utils import torch_setting +from .gptq import GPTQMixIn + +try: + import triton + import triton.language as tl + + from . import custom_autotune + + # code based https://github.com/fpgaminer/GPTQ-triton + @custom_autotune.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, + num_stages=2, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, + num_stages=2, + num_warps=4), + ], + key=['M', 'N', 'K'], + nearest_power_of_two=True, + prune_configs_by={ + 'early_config_prune': + custom_autotune.matmul248_kernel_config_pruner, + 'perf_model': None, + 'top_k': None, + }, + ) + @triton.jit + def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, + N, K, bits, maxq, stride_am, stride_ak, stride_bk, + stride_bn, stride_cm, stride_cn, stride_scales, + stride_zeros, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 + # times + b_ptrs = b_ptr + ( + (offs_k[:, None] // infearure_per_bits) * stride_bk + + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit + # word from B + scales_ptrs = scales_ptr + offs_bn[None, :] + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + + # Fetch scales and zeros; these are per-outfeature and thus reused + # in the inner loop + scales = tl.load(scales_ptrs + g_idx[:, None] * + stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load( + zeros_ptrs + + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + a = tl.load( + a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit + # values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[ + None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @custom_autotune.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 256, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, + num_stages=2, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, + num_stages=2, + num_warps=4), + ], + key=['M', 'N', 'K'], + nearest_power_of_two=True) + @triton.jit + def transpose_matmul_248_kernel( + a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, + maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, + stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, N) float16 + B is of shape (K//8, N) int32 + C is of shape (M, K) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_k + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_k = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_n = tl.arange(0, BLOCK_SIZE_N) + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak + ) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 + # times + b_ptrs = b_ptr + ( + (offs_bk[:, None] // infearure_per_bits) * stride_bk + + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_bk + g_idx = tl.load(g_ptrs) + + # shifter is used to extract the N bits of each element in the 32-bit + # word from B + scales_ptrs = scales_ptr + offs_n[ + None, :] + g_idx[:, None] * stride_scales + zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits + ) + g_idx[:, None] * stride_zeros + + shifter = (offs_bk % infearure_per_bits) * bits + zeros_shifter = (offs_n % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + + for n in range(0, num_pid_n): + # Fetch scales and zeros; these are per-outfeature and thus reused + # in the inner loop + scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + a = tl.load( + a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit + # values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + b = tl.trans(b) + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_N + b_ptrs += BLOCK_SIZE_N + scales_ptrs += BLOCK_SIZE_N + zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) + + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[ + None, :] + c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) + tl.store(c_ptrs, accumulator, mask=c_mask) +except: # noqa: E722 + print('triton not installed.') + + +def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): + """matmul248 function with matmul_248_kernel.""" + with torch.cuda.device(input.device): + output = torch.empty((input.shape[0], qweight.shape[1]), + device=input.device, + dtype=torch.float16) + grid = lambda META: ( # noqa: E731 + triton.cdiv( # noqa: E731 + input.shape[0], META['BLOCK_SIZE_M']) * triton. # noqa: E731 + cdiv( # noqa: E731 + qweight.shape[1], META['BLOCK_SIZE_N']), ) # noqa: E731 + matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, + input.shape[0], qweight.shape[1], + input.shape[1], bits, maxq, input.stride(0), + input.stride(1), qweight.stride(0), + qweight.stride(1), output.stride(0), + output.stride(1), scales.stride(0), + qzeros.stride(0)) + return output + + +def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): + """transpose_matmul248 function with transpose_matmul_248_kernel.""" + with torch.cuda.device(input.device): + output_dim = (qweight.shape[0] * 32) // bits + output = torch.empty((input.shape[0], output_dim), + device=input.device, + dtype=torch.float16) + grid = lambda META: ( # noqa: E731 + triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) # noqa: E731 + * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), ) # noqa: E731 + transpose_matmul_248_kernel[grid](input, qweight, output, scales, + qzeros, g_idx, input.shape[0], + qweight.shape[1], output_dim, + bits, maxq, input.stride(0), + input.stride(1), qweight.stride(0), + qweight.stride(1), output.stride(0), + output.stride(1), scales.stride(0), + qzeros.stride(0)) + return output + + +class QuantLinearFunction(torch.autograd.Function): + """Custom QuantLinearFunction.""" + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): + """Custom forward.""" + output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) + ctx.save_for_backward(qweight, scales, qzeros, g_idx) + ctx.bits, ctx.maxq = bits, maxq + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + """Custom backward.""" + qweight, scales, qzeros, g_idx = ctx.saved_tensors + bits, maxq = ctx.bits, ctx.maxq + grad_input = None + + if ctx.needs_input_grad[0]: + grad_input = transpose_matmul248(grad_output, qweight, scales, + qzeros, g_idx, bits, maxq) + return grad_input, None, None, None, None, None, None + + +class TritonGPTQLinear(nn.Module, GPTQMixIn): + """Custom Linear for GPTQ with custom triton kernel.""" + + def __init__(self, bits, groupsize, weight, in_features, out_features, + bias): + super().__init__() + if bits not in [2, 4, 8]: + raise NotImplementedError('Only 2,4,8 bits are supported.') + self.weight = weight + self.bias = bias + + self.in_features = in_features + self.out_features = out_features + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize if groupsize != -1 else in_features + + self.register_buffer( + 'qweight', + torch.zeros((in_features // 32 * self.bits, out_features), + dtype=torch.int32)) + self.register_buffer( + 'qzeros', + torch.zeros((math.ceil( + in_features / self.groupsize), out_features // 32 * self.bits), + dtype=torch.int32)) + self.register_buffer( + 'scales', + torch.zeros( + (math.ceil(in_features / self.groupsize), out_features), + dtype=torch.float16)) + self.register_buffer( + 'g_idx', + torch.tensor([i // self.groupsize for i in range(in_features)], + dtype=torch.int32)) + + self._gptq_mix_in_init() + + @property + def is_custom_kernel(self): + """Whether use custom kernel.""" + return True + + @classmethod + def convert_from(cls, module: nn.Linear, bits, groupsize): + """Convert to cls from torch's module.""" + new_module = cls( + bits, + groupsize, + weight=module.weight, + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias) + + return new_module + + def forward(self, x): + """Custom forward.""" + if torch.all(self.qweight == 0): + out = F.linear(x, self.weight, self.bias) + else: + # import pdb;pdb.set_trace() + out_shape = x.shape[:-1] + (self.out_features, ) + out = QuantLinearFunction.apply( + x.reshape(-1, x.shape[-1]), self.qweight, self.scales, + self.qzeros, self.g_idx, self.bits, self.maxq) + out = out + self.bias if self.bias is not None else out + out = out.reshape(out_shape) + # import pdb;pdb.set_trace() + return out + + +class GPTQLinear(DynamicLinear, GPTQMixIn): + """Custom Linear for GPTQ without custom triton kernel.""" + + def __init__(self, a_fakequant=None, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._gptq_mix_in_init() + self.a_fakequant = a_fakequant + self.fix_qparams = False + + @property + def is_custom_kernel(self): + """Whether use custom kernel.""" + return False + + @classmethod + def convert_from(cls, + module: nn.Linear, + a_fakequant=None) -> 'DynamicLinear': + """Convert to cls from torch's module.""" + new_module = cls( + a_fakequant=a_fakequant, + in_features=module.in_features, + out_features=module.out_features, + bias=True if module.bias is not None else False) + new_module.load_state_dict(module.state_dict(), strict=False) + + dtype = next(module.parameters()).dtype + new_module = new_module.to(dtype) + + return new_module + + def forward(self, input: Tensor) -> Tensor: + """Custom forward.""" + if self.a_fakequant: + dtype = self.weight.dtype + if not self.fix_qparams: + self.a_fakequant.find_params(input) + input = self.a_fakequant.quantize(input).to(dtype) + return super().forward(input) + + +class GPTQConv2d(DynamicConv2d, GPTQMixIn): + """Custom Conv2d for GPTQ without custom triton kernel.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._gptq_mix_in_init() + + @property + def is_custom_kernel(self): + """Whether use custom kernel.""" + return False + + @classmethod + def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d': + """Convert to cls from torch's module.""" + new_module = super().convert_from(module) + new_module.load_state_dict(module.state_dict(), strict=False) + + dtype = next(module.parameters()).dtype + new_module = new_module.to(dtype) + + return new_module + + def format_input(self, input: torch.Tensor): + """Format input shape.""" + # input B C H W + input = F.unfold( + input, self.kernel_size, padding=self.padding, + stride=self.stride) # B C D + return input.transpose(-1, -2) diff --git a/mmrazor/implementations/quantization/gptq/quantizer.py b/mmrazor/implementations/quantization/gptq/quantizer.py new file mode 100644 index 000000000..0db2fb998 --- /dev/null +++ b/mmrazor/implementations/quantization/gptq/quantizer.py @@ -0,0 +1,144 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + + +class Quantizer(nn.Module): + """Quantizer for some basic quantization functions.""" + + def __init__(self, shape=1): + super(Quantizer, self).__init__() + self.register_buffer('maxq', torch.tensor(0)) + self.register_buffer('scale', torch.zeros(shape)) + self.register_buffer('zero', torch.zeros(shape)) + + def configure(self, + bits, + perchannel=False, + sym=True, + mse=False, + norm=2.4, + grid=100, + maxshrink=.8, + trits=False): + """Configure qconfig.""" + + self.maxq = torch.tensor(2**bits - 1) + self.perchannel = perchannel + self.sym = sym + self.mse = mse + self.norm = norm + self.grid = grid + self.maxshrink = maxshrink + if trits: + self.maxq = torch.tensor(-1) + self.scale = torch.zeros_like(self.scale) + + def _quantize(self, x, scale, zero, maxq): + """Fakequant.""" + if maxq < 0: + return (x > scale / 2).float() * scale + (x < + zero / 2).float() * zero + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + + def find_params(self, x, weight=False): + """Observe the specified data and calculate the qparams.""" + dev = x.device + self.maxq = self.maxq.to(dev) + + shape = x.shape + if self.perchannel: + if weight: + x = x.flatten(1) + else: + if len(shape) == 4: + x = x.permute([1, 0, 2, 3]) + x = x.flatten(1) + if len(shape) == 3: + x = x.reshape((-1, shape[-1])).t() + if len(shape) == 2: + x = x.t() + else: + x = x.flatten().unsqueeze(0) + + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + if self.maxq < 0: + self.scale = xmax + self.zero = xmin + else: + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = torch.round(-xmin / self.scale) + + if self.mse: + best = torch.full([x.shape[0]], float('inf'), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / + scale1) if not self.sym else self.zero + q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), + self.maxq) + q -= x + q.abs_() + q.pow_(self.norm) + err = torch.sum(q, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + if weight: + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + return + if len(shape) == 4: + self.scale = self.scale.reshape((1, -1, 1, 1)) + self.zero = self.zero.reshape((1, -1, 1, 1)) + if len(shape) == 3: + self.scale = self.scale.reshape((1, 1, -1)) + self.zero = self.zero.reshape((1, 1, -1)) + if len(shape) == 2: + self.scale = self.scale.unsqueeze(0) + self.zero = self.zero.unsqueeze(0) + + def quantize(self, x): + """Fakequant.""" + if self.ready(): + return self._quantize(x, self.scale, self.zero, self.maxq) + + return x + + def enabled(self): + """Whether is enabled.""" + return self.maxq > 0 + + def ready(self): + """Whether is ready.""" + return torch.all(self.scale != 0) diff --git a/mmrazor/implementations/quantization/gptq/utils.py b/mmrazor/implementations/quantization/gptq/utils.py new file mode 100644 index 000000000..a27b3ff8d --- /dev/null +++ b/mmrazor/implementations/quantization/gptq/utils.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +# copy from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py # noqa: E501 +def torch_snr_error(y_pred: torch.Tensor, + y_real: torch.Tensor, + reduction: str = 'mean') -> torch.Tensor: + """Compute SNR between y_pred(tensor) and y_real(tensor) + + SNR can be calculted as following equation: + + SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2 + + if x and y are matrixs, SNR error over matrix should be the mean value of + SNR error over all elements. + + SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2) + Args: + y_pred (torch.Tensor): _description_ + y_real (torch.Tensor): _description_ + reduction (str, optional): _description_. Defaults to 'mean'. + Raises: + ValueError: _description_ + ValueError: _description_ + Returns: + torch.Tensor: _description_ + """ + y_pred = y_pred.type(torch.float32) + y_real = y_real.type(torch.float32) + + if y_pred.shape != y_real.shape: + raise ValueError( + f'Can not compute snr loss for tensors with different shape. ' + f'({y_pred.shape} and {y_real.shape})') + reduction = str(reduction).lower() + + if y_pred.ndim == 1: + y_pred = y_pred.unsqueeze(0) + y_real = y_real.unsqueeze(0) + + y_pred = y_pred.flatten(start_dim=1) + y_real = y_real.flatten(start_dim=1) + + noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1) + signal_power = torch.pow(y_real, 2).sum(dim=-1) + snr = (noise_power) / (signal_power + 1e-7) + + if reduction == 'mean': + return torch.mean(snr) + elif reduction == 'sum': + return torch.sum(snr) + elif reduction == 'none': + return snr + else: + raise ValueError('Unsupported reduction method.') diff --git a/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py b/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py index 0a381d5d0..bea018975 100644 --- a/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py +++ b/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py @@ -49,7 +49,9 @@ 'relu_qat', 'bn_qat', 'bn_relu_qat', 'func' ]) -if digit_version(torch.__version__) >= digit_version('1.13.0'): +if digit_version( + torch.__version__) >= digit_version('1.13.0') and digit_version( + torch.__version__) <= digit_version('1.13.1'): _Conv1dMetadata = _ConvMetadata( nn.Conv1d, nn.ConvTranspose1d, nn.BatchNorm1d, nnqr.Conv1d, nnqr.ConvTranspose1d, nni.ConvReLU1d, nni.ConvBn1d, nni.ConvBnReLU1d, diff --git a/mmrazor/structures/quantization/backend_config/mapping.py b/mmrazor/structures/quantization/backend_config/mapping.py index b9cc5372b..0a02ac1b7 100644 --- a/mmrazor/structures/quantization/backend_config/mapping.py +++ b/mmrazor/structures/quantization/backend_config/mapping.py @@ -7,7 +7,9 @@ from .openvino import get_openvino_backend_config from .tensorrt import get_tensorrt_backend_config -if digit_version(torch.__version__) >= digit_version('1.13.0'): +if digit_version( + torch.__version__) >= digit_version('1.13.0') and digit_version( + torch.__version__) <= digit_version('1.13.1'): BackendConfigs = { 'academic': get_academic_backend_config(), 'native': get_native_backend_config(), diff --git a/mmrazor/utils/log_tools.py b/mmrazor/utils/log_tools.py index 787dc1927..935349a03 100644 --- a/mmrazor/utils/log_tools.py +++ b/mmrazor/utils/log_tools.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging +import torch.distributed as dist from mmengine import MMLogger from mmengine import print_log as engine_print_log @@ -17,8 +18,15 @@ def get_level(level='info'): return level -def print_log(msg, logger='current', level='info'): - engine_print_log(msg, logger, get_level(level)) +def print_log(msg, logger='current', level='info', only_rank0=True): + + if only_rank0 and dist.is_initialized(): + if dist.get_rank() == 0: + engine_print_log(msg, logger, get_level(level)) + else: + pass + else: + engine_print_log(msg, logger, get_level(level)) def set_log_level(level='debug'): diff --git a/projects/mmrazor_large/README.md b/projects/mmrazor_large/README.md new file mode 100644 index 000000000..378b9102b --- /dev/null +++ b/projects/mmrazor_large/README.md @@ -0,0 +1,42 @@ +
+ +
+ +# MMRazor for Large Models + +## Introduction + +MMRazor is dedicated to the development of general-purpose model compression tools. Now, MMRazor not only supports conventional CV model compression but also extends to support large models. This project will provide examples of MMRazor's compression for various large models, including LLaMA, stable diffusion, and more. + +Code structure overview about large models. + +``` +mmrazor +├── implementations # core algorithm components + ├── pruning + └── quantization +projects +└── mmrazor_large + ├── algorithms # algorithms usage introduction + └── examples # examples for various models about algorithms + ├── language_models + │ ├── LLaMA + │ └── OPT + └── ResNet +``` + +## Model-Algorithm Example Matrix + +| | ResNet | OPT | LLama | Stable diffusion | +| ------------------------------------ | ----------------------------------------------- | ------------------------------------------------------------ | -------------------------------------------------------------- | ---------------- | +| [SparseGPT](algorithms/SparseGPT.md) | [:white_check_mark:](examples/ResNet/README.md) | [:white_check_mark:](examples/language_models/OPT/README.md) | [:white_check_mark:](examples/language_models/LLaMA/README.md) | | +| [GPTQ](algorithms/GPTQ.md) | [:white_check_mark:](examples/ResNet/README.md) | [:white_check_mark:](examples/language_models/OPT/README.md) | [:white_check_mark:](examples/language_models/LLaMA/README.md) | | + +## PaperList + +We provide a paperlist for researchers in the field of model compression for large models. If you want to add your paper to this list, please submit a PR. + +| Paper | Title | Type | MMRazor | +| --------- | --------------------------------------------------------------------------------------------------------------------- | ------------ | --------------------------------------------- | +| SparseGPT | [SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot](https://arxiv.org/abs/2301.00774) | Pruning | [:white_check_mark:](algorithms/SparseGPT.md) | +| GPTQ | [GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers](https://arxiv.org/abs/2210.17323) | Quantization | [:white_check_mark:](algorithms/GPTQ.md) | diff --git a/projects/mmrazor_large/algorithms/GPTQ.md b/projects/mmrazor_large/algorithms/GPTQ.md new file mode 100644 index 000000000..b013a73a2 --- /dev/null +++ b/projects/mmrazor_large/algorithms/GPTQ.md @@ -0,0 +1,56 @@ +# GPTQ + +> [GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers](https://arxiv.org/abs/2210.17323) + + + +## Abstract + +Generative Pre-trained Transformer models, known as GPT or OPT, set themselves apart through breakthrough performance across complex language modelling tasks, but also by their extremely high computational and storage costs. Specifically, due to their massive size, even inference for large, highly-accurate GPT models may require multiple performant GPUs, which limits the usability of such models. While there is emerging work on relieving this pressure via model compression, the applicability and performance of existing compression techniques is limited by the scale and complexity of GPT models. In this paper, we address this challenge, and propose GPTQ, a new one-shot weight quantization method based on approximate second-order information, that is both highlyaccurate and highly-efficient. Specifically, GPTQ can quantize GPT models with 175 billion parameters in approximately four GPU hours, reducing the bitwidth down to 3 or 4 bits per weight, with negligible accuracy degradation relative to the uncompressed baseline. Our method more than doubles the compression gains relative to previously-proposed one-shot quantization methods, preserving accuracy, allowing us for the first time to execute an 175 billion-parameter model inside a single GPU for generative inference. Moreover, we also show that our method can still provide reasonable accuracy in the extreme quantization regime, in which weights are quantized to 2-bit or even ternary quantization levels. We show experimentally that these improvements can be leveraged for end-to-end inference speedups over FP16, of around 3.25x when using high-end GPUs (NVIDIA A100) and 4.5x when using more cost-effective ones (NVIDIA A6000). The implementation is available at https://github.com/IST-DASLab/gptq. + +## Usage + +GPTQ is easy to use in mmrazor. You can use it like this: + +```python +from mmrazor.implementations.quantization import gptq + +# initial model, dataloaders +model +train_loader, test_loader + +## init gptq compressor and prepare for quantization +compressor = gptq.GPTQCompressor() +compressor.prepare(model) + +## get hessian matrix +compressor.init_hessian() +compressor.register_hessian_hooks() +infer(model, test_loader, num_samples=num_samples) +compressor.remove_hessian_hooks() + +## quant +compressor.quant_with_default_qconfig() + +## to a normal torch model +model = compressor.to_static_model(model) + +``` + +## Full Examples + +- [ResNet](../examples/ResNet/README.md) +- [LLaMA](../examples/language_models/LLaMA/README.md) + +## Cite + +```latex + @misc{ + Frantar_Ashkboos_Hoefler_Alistarh_2022, + title={GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers}, + author={Frantar, Elias and Ashkboos, Saleh and Hoefler, Torsten and Alistarh, Dan}, + year={2022}, + month={Oct}, + language={en-US} +} +``` diff --git a/projects/mmrazor_large/algorithms/SparseGPT.md b/projects/mmrazor_large/algorithms/SparseGPT.md new file mode 100644 index 000000000..479235baa --- /dev/null +++ b/projects/mmrazor_large/algorithms/SparseGPT.md @@ -0,0 +1,55 @@ +# SparseGPT + +> [SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot](https://arxiv.org/abs/2301.00774) + + + +## Abstract + +We show for the first time that large-scale generative pretrained transformer (GPT) family models can be pruned to at least 50% sparsity in one-shot, without any retraining, at minimal loss of accuracy. This is achieved via a new pruning method called SparseGPT, specifically designed to work efficiently and accurately on massive GPT-family models. We can execute SparseGPT on the largest available open-source models, OPT-175B and BLOOM-176B, in under 4.5 hours, and can reach 60% unstructured sparsity with negligible increase in perplexity: remarkably, more than 100 billion weights from these models can be ignored at inference time. SparseGPT generalizes to semi-structured (2:4 and 4:8) patterns, and is compatible with weight quantization approaches. + +## Usage + +SparseGPT is easy to use in mmrazor. You can use it like this: + +```python +from mmrazor.implementations.pruning import sparse_gpt + +# initial model, dataloaders +model +train_loader, test_loader + +## init sparse gpt compressor and prepare for pruning +compressor = sparse_gpt.SparseGptCompressor() +compressor.prepare(model) + +## get hessian matrix +compressor.init_hessian() +compressor.register_hessian_hooks() +infer(model, test_loader, num_samples=num_samples) +compressor.remove_hessian_hooks() + +## prune +compressor.prune_24() + +## to a normal torch model +model = compressor.to_static_model(model) + +``` + +## Full Examples + +- [ResNet](../examples/ResNet/README.md) +- [OPT](../examples/language_models/OPT/README.md) +- [LLaMA](../examples/language_models/LLaMA/README.md) + +## Cite + +```latex +@article{frantar2023massive, + title={Massive Language Models Can Be Accurately Pruned in One-Shot}, + author={Frantar, Elias and Alistarh, Dan}, + journal={arXiv preprint arXiv:2301.00774}, + year={2023} +} +``` diff --git a/projects/mmrazor_large/examples/ResNet/README.md b/projects/mmrazor_large/examples/ResNet/README.md new file mode 100644 index 000000000..aa4eb374c --- /dev/null +++ b/projects/mmrazor_large/examples/ResNet/README.md @@ -0,0 +1,25 @@ +# Examples for ResNet + +## SparseGPT + +For more details about SparseGPT, please refer to [SparseGPT](../../algorithms/SparseGPT.md) + +### Usage + +```shell +python projects/mmrazor_large/examples/ResNet/resnet18_sparse_gpt.py --data {imagenet_path} --batchsize 128 --num_samples 512 +``` + +**Note**: this imagenet folder follows torch format. + +## GPTQ + +For more details about GPTQ, please refer to [GPTQ](../../algorithms/GPTQ.md) + +### Usage + +```shell +python projects/mmrazor_large/examples/ResNet/resnet18_gptq.py --data {imagenet_path} --batchsize 128 --num_samples 512 +``` + +**Note**: this imagenet folder follows torch format. diff --git a/projects/mmrazor_large/examples/ResNet/resnet18_gptq.py b/projects/mmrazor_large/examples/ResNet/resnet18_gptq.py new file mode 100644 index 000000000..9aa6877a6 --- /dev/null +++ b/projects/mmrazor_large/examples/ResNet/resnet18_gptq.py @@ -0,0 +1,187 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# model settings +import os.path as osp + +import torch +import torch.nn as nn +import torchvision +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +from mmrazor.implementations.quantization.gptq import (GPTQCompressor, + GPTQLinear) +from mmrazor.utils import print_log + + +def enable_observer_linear(model): + print_log('Enable updating qparams for GPTQLinear!') + for _, module in model.named_modules(): + if isinstance(module, GPTQLinear): + module.fix_qparams = False + + +def disable_observer_linear(model): + print_log('Disable updating qparams for GPTQLinear!') + for _, module in model.named_modules(): + if isinstance(module, GPTQLinear): + module.fix_qparams = True + + +def get_dataloaders(batch_size, n_workers, path=''): + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + train_dataset = datasets.ImageFolder( + osp.join(path, 'train'), + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ]), + ) + + test_dataset = datasets.ImageFolder( + osp.join(path, 'val'), + transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ]), + ) + + dataloader_train = torch.utils.data.DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=n_workers, + pin_memory=True, + ) + dataloader_test = torch.utils.data.DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=n_workers, + pin_memory=True, + ) + return dataloader_train, dataloader_test + + +@torch.no_grad() +def eval(model: nn.Module, + dataloader_test: DataLoader, + device=torch.device('cuda:0'), + is_half=True): + + total = 0 + correct = 0 + + model.eval() + with torch.no_grad(): + for x, y in dataloader_test: + x: torch.Tensor # type: ignore + y: torch.Tensor # type: ignore + x = x.to(device) + y = y.to(device) + if is_half: + x = x.half() + y = y.half() + outputs = model(x) + _, predicted = outputs.max(1) + correct += (y == predicted).long().sum() + total += y.numel() + acc = correct / total + return acc + + +@torch.no_grad() +def infer(model: nn.Module, + dataloader: torch.utils.data.DataLoader, + num_samples=256, + device=torch.device('cuda:0'), + is_half=True): + model.eval() + with torch.no_grad(): + accumulate_batch = 0 + for x, _ in dataloader: + x = x.to(device) + if is_half: + x = x.half() + model(x) + B = x.shape[0] + accumulate_batch += B + if accumulate_batch > num_samples: + break + + +if __name__ == '__main__': + import argparse + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument( + '--data', + type=str, + default='data/imagenet_torch', + help='path to imagenet in torch folder format') + arg_parser.add_argument( + '--num_samples', + type=int, + default=512, + help='number of samples to estimate hessian matrix') + arg_parser.add_argument( + '--batch_size', + type=int, + default=128, + help='batch size for evaluation and inference') + arg_parser.add_argument( + '--fp16', + type=bool, + default=False, + help='whether to use fp16 for evaluation and inference') + args = arg_parser.parse_args() + + data_path = args.data + num_samples = args.num_samples + batch_size = args.batch_size + + model = torchvision.models.resnet18(pretrained=True) + if args.fp16: + model = model.half() + train_loader, test_loader = get_dataloaders(batch_size, 4, data_path) + + compressor = GPTQCompressor() + + # # use_triton_ops is True + # compressor.prepare(model, + # quant_conv=True, + # quant_linear=True, + # use_triton_ops=False, + # skipped_layers=['conv1'], + # bits=4, + # groupsize=128) + + # # quantize activation for linear + # a_qconfig = dict(bits=4, perchannel=True, sym=False) + compressor.prepare( + model, + quant_conv=True, + quant_linear=True, + use_triton_ops=False, + skipped_layers=['conv1'], + # a_qconfig=a_qconfig + ) + + model.cuda() + + enable_observer_linear(model) + compressor.init_hessian() + compressor.register_hessian_hooks() + infer(model, test_loader, num_samples=num_samples, is_half=args.fp16) + compressor.remove_hessian_hooks() + compressor.quant_with_default_qconfig() + + print('start evaluation') + disable_observer_linear(model) + model = model.cuda() + acc = eval(model, test_loader, is_half=args.fp16) + print('accuracy:', acc.item()) diff --git a/projects/mmrazor_large/examples/ResNet/resnet18_sparse_gpt.py b/projects/mmrazor_large/examples/ResNet/resnet18_sparse_gpt.py new file mode 100644 index 000000000..0e6658a6f --- /dev/null +++ b/projects/mmrazor_large/examples/ResNet/resnet18_sparse_gpt.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# model settings +import os.path as osp + +import torch +import torch.nn as nn +import torchvision +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +from mmrazor.implementations.pruning import sparse_gpt + + +def get_dataloaders(batch_size, n_workers, path=''): + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + train_dataset = datasets.ImageFolder( + osp.join(path, 'train'), + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ]), + ) + + test_dataset = datasets.ImageFolder( + osp.join(path, 'val'), + transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ]), + ) + + dataloader_train = torch.utils.data.DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=n_workers, + pin_memory=True, + ) + dataloader_test = torch.utils.data.DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=n_workers, + pin_memory=True, + ) + return dataloader_train, dataloader_test + + +@torch.no_grad() +def eval(model: nn.Module, + dataloader_test: DataLoader, + device=torch.device('cuda:0')): + + total = 0 + correct = 0 + + model.eval() + with torch.no_grad(): + for x, y in dataloader_test: + x: torch.Tensor # type: ignore + y: torch.Tensor # type: ignore + x = x.to(device) + outputs = model(x) + _, predicted = outputs.max(1) + y = y.to(device) + correct += (y == predicted).long().sum() + total += y.numel() + acc = correct / total + return acc + + +@torch.no_grad() +def infer(model: nn.Module, + dataloader: torch.utils.data.DataLoader, + num_samples=256, + device=torch.device('cuda:0')): + model.eval() + with torch.no_grad(): + accumulate_batch = 0 + for x, _ in dataloader: + x = x.to(device) + model(x) + B = x.shape[0] + accumulate_batch += B + if accumulate_batch > num_samples: + break + + +if __name__ == '__main__': + import argparse + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument( + '--data', + type=str, + default='data/imagenet_torch', + help='path to imagenet in torch folder format') + arg_parser.add_argument( + '--num_samples', + type=int, + default=512, + help='number of samples to estimate hessian matrix') + arg_parser.add_argument( + '--batch_size', + type=int, + default=128, + help='batch size for evaluation and inference') + args = arg_parser.parse_args() + + data_path = args.data + num_samples = args.num_samples + batch_size = args.batch_size + + model = torchvision.models.resnet18(pretrained=True) + train_loader, test_loader = get_dataloaders(batch_size, 4, data_path) + + compressor = sparse_gpt.SparseGptCompressor() + compressor.prepare(model) + + model.cuda() + + compressor.init_hessian() + compressor.register_hessian_hooks() + infer(model, test_loader, num_samples=num_samples) + compressor.remove_hessian_hooks() + compressor.prune_24() + model = compressor.to_static_model(model) + + print('start evaluation') + model = model.cuda() + acc = eval(model, test_loader) + print('accuracy:', acc.item()) diff --git a/projects/mmrazor_large/examples/language_models/LLaMA/README.md b/projects/mmrazor_large/examples/language_models/LLaMA/README.md new file mode 100644 index 000000000..7d9862de8 --- /dev/null +++ b/projects/mmrazor_large/examples/language_models/LLaMA/README.md @@ -0,0 +1,55 @@ +# Examples for LLaMA + +## SparseGPT + +For more details about SparseGPT, please refer to [SparseGPT](../../../algorithms/SparseGPT.md) + +### Usage + +```shell +# example for decapoda-research/llama-7b-hf +python projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt.py decapoda-research/llama-7b-hf c4 + +# help +usage: llama_sparse_gpt.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M] model {wikitext2,ptb,c4} + +positional arguments: + model Llama model to load + {wikitext2,ptb,c4} Where to extract calibration data from. + +optional arguments: + -h, --help show this help message and exit + --seed SEED Seed for sampling the calibration data. + --nsamples NSAMPLES Number of calibration data samples. + --batch_size BATCH_SIZE + Batchsize for calibration and evaluation. + --save SAVE Path to saved model. + -m M Whether to enable memory efficient forward +``` + +## GPTQ + +For more details about GPTQ, please refer to [GPTQ](../../../algorithms/GPTQ.md) + +### Usage + +```shell +# example for decapoda-research/llama-7b-hf +python projects/mmrazor_large/examples/language_models/LLaMA/llama_gptq.py decapoda-research/llama-7b-hf c4 + +# help +usage: llama_gptq.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M] model {wikitext2,ptb,c4} + +positional arguments: + model Llama model to load + {wikitext2,ptb,c4} Where to extract calibration data from. + +optional arguments: + -h, --help show this help message and exit + --seed SEED Seed for sampling the calibration data. + --nsamples NSAMPLES Number of calibration data samples. + --batch_size BATCH_SIZE + Batchsize for calibration and evaluation. + --save SAVE Path to saved model. + -m M Whether to enable memory efficient forward +``` diff --git a/projects/mmrazor_large/examples/language_models/LLaMA/datautils.py b/projects/mmrazor_large/examples/language_models/LLaMA/datautils.py new file mode 100755 index 000000000..04697d560 --- /dev/null +++ b/projects/mmrazor_large/examples/language_models/LLaMA/datautils.py @@ -0,0 +1,152 @@ +import numpy as np +import torch +from torch.utils.data import DataLoader +from torch.utils.data import Dataset as TorchDataset +from torch.utils.data import DistributedSampler + + +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + + +def get_wikitext2(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + trainenc = tokenizer(' '.join(traindata['text']), return_tensors='pt') + testenc = tokenizer('\n\n'.join(testdata['text']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_ptb(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') + testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + trainenc = tokenizer(' '.join(traindata['sentence']), return_tensors='pt') + testenc = tokenizer(' '.join(testdata['sentence']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_c4(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset( + 'allenai/c4', + 'allenai--c4', + data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, + split='train') + valdata = load_dataset( + 'allenai/c4', + 'allenai--c4', + data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, + split='validation') + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') + valenc = valenc.input_ids[:, :(256 * seqlen)] + + class TokenizerWrapper: + + def __init__(self, input_ids): + self.input_ids = input_ids + + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''): + if 'wikitext2' in name: + return get_wikitext2(nsamples, seed, seqlen, model) + if 'ptb' in name: + return get_ptb(nsamples, seed, seqlen, model) + if 'c4' in name: + return get_c4(nsamples, seed, seqlen, model) + + +def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048): + # tokens: 1 N + N = tokens.shape[1] + num_drop = N % batch_seq_len + if num_drop != 0: + tokens = tokens[:, :-num_drop] + tokens = tokens.reshape([-1, batch_seq_len]) # B N + return tokens + + +class LanguageDataset(TorchDataset): + + def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None: + super().__init__() + # seq: 1, N + self.seq_len = seq_len + + self.seq = fold_tokens(seq) # B N + + def __len__(self) -> int: + return self.seq.shape[0] + + def __getitem__(self, index): + return self.seq[index] + + +def build_language_loader(testloader, world_size, rank, model, batch_size=128): + val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen) + distributed_sampler = DistributedSampler( + val_dataset, num_replicas=world_size, rank=rank, shuffle=False) + batch_size = min(len(val_dataset) // world_size, batch_size) + val_dataloader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=0, + pin_memory=True, + drop_last=True, + sampler=distributed_sampler) + return val_dataloader diff --git a/projects/mmrazor_large/examples/language_models/LLaMA/llama_gptq.py b/projects/mmrazor_large/examples/language_models/LLaMA/llama_gptq.py new file mode 100644 index 000000000..0eae9b4f0 --- /dev/null +++ b/projects/mmrazor_large/examples/language_models/LLaMA/llama_gptq.py @@ -0,0 +1,162 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from datautils import get_loaders +from transformers.models.llama import LlamaForCausalLM +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from utils import opt_eval, opt_infer + +from mmrazor.implementations.pruning.sparse_gpt.utils import \ + memory_efficient_forward +from mmrazor.implementations.quantization.gptq import (GPTQLinear, + TritonGPTQLinear) +from mmrazor.utils import print_log + + +def enable_observer_linear(model): + print_log('Enable updating qparams for GPTQLinear!') + for _, module in model.named_modules(): + if isinstance(module, GPTQLinear): + module.fix_qparams = False + + +def disable_observer_linear(model): + print_log('Disable updating qparams for GPTQLinear!') + for _, module in model.named_modules(): + if isinstance(module, GPTQLinear): + module.fix_qparams = True + + +def del_redundant_attr(model): + print_log('Del redundant weight for GPTQLinear!') + for _, module in model.named_modules(): + if isinstance(module, TritonGPTQLinear): + del module.weight + + +def get_model(model): + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + model: LlamaForCausalLM = LlamaForCausalLM.from_pretrained( + model, + torch_dtype='auto', + ) + model.seqlen = 2048 + return model + + +if __name__ == '__main__': + + import argparse + parser = argparse.ArgumentParser() + + parser.add_argument('model', type=str, help='Llama model to load') + parser.add_argument( + '--dataset', + type=str, + choices=['wikitext2', 'ptb', 'c4'], + help='Where to extract calibration data from.') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Seed for sampling the calibration data.') + parser.add_argument( + '--nsamples', + type=int, + default=128, + help='Number of calibration data samples.') + parser.add_argument( + '--batch_size', + type=int, + default=16, + help='Batchsize for calibration and evaluation.') + parser.add_argument( + '--save', type=str, default='', help='Path to saved model.') + parser.add_argument( + '--quant_ckpt', type=str, default='', help='Quantized ckpt to load.') + parser.add_argument( + '--dev', type=str, default='cuda:0', help='Use which device.') + parser.add_argument( + '-m', + type=bool, + default=False, + help='Whether to enable memory efficient forward') + + args = parser.parse_args() + + DEV = args.dev + + model = get_model(args.model) + model.to(DEV) + model.eval() + print_log('load model over') + + from mmrazor.implementations.quantization import gptq + compressor = gptq.GPTQCompressor() + # use_triton_ops is True + compressor.prepare( + model.model.layers, + quant_conv=True, + use_triton_ops=True, + quant_linear=True, + bits=4, + groupsize=128) + + # # quantize activation for linear + # # a_qconfig = dict(bits=4, perchannel=False, sym=False) + # compressor.prepare( + # model.model.layers, + # quant_conv=True, + # quant_linear=True, + # use_triton_ops=False, + # # a_qconfig=a_qconfig + # ) + + if args.quant_ckpt: + del_redundant_attr(model) + model.load_state_dict(torch.load(args.quant_ckpt)) + else: + dataloader, testloader = get_loaders( + args.dataset, + seed=args.seed, + model=args.model, + seqlen=model.seqlen) + print_log('load data for infer over') + + compressor.init_hessian() + enable_observer_linear(model) + with memory_efficient_forward( + model, + wrap_modules=[LlamaDecoderLayer], + enabled=args.m, + device=DEV): + compressor.register_hessian_hooks() + opt_infer( + model, + testloader, + DEV, + batch_size=args.batch_size, + num_samples=args.nsamples) + compressor.remove_hessian_hooks() + compressor.quant_with_default_qconfig(device=DEV) + + disable_observer_linear(model) + with memory_efficient_forward( + model, wrap_modules=[LlamaDecoderLayer], enabled=args.m, + device=DEV): + + # for dataset in ['wikitext2', 'ptb', 'c4']: + for dataset in ['wikitext2']: + dataloader, testloader = get_loaders( + dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) + print_log(dataset) + opt_eval(model, testloader, DEV, batch_size=args.batch_size) + + if args.save and not args.quant_ckpt: + print_log(f'save model in {args.save}') + torch.save(model.state_dict(), args.save) diff --git a/projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt.py b/projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt.py new file mode 100644 index 000000000..972feff2f --- /dev/null +++ b/projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from datautils import get_loaders +from transformers.models.llama import LlamaForCausalLM +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from utils import opt_eval, opt_infer + +from mmrazor.implementations.pruning.sparse_gpt.utils import \ + memory_efficient_forward +from mmrazor.utils import print_log + + +def get_model(model): + import torch + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + model: LlamaForCausalLM = LlamaForCausalLM.from_pretrained( + model, + torch_dtype='auto', + ) + model.seqlen = 2048 + return model + + +if __name__ == '__main__': + + import argparse + parser = argparse.ArgumentParser() + + parser.add_argument('model', type=str, help='Llama model to load') + parser.add_argument( + 'dataset', + type=str, + choices=['wikitext2', 'ptb', 'c4'], + help='Where to extract calibration data from.') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Seed for sampling the calibration data.') + parser.add_argument( + '--nsamples', + type=int, + default=128, + help='Number of calibration data samples.') + parser.add_argument( + '--batch_size', + type=int, + default=16, + help='Batchsize for calibration and evaluation.') + parser.add_argument( + '--save', type=str, default='', help='Path to saved model.') + parser.add_argument( + '-m', + type=bool, + default=False, + help='Whether to enable memory efficient forward') + + args = parser.parse_args() + + torch.set_default_dtype(torch.half) + DEV = torch.device('cuda:0') + + model = get_model(args.model) + model.eval() + print_log('load model over') + + dataloader, testloader = get_loaders( + args.dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) + print_log('load data for infer over') + + from mmrazor.implementations.pruning import sparse_gpt + compressor = sparse_gpt.SparseGptCompressor() + compressor.prepare(model.model.layers) + + compressor.init_hessian() + with memory_efficient_forward( + model, wrap_modules=[LlamaDecoderLayer], enabled=args.m): + compressor.register_hessian_hooks() + opt_infer( + model, + testloader, + DEV, + batch_size=args.batch_size, + num_samples=args.nsamples) + compressor.remove_hessian_hooks() + compressor.prune_24() + + model = compressor.to_static_model(model) + if args.save: + print_log(f'save model in {args.save}') + model.save_pretrained(args.save) + + with memory_efficient_forward( + model, wrap_modules=[LlamaDecoderLayer], enabled=args.m): + + for dataset in ['wikitext2', 'ptb', 'c4']: + dataloader, testloader = get_loaders( + dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) + print_log(dataset) + opt_eval(model, testloader, DEV, batch_size=args.batch_size) diff --git a/projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt_fsdp.py b/projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt_fsdp.py new file mode 100644 index 000000000..14d40172b --- /dev/null +++ b/projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt_fsdp.py @@ -0,0 +1,198 @@ +import functools +import os + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from datautils import build_language_loader, get_loaders +from llama_sparse_gpt import get_model +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.api import ShardingStrategy +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy +from utils import init_on_meta, opt_eval_fsdp, opt_infer_fsdp + +from mmrazor.implementations.pruning import sparse_gpt +from mmrazor.utils import print_log + + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12356' + + dist.init_process_group('nccl', rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + print_log(f'init {rank}/{world_size}', only_rank0=False) + + +def init_fn_wrapper(model: nn.Module, model_copy: nn.Module): + + def find_module_in_model_copy(module: nn.Module): + name2module = dict(model.named_modules()) + module2name = dict([(v, k) for k, v in name2module.items()]) + + name = module2name[module] + return dict(model_copy.named_modules())[name] + + def _materialize_meta_module(module: nn.Module, ): + + def meta_to_empty(p: torch.Tensor): + if p.device == torch.device('meta'): + return p.new_empty(p.shape, device='cpu') + else: + return p + + module._apply(meta_to_empty) + if dist.get_rank() == 0: + assert model_copy is not None + module_copy = find_module_in_model_copy(module) + + name2p = dict(module_copy.named_parameters(remove_duplicate=False)) + for n, p in module.named_parameters(): + if '_flat_param' not in n: + n = n.replace('_fsdp_wrapped_module.', '') + try: + p.data.copy_(name2p[n]) + except Exception: + pass + name2p = dict(module_copy.named_buffers(remove_duplicate=False)) + for n, p in module.named_buffers(): + if '_flat_param' not in n: + n = n.replace('_fsdp_wrapped_module.', '') + try: + p.data.copy_(name2p[n]) + except Exception: + pass + + return _materialize_meta_module + + +def main(rank, world_size=8, args=None): + setup(rank, world_size) + + model_name = args.model + batch_size = args.batch_size + + def build(): + model = get_model(model_name) + + # init compressor + compressor = sparse_gpt.SparseGptCompressor() + compressor.prepare(model.model.layers) + return model, compressor + + with init_on_meta(enable=True): + model, compressor = build() + + if rank == 0: + model_copy, _ = build() # init on cpu + else: + model_copy = None + + # init fsdp + size_based_auto_wrap_policy_x = functools.partial( + size_based_auto_wrap_policy, min_num_params=int(1e8)) + + model = FSDP( + model, + auto_wrap_policy=size_based_auto_wrap_policy_x, + cpu_offload=CPUOffload(True), + sharding_strategy=ShardingStrategy.FULL_SHARD, + device_id=rank, + param_init_fn=init_fn_wrapper(model, model_copy), + sync_module_states=True) + print_log(model) + + # init hessian + + compressor.init_hessian(device='cuda') + compressor.register_hessian_hooks() + + _, testloader = get_loaders( + args.dataset, seed=args.seed, model=model_name, seqlen=model.seqlen) + testloader = build_language_loader( + testloader, world_size, rank, model, batch_size=batch_size) + opt_infer_fsdp(model, testloader) + + compressor.remove_hessian_hooks() + + # prune + name2module = dict(model.named_modules()) + module2name = {} + module2name = dict([(v, k) for k, v in name2module.items()]) + + with torch.no_grad(): + for fsdp in FSDP.fsdp_modules(model): + fsdp._reset_lazy_init() + with FSDP.summon_full_params(fsdp, recurse=False): + fsdp_name = module2name[fsdp] + for name, op in fsdp.named_modules(): + if name.count('_fsdp_wrapped_module') <= 1: + if isinstance(op, sparse_gpt.SparseGptMixIn): + try: + op.prune(0.5, prunen=2, prunem=4) + print_log( + f'prune {fsdp_name}.{name} successfully.', # noqa + only_rank0=True) + except Exception as e: + print_log( + f'prune {fsdp_name}.{name} failed, as {e}', # noqa + only_rank0=True) + fsdp._reset_lazy_init() + + # save + if args.save: + print_log(f'save model in {args.save}') + model._reset_lazy_init() + with FSDP.summon_full_params(model, rank0_only=True, writeback=False): + if dist.get_rank() == 0: + model.save_pretrained(args.save) + + # val + torch.cuda.empty_cache() + model._reset_lazy_init() + for dataset in ['wikitext2', 'ptb', 'c4']: + _, testloader = get_loaders( + dataset, seed=args.seed, model=model_name, seqlen=model.seqlen) + testloader = build_language_loader( + testloader, world_size, rank, model, batch_size=batch_size) + print_log(dataset) + opt_eval_fsdp(model, testloader, torch.device('cuda')) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + + parser.add_argument( + 'model', type=str, help='OPT model to load; pass `facebook/opt-X`.') + parser.add_argument( + 'dataset', + type=str, + choices=['wikitext2', 'ptb', 'c4'], + help='Where to extract calibration data from.') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Seed for sampling the calibration data.') + parser.add_argument( + '--nsamples', + type=int, + default=128, + help='Number of calibration data samples.') + parser.add_argument( + '--batch_size', + type=int, + default=64, + help='Batchsize for calibration and evaluation.') + + parser.add_argument( + '--save', type=str, default='', help='Path to saved model.') + parser.add_argument( + '--world_size', type=int, default=1, help='Number of GPUs to use.') + args = parser.parse_args() + + WORLD_SIZE = args.world_size + mp.spawn(main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) diff --git a/projects/mmrazor_large/examples/language_models/LLaMA/utils.py b/projects/mmrazor_large/examples/language_models/LLaMA/utils.py new file mode 100644 index 000000000..1f8ceb87d --- /dev/null +++ b/projects/mmrazor_large/examples/language_models/LLaMA/utils.py @@ -0,0 +1,173 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Example for opt is converted from https://github.com/ist-daslab/sparsegpt +import torch +import torch.nn as nn +from torch import distributed as dist +from torch.utils.data import DataLoader +from transformers import OPTForCausalLM + +from mmrazor.utils import print_log + + +def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048): + # tokens: 1 N + N = tokens.shape[1] + num_drop = N % batch_seq_len + if num_drop != 0: + tokens = tokens[:, :-num_drop] + tokens = tokens.reshape([-1, batch_seq_len]) # B N + return tokens + + +@torch.no_grad() +def opt_eval(model: OPTForCausalLM, + testenc, + dev=torch.device('cuda:0'), + batch_size=16): + print_log('Evaluating ...') + + seqlen = model.seqlen + + testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N + testenc = fold_tokens(testenc, seqlen) # B N + + use_cache = model.config.use_cache + model.config.use_cache = False + nlls = [] + + for i, batch in enumerate(torch.split(testenc, batch_size)): + B = batch.shape[0] + + batch = batch.to(dev) + out: torch.Tensor = model(batch)[0] # 1 + + shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C + shift_labels = batch[:, 1:].flatten() # (B N) + + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) + neg_log_likelihood = loss.float() * seqlen * B + nlls.append(neg_log_likelihood) + + print_log(f'{(i+1)*batch_size} / {len(testenc)}') + + ppl = torch.exp(torch.stack(nlls).sum() / (testenc.numel())) + print_log(f'Perplexity: {ppl.item():3f}') + model.config.use_cache = use_cache + + +@torch.no_grad() +def opt_infer( + model: OPTForCausalLM, + testenc, + dev, + batch_size=16, + num_samples=128, +): + print_log('Infer ...') + + seqlen = model.seqlen + + testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N + testenc = fold_tokens(testenc, seqlen) # B N + + use_cache = model.config.use_cache + model.config.use_cache = False + + for i, batch in enumerate(torch.split(testenc, batch_size)): + batch = batch.to(dev) + _ = model(batch)[0] # 1 + print_log(f'{(i+1)*batch_size} / {num_samples}') + + if (i + 1) * batch_size >= num_samples: + break + model.config.use_cache = use_cache + + +class init_on_meta: + + def __init__(self, enable=True) -> None: + self.enable = enable + self.default_device = torch.ones([]).device + + def __enter__(self): + if self.enable: + torch.set_default_device('meta') + + def __exit__(self, exc_type, exc_value, traceback): + if self.enable: + torch.set_default_device(self.default_device) + + +@torch.no_grad() +def opt_eval_fsdp( + model: nn.Module, + dataloader: DataLoader, + dev=torch.device('cuda:0'), +): + print_log('Evaluating ...') + + use_cache = model.config.use_cache + model.config.use_cache = False + loss_sum = torch.zeros([1], device=dev) + total_seq_len = torch.zeros([1], device=dev, dtype=torch.long) + + for i, batch in enumerate(dataloader): + B, seq_len = batch.shape[:2] + + batch = batch.to(dev) + out: torch.Tensor = model(batch)[0] # 1 + + shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C + shift_labels = batch[:, 1:].flatten() # (B N) + + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) + + neg_log_likelihood = loss.float() * seq_len * B + total_seq_len += seq_len * B + loss_sum += neg_log_likelihood + + if dist.is_initialized(): + world_size = dist.get_world_size() + else: + world_size = 1 + infered_batch = (i + 1) * B * world_size + + print_log(f'{infered_batch} / {len(dataloader.dataset)}') + + if dist.is_initialized(): + dist.all_reduce(loss_sum) + dist.all_reduce(total_seq_len) + + ppl = torch.exp(loss_sum / total_seq_len) + print_log(f'Perplexity: {ppl.item():3f}') + model.config.use_cache = use_cache + + +@torch.no_grad() +def opt_infer_fsdp( + model: nn.Module, + dataloader: DataLoader, + dev=torch.device('cuda:0'), + num_samples=128, +): + print_log('Infering ...') + + model.config.use_cache = False + + for i, batch in enumerate(dataloader): + B = batch.shape[0] + + batch = batch.to(dev) + model(batch)[0] # 1 + + if dist.is_initialized(): + world_size = dist.get_world_size() + else: + world_size = 1 + infered_batch = (i + 1) * B * world_size + + print_log(f'{infered_batch} / {len(dataloader.dataset)}') + if infered_batch >= num_samples: + break diff --git a/projects/mmrazor_large/examples/language_models/OPT/README.md b/projects/mmrazor_large/examples/language_models/OPT/README.md new file mode 100644 index 000000000..a5d1c8030 --- /dev/null +++ b/projects/mmrazor_large/examples/language_models/OPT/README.md @@ -0,0 +1,55 @@ +# Examples for OPT + +## SparseGPT + +For more details about SparseGPT, please refer to [SparseGPT](../../../algorithms/SparseGPT.md) + +### Usage + +```shell +# example for facebook/opt-125m +python projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt.py facebook/opt-125m c4 + +# help +usage: opt_sparse_gpt.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M] model {wikitext2,ptb,c4} + +positional arguments: + model OPT model to load; pass `facebook/opt-X`. + {wikitext2,ptb,c4} Where to extract calibration data from. + +optional arguments: + -h, --help show this help message and exit + --seed SEED Seed for sampling the calibration data. + --nsamples NSAMPLES Number of calibration data samples. + --batch_size BATCH_SIZE + Batchsize for calibration and evaluation. + --save SAVE Path to saved model. + -m M Whether to enable memory efficient forward +``` + +## GPTQ + +For more details about GPTQ, please refer to [GPTQ](../../../algorithms/GPTQ.md) + +### Usage + +```shell +# example for facebook/opt-125m +python projects/mmrazor_large/examples/language_models/OPT/opt_gptq.py facebook/opt-125m c4 + +# help +usage: opt_gptq.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M] model {wikitext2,ptb,c4} + +positional arguments: + model OPT model to load; pass `facebook/opt-X`. + {wikitext2,ptb,c4} Where to extract calibration data from. + +optional arguments: + -h, --help show this help message and exit + --seed SEED Seed for sampling the calibration data. + --nsamples NSAMPLES Number of calibration data samples. + --batch_size BATCH_SIZE + Batchsize for calibration and evaluation. + --save SAVE Path to saved model. + -m M Whether to enable memory efficient forward +``` diff --git a/projects/mmrazor_large/examples/language_models/OPT/datautils.py b/projects/mmrazor_large/examples/language_models/OPT/datautils.py new file mode 100755 index 000000000..04697d560 --- /dev/null +++ b/projects/mmrazor_large/examples/language_models/OPT/datautils.py @@ -0,0 +1,152 @@ +import numpy as np +import torch +from torch.utils.data import DataLoader +from torch.utils.data import Dataset as TorchDataset +from torch.utils.data import DistributedSampler + + +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + + +def get_wikitext2(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + trainenc = tokenizer(' '.join(traindata['text']), return_tensors='pt') + testenc = tokenizer('\n\n'.join(testdata['text']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_ptb(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') + testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + trainenc = tokenizer(' '.join(traindata['sentence']), return_tensors='pt') + testenc = tokenizer(' '.join(testdata['sentence']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_c4(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset( + 'allenai/c4', + 'allenai--c4', + data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, + split='train') + valdata = load_dataset( + 'allenai/c4', + 'allenai--c4', + data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, + split='validation') + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') + valenc = valenc.input_ids[:, :(256 * seqlen)] + + class TokenizerWrapper: + + def __init__(self, input_ids): + self.input_ids = input_ids + + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''): + if 'wikitext2' in name: + return get_wikitext2(nsamples, seed, seqlen, model) + if 'ptb' in name: + return get_ptb(nsamples, seed, seqlen, model) + if 'c4' in name: + return get_c4(nsamples, seed, seqlen, model) + + +def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048): + # tokens: 1 N + N = tokens.shape[1] + num_drop = N % batch_seq_len + if num_drop != 0: + tokens = tokens[:, :-num_drop] + tokens = tokens.reshape([-1, batch_seq_len]) # B N + return tokens + + +class LanguageDataset(TorchDataset): + + def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None: + super().__init__() + # seq: 1, N + self.seq_len = seq_len + + self.seq = fold_tokens(seq) # B N + + def __len__(self) -> int: + return self.seq.shape[0] + + def __getitem__(self, index): + return self.seq[index] + + +def build_language_loader(testloader, world_size, rank, model, batch_size=128): + val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen) + distributed_sampler = DistributedSampler( + val_dataset, num_replicas=world_size, rank=rank, shuffle=False) + batch_size = min(len(val_dataset) // world_size, batch_size) + val_dataloader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=0, + pin_memory=True, + drop_last=True, + sampler=distributed_sampler) + return val_dataloader diff --git a/projects/mmrazor_large/examples/language_models/OPT/opt_gptq.py b/projects/mmrazor_large/examples/language_models/OPT/opt_gptq.py new file mode 100644 index 000000000..5cd48e563 --- /dev/null +++ b/projects/mmrazor_large/examples/language_models/OPT/opt_gptq.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Example for opt is converted from https://github.com/ist-daslab/sparsegpt +import torch +from datautils import get_loaders +from transformers import OPTForCausalLM +from transformers.models.opt.modeling_opt import OPTDecoderLayer +from utils import opt_eval, opt_infer + +from mmrazor.implementations.pruning.sparse_gpt.utils import \ + memory_efficient_forward +from mmrazor.implementations.quantization.gptq import (GPTQLinear, + TritonGPTQLinear) +from mmrazor.utils import print_log + + +def enable_observer_linear(model): + print_log('Enable updating qparams for GPTQLinear!') + for _, module in model.named_modules(): + if isinstance(module, GPTQLinear): + module.fix_qparams = False + + +def disable_observer_linear(model): + print_log('Disable updating qparams for GPTQLinear!') + for _, module in model.named_modules(): + if isinstance(module, GPTQLinear): + module.fix_qparams = True + + +def del_redundant_attr(model): + print_log('Del redundant weight for GPTQLinear!') + for _, module in model.named_modules(): + if isinstance(module, TritonGPTQLinear): + del module.weight + + +def get_model(model): + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto') + model.seqlen = model.config.max_position_embeddings + return model + + +if __name__ == '__main__': + + import argparse + parser = argparse.ArgumentParser() + + parser.add_argument('model', type=str, help='Llama model to load') + parser.add_argument( + '--dataset', + type=str, + choices=['wikitext2', 'ptb', 'c4'], + help='Where to extract calibration data from.') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Seed for sampling the calibration data.') + parser.add_argument( + '--nsamples', + type=int, + default=128, + help='Number of calibration data samples.') + parser.add_argument( + '--batch_size', + type=int, + default=16, + help='Batchsize for calibration and evaluation.') + parser.add_argument( + '--save', type=str, default='', help='Path to saved model.') + parser.add_argument( + '--quant_ckpt', type=str, default='', help='Quantized ckpt to load.') + parser.add_argument( + '--dev', type=str, default='cuda:0', help='Use which device.') + parser.add_argument( + '-m', + type=bool, + default=False, + help='Whether to enable memory efficient forward') + + args = parser.parse_args() + + DEV = args.dev + + model = get_model(args.model) + model.to(DEV) + model.eval() + print_log('load model over') + + from mmrazor.implementations.quantization import gptq + compressor = gptq.GPTQCompressor() + # use_triton_ops is True + compressor.prepare( + model.model.layers, + quant_conv=True, + use_triton_ops=True, + quant_linear=True, + bits=4, + groupsize=128) + + # # # quantize activation for linear + # # a_qconfig = dict(bits=4, perchannel=False, sym=False) + # compressor.prepare( + # model.model.decoder, + # quant_conv=True, + # quant_linear=True, + # use_triton_ops=False, + # # a_qconfig=a_qconfig + # ) + + if args.quant_ckpt: + del_redundant_attr(model) + model.load_state_dict(torch.load(args.quant_ckpt)) + else: + dataloader, testloader = get_loaders( + args.dataset, + seed=args.seed, + model=args.model, + seqlen=model.seqlen) + print_log('load data for infer over') + + compressor.init_hessian() + enable_observer_linear(model) + with memory_efficient_forward( + model, wrap_modules=[OPTDecoderLayer], enabled=args.m, + device=DEV): + compressor.register_hessian_hooks() + opt_infer( + model, + testloader, + DEV, + batch_size=args.batch_size, + num_samples=args.nsamples) + compressor.remove_hessian_hooks() + compressor.quant_with_default_qconfig(device=DEV) + + disable_observer_linear(model) + with memory_efficient_forward( + model, wrap_modules=[OPTDecoderLayer], enabled=args.m, device=DEV): + + # for dataset in ['wikitext2', 'ptb', 'c4']: + for dataset in ['wikitext2']: + dataloader, testloader = get_loaders( + dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) + print_log(dataset) + opt_eval(model, testloader, DEV, batch_size=args.batch_size) + + if args.save and not args.quant_ckpt: + print_log(f'save model in {args.save}') + torch.save(model.state_dict(), args.save) diff --git a/projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt.py b/projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt.py new file mode 100644 index 000000000..29d0947d3 --- /dev/null +++ b/projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt.py @@ -0,0 +1,105 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Example for opt is converted from https://github.com/ist-daslab/sparsegpt +import torch +from datautils import get_loaders +from transformers import OPTForCausalLM +from transformers.models.opt.modeling_opt import OPTDecoderLayer +from utils import opt_eval, opt_infer + +from mmrazor.implementations.pruning.sparse_gpt.utils import \ + memory_efficient_forward +from mmrazor.utils import print_log + + +def get_model(model): + import torch + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto') + model.seqlen = model.config.max_position_embeddings + return model + + +if __name__ == '__main__': + + import argparse + parser = argparse.ArgumentParser() + + parser.add_argument( + 'model', type=str, help='OPT model to load; pass `facebook/opt-X`.') + parser.add_argument( + 'dataset', + type=str, + choices=['wikitext2', 'ptb', 'c4'], + help='Where to extract calibration data from.') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Seed for sampling the calibration data.') + parser.add_argument( + '--nsamples', + type=int, + default=128, + help='Number of calibration data samples.') + parser.add_argument( + '--batch_size', + type=int, + default=64, + help='Batchsize for calibration and evaluation.') + parser.add_argument( + '--save', type=str, default='', help='Path to saved model.') + parser.add_argument( + '-m', + type=bool, + default=False, + help='Whether to enable memory efficient forward') + + args = parser.parse_args() + + DEV = torch.device('cuda:0') + + model = get_model(args.model) + model.eval() + print_log('load model over') + + dataloader, testloader = get_loaders( + args.dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) + print_log('load data for infer over') + + from mmrazor.implementations.pruning import sparse_gpt + compressor = sparse_gpt.SparseGptCompressor() + compressor.prepare(model.model.decoder) + + compressor.init_hessian() + with memory_efficient_forward( + model, wrap_modules=[OPTDecoderLayer], enabled=args.m): + + compressor.register_hessian_hooks() + opt_infer( + model, + testloader, + DEV, + batch_size=args.batch_size, + num_samples=args.nsamples) + compressor.remove_hessian_hooks() + compressor.prune_24() + + model = compressor.to_static_model(model) + if args.save: + print_log(f'save model in {args.save}') + model.save_pretrained(args.save) + + with memory_efficient_forward( + model, wrap_modules=[OPTDecoderLayer], enabled=args.m): + + for dataset in ['wikitext2', 'ptb', 'c4']: + dataloader, testloader = get_loaders( + dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) + print_log(dataset) + opt_eval(model, testloader, DEV, batch_size=args.batch_size) diff --git a/projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt_fsdp.py b/projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt_fsdp.py new file mode 100644 index 000000000..e357be01a --- /dev/null +++ b/projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt_fsdp.py @@ -0,0 +1,198 @@ +import functools +import os + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from datautils import build_language_loader, get_loaders +from opt_sparse_gpt import get_model +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.api import ShardingStrategy +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy +from utils import init_on_meta, opt_eval_fsdp, opt_infer_fsdp + +from mmrazor.implementations.pruning import sparse_gpt +from mmrazor.utils import print_log + + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12356' + + dist.init_process_group('nccl', rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + print_log(f'init {rank}/{world_size}', only_rank0=False) + + +def init_fn_wrapper(model: nn.Module, model_copy: nn.Module): + + def find_module_in_model_copy(module: nn.Module): + name2module = dict(model.named_modules()) + module2name = dict([(v, k) for k, v in name2module.items()]) + + name = module2name[module] + return dict(model_copy.named_modules())[name] + + def _materialize_meta_module(module: nn.Module, ): + + def meta_to_empty(p: torch.Tensor): + if p.device == torch.device('meta'): + return p.new_empty(p.shape, device='cpu') + else: + return p + + module._apply(meta_to_empty) + if dist.get_rank() == 0: + assert model_copy is not None + module_copy = find_module_in_model_copy(module) + + name2p = dict(module_copy.named_parameters(remove_duplicate=False)) + for n, p in module.named_parameters(): + if '_flat_param' not in n: + n = n.replace('_fsdp_wrapped_module.', '') + try: + p.data.copy_(name2p[n]) + except Exception: + pass + name2p = dict(module_copy.named_buffers(remove_duplicate=False)) + for n, p in module.named_buffers(): + if '_flat_param' not in n: + n = n.replace('_fsdp_wrapped_module.', '') + try: + p.data.copy_(name2p[n]) + except Exception: + pass + + return _materialize_meta_module + + +def main(rank, world_size=8, args=None): + setup(rank, world_size) + + model_name = args.model + batch_size = args.batch_size + + def build(): + model = get_model(model_name) + + # init mutator + mutator = sparse_gpt.SparseGptCompressor() + mutator.prepare(model.model.decoder) + return model, mutator + + with init_on_meta(enable=True): + model, mutator = build() + + if rank == 0: + model_copy, _ = build() # init on cpu + else: + model_copy = None + + # init fsdp + size_based_auto_wrap_policy_x = functools.partial( + size_based_auto_wrap_policy, min_num_params=int(1e8)) + + model = FSDP( + model, + auto_wrap_policy=size_based_auto_wrap_policy_x, + cpu_offload=CPUOffload(True), + sharding_strategy=ShardingStrategy.FULL_SHARD, + device_id=rank, + param_init_fn=init_fn_wrapper(model, model_copy), + sync_module_states=True) + print_log(model) + + # init hessian + + mutator.init_hessian(device='cuda') + mutator.register_hessian_hooks(model) + + _, testloader = get_loaders( + args.dataset, seed=args.seed, model=model_name, seqlen=model.seqlen) + testloader = build_language_loader( + testloader, world_size, rank, model, batch_size=batch_size) + opt_infer_fsdp(model, testloader) + + mutator.remove_hessian_hooks() + + # prune + name2module = dict(model.named_modules()) + module2name = {} + module2name = dict([(v, k) for k, v in name2module.items()]) + + with torch.no_grad(): + for fsdp in FSDP.fsdp_modules(model): + fsdp._reset_lazy_init() + with FSDP.summon_full_params(fsdp, recurse=False): + fsdp_name = module2name[fsdp] + for name, op in fsdp.named_modules(): + if name.count('_fsdp_wrapped_module') <= 1: + if isinstance(op, sparse_gpt.SparseGptMixIn): + try: + op.prune(0.5, prunen=2, prunem=4) + print_log( + f'prune {fsdp_name}.{name} successfully.', # noqa + only_rank0=True) + except Exception as e: + print_log( + f'prune {fsdp_name}.{name} failed, as {e}', # noqa + only_rank0=True) + fsdp._reset_lazy_init() + + # save + if args.save: + print_log(f'save model in {args.save}') + model._reset_lazy_init() + with FSDP.summon_full_params(model, rank0_only=True, writeback=False): + if dist.get_rank() == 0: + model.save_pretrained(args.save) + + # val + torch.cuda.empty_cache() + model._reset_lazy_init() + for dataset in ['wikitext2', 'ptb', 'c4']: + _, testloader = get_loaders( + dataset, seed=args.seed, model=model_name, seqlen=model.seqlen) + testloader = build_language_loader( + testloader, world_size, rank, model, batch_size=batch_size) + print_log(dataset) + opt_eval_fsdp(model, testloader, torch.device('cuda')) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + + parser.add_argument( + 'model', type=str, help='OPT model to load; pass `facebook/opt-X`.') + parser.add_argument( + 'dataset', + type=str, + choices=['wikitext2', 'ptb', 'c4'], + help='Where to extract calibration data from.') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Seed for sampling the calibration data.') + parser.add_argument( + '--nsamples', + type=int, + default=128, + help='Number of calibration data samples.') + parser.add_argument( + '--batch_size', + type=int, + default=64, + help='Batchsize for calibration and evaluation.') + + parser.add_argument( + '--save', type=str, default='', help='Path to saved model.') + parser.add_argument( + '--world_size', type=int, default=1, help='Number of GPUs to use.') + args = parser.parse_args() + + WORLD_SIZE = args.world_size + mp.spawn(main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) diff --git a/projects/mmrazor_large/examples/language_models/OPT/utils.py b/projects/mmrazor_large/examples/language_models/OPT/utils.py new file mode 100644 index 000000000..a728a2268 --- /dev/null +++ b/projects/mmrazor_large/examples/language_models/OPT/utils.py @@ -0,0 +1,171 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Example for opt is converted from https://github.com/ist-daslab/sparsegpt +import torch +import torch.nn as nn +from torch import distributed as dist +from torch.utils.data import DataLoader +from transformers import OPTForCausalLM + +from mmrazor.utils import print_log + + +def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048): + # tokens: 1 N + N = tokens.shape[1] + num_drop = N % batch_seq_len + if num_drop != 0: + tokens = tokens[:, :-num_drop] + tokens = tokens.reshape([-1, batch_seq_len]) # B N + return tokens + + +@torch.no_grad() +def opt_eval(model: OPTForCausalLM, + testenc, + dev=torch.device('cuda:0'), + batch_size=16): + print_log('Evaluating ...') + + seqlen = model.seqlen + + testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N + testenc = fold_tokens(testenc, seqlen) # B N + + use_cache = model.config.use_cache + model.config.use_cache = False + nlls = [] + + for i, batch in enumerate(torch.split(testenc, batch_size)): + B = batch.shape[0] + + batch = batch.to(dev) + out: torch.Tensor = model(batch)[0] # 1 + + shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C + shift_labels = batch[:, 1:].flatten() # (B N) + + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) + neg_log_likelihood = loss.float() * seqlen * B + nlls.append(neg_log_likelihood) + + print_log(f'{(i+1)*batch_size} / {len(testenc)}') + + ppl = torch.exp(torch.stack(nlls).sum() / (testenc.numel())) + print_log(f'Perplexity: {ppl.item():3f}') + model.config.use_cache = use_cache + + +@torch.no_grad() +def opt_infer( + model: OPTForCausalLM, + testenc, + dev, + batch_size=16, + num_samples=128, +): + print_log('Infer ...') + + seqlen = model.seqlen + + testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N + testenc = fold_tokens(testenc, seqlen) # B N + + model.config.use_cache = False + + for i, batch in enumerate(torch.split(testenc, batch_size)): + batch = batch.to(dev) + _ = model(batch)[0] # 1 + print_log(f'{(i+1)*batch_size} / {num_samples}') + + if (i + 1) * batch_size >= num_samples: + break + + +class init_on_meta: + + def __init__(self, enable=True) -> None: + self.enable = enable + self.default_device = torch.ones([]).device + + def __enter__(self): + if self.enable: + torch.set_default_device('meta') + + def __exit__(self, exc_type, exc_value, traceback): + if self.enable: + torch.set_default_device(self.default_device) + + +@torch.no_grad() +def opt_eval_fsdp( + model: nn.Module, + dataloader: DataLoader, + dev=torch.device('cuda:0'), +): + print_log('Evaluating ...') + + use_cache = model.config.use_cache + model.config.use_cache = False + loss_sum = torch.zeros([1], device=dev) + total_seq_len = torch.zeros([1], device=dev, dtype=torch.long) + + for i, batch in enumerate(dataloader): + B, seq_len = batch.shape[:2] + + batch = batch.to(dev) + out: torch.Tensor = model(batch)[0] # 1 + + shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C + shift_labels = batch[:, 1:].flatten() # (B N) + + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) + + neg_log_likelihood = loss.float() * seq_len * B + total_seq_len += seq_len * B + loss_sum += neg_log_likelihood + + if dist.is_initialized(): + world_size = dist.get_world_size() + else: + world_size = 1 + infered_batch = (i + 1) * B * world_size + + print_log(f'{infered_batch} / {len(dataloader.dataset)}') + + if dist.is_initialized(): + dist.all_reduce(loss_sum) + dist.all_reduce(total_seq_len) + + ppl = torch.exp(loss_sum / total_seq_len) + print_log(f'Perplexity: {ppl.item():3f}') + model.config.use_cache = use_cache + + +@torch.no_grad() +def opt_infer_fsdp( + model: nn.Module, + dataloader: DataLoader, + dev=torch.device('cuda:0'), + num_samples=128, +): + print_log('Infering ...') + + model.config.use_cache = False + + for i, batch in enumerate(dataloader): + B = batch.shape[0] + + batch = batch.to(dev) + model(batch)[0] # 1 + + if dist.is_initialized(): + world_size = dist.get_world_size() + else: + world_size = 1 + infered_batch = (i + 1) * B * world_size + + print_log(f'{infered_batch} / {len(dataloader.dataset)}') + if infered_batch >= num_samples: + break diff --git a/requirements/tests.txt b/requirements/tests.txt index 5980dc303..b025f5a67 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -7,5 +7,6 @@ nbformat numpy < 1.24.0 # A temporary solution for tests with mmdet. onnx pytest +triton==2.0.0 xdoctest >= 0.10.0 yapf diff --git a/tests/test_impl/test_pruning/test_sparse_gpt/test_op.py b/tests/test_impl/test_pruning/test_sparse_gpt/test_op.py new file mode 100644 index 000000000..636d64f67 --- /dev/null +++ b/tests/test_impl/test_pruning/test_sparse_gpt/test_op.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch +import torch.nn as nn + +from mmrazor import digit_version +from mmrazor.implementations.pruning import sparse_gpt + + +class TestSparseGptOps(unittest.TestCase): + + @torch.no_grad() + def test_op(self): + if digit_version(torch.__version__) < digit_version('1.12.0'): + self.skipTest('torch<1.12.0') + + def get_loss(linear, linear1, data): + y = linear(data) + y1 = linear1(data) + return (y - y1).square().sum() + + def infer(model, dataset): + for x in dataset: + model(x) + + for device in ['cpu']: + device = torch.device(device) + + # prepare + + linear = nn.Linear(12, 20, bias=False).to(device) + sparse_linear = sparse_gpt.SparseGptLinear( + 12, 20, bias=False).to(device) + sparse_linear.load_state_dict(linear.state_dict(), strict=False) + + random_data = torch.rand([10, 5, 12]).to( + device) # [loader_batch,batch,feature] + data_0 = random_data[0] + + self.assertTrue(get_loss(linear, sparse_linear, data_0) == 0) + + # prune + + sparse_linear.init_hessian() + sparse_linear.register_hessian_hook() + infer(sparse_linear, random_data) + sparse_linear.remove_hessian_hook() + + sparse_linear.prune(0.5) + + # compare + + print('norm:', linear(data_0).norm(2)) + print('distance:', get_loss(linear, sparse_linear, data_0)) + + @torch.no_grad() + def test_model(self): + if digit_version(torch.__version__) < digit_version('1.12.0'): + self.skipTest('torch<1.12.0') + import torchvision + model = torchvision.models.resnet18() + + mutator = sparse_gpt.SparseGptCompressor() + mutator.prepare(model) + + x = torch.rand(10, 3, 224, 224) + mutator.init_hessian() + mutator.register_hessian_hooks() + model(x) + mutator.remove_hessian_hooks() + mutator.prune_24() + + model = mutator.to_static_model(model) + assert type(model.conv1) is nn.Conv2d diff --git a/tests/test_impl/test_quantization/test_gptq/test_op_gptq.py b/tests/test_impl/test_quantization/test_gptq/test_op_gptq.py new file mode 100644 index 000000000..4928e0f17 --- /dev/null +++ b/tests/test_impl/test_quantization/test_gptq/test_op_gptq.py @@ -0,0 +1,80 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch +import torch.nn as nn + +from mmrazor import digit_version +from mmrazor.implementations.quantization import gptq + + +class TestGPTQOps(unittest.TestCase): + + @torch.no_grad() + def test_op(self): + if digit_version(torch.__version__) < digit_version( + '1.12.0') or not torch.cuda.is_available(): + self.skipTest('torch<1.12.0') + + def get_loss(linear, linear1, data): + y = linear(data) + y1 = linear1(data) + return (y - y1).square().sum() + + def infer(model, dataset): + for x in dataset: + model(x) + + for device in ['cpu']: + device = torch.device(device) + + # prepare + + linear = nn.Linear(12, 20, bias=False).to(device) + gptq_linear = gptq.GPTQLinear( + in_features=12, out_features=20, bias=False).to(device) + gptq_linear.load_state_dict(linear.state_dict(), strict=False) + + random_data = torch.rand([10, 5, 12]).to( + device) # [loader_batch,batch,feature] + data_0 = random_data[0] + + self.assertTrue(get_loss(linear, gptq_linear, data_0) == 0) + + # quant + + gptq_linear.init_hessian() + gptq_linear.register_hessian_hook() + infer(gptq_linear, random_data) + gptq_linear.remove_hessian_hook() + + qconfig = dict(bits=4, perchannel=True, sym=False) + quantizer = gptq.Quantizer() + quantizer.configure(**qconfig) + gptq_linear.quant(quantizer=quantizer) + + # compare + + print('norm:', linear(data_0).norm(2)) + print('distance:', get_loss(linear, gptq_linear, data_0)) + + @torch.no_grad() + def test_model(self): + if digit_version(torch.__version__) < digit_version( + '1.12.0') or not torch.cuda.is_available(): + self.skipTest('torch<1.12.0') + import torchvision + model = torchvision.models.resnet18() + + compressor = gptq.GPTQCompressor() + compressor.prepare(model, use_triton_ops=False) + + x = torch.rand(10, 3, 224, 224) + compressor.init_hessian() + compressor.register_hessian_hooks() + model(x) + compressor.remove_hessian_hooks() + compressor.quant_with_default_qconfig() + + model = compressor.to_static_model(model) + assert type(model.conv1) is nn.Conv2d