Skip to content

Commit

Permalink
[Feature] Merge dev-large into main (#543)
Browse files Browse the repository at this point in the history
* add sparse gpt (#499)

init

Co-authored-by: liukai <your_email@abc.example>

* enhence sparsegpt (#505)

* update

* fix bug

* fix bug

* update opt

* add memory efficient forward for opt

* support to set device for pruning

---------

Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: Your Name <you@example.com>

* Lk large (#510)

* update

* update

---------

Co-authored-by: liukai <your_email@abc.example>

* refine sparse gpt, support multiple gpus with fsdp (#520)

* add mmrazor large

* update readme

* add fsdp for opt

* update

* update

* rename

* update args

* support fsdp

* refine

* refine

* refine

* refine

* fix out of memorry bug

---------

Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: Your Name <you@example.com>

* refine sparse gpt (#526)

* save cpu memory

* update

* update

* update

* update

* refine

* update

* update

---------

Co-authored-by: Your Name <you@example.com>

* merge main (#527)

* fix bug for autoslim (#511)

* fix bug for autoslim

* delete resnet50 for dmcp

---------

Co-authored-by: liukai <your_email@abc.example>

* Add timm (#512)

* add timm to optional.txt

* fix deit paths

* [Feature] Add MMRazor quantization (#513)

* [FEATURE] add quant algo `Learned Step Size Quantization` (#346)

* update

* Fix a bug in make_divisible. (#333)

fix bug in make_divisible

Co-authored-by: liukai <liukai@pjlab.org.cn>

* [Fix] Fix counter mapping bug (#331)

* fix counter mapping bug

* move judgment into get_counter_type & update UT

* [Docs]Add MMYOLO projects link (#334)

* [Doc] fix typos in en/usr_guides (#299)

* Update README.md

* Update README_zh-CN.md

Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com>

* [Features]Support `MethodInputsRecorder` and `FunctionInputsRecorder` (#320)

* support MethodInputsRecorder and FunctionInputsRecorder

* fix bugs that the model can not be pickled

* WIP: add pytest for ema model

* fix bugs in recorder and delivery when ema_hook is used

* don't register the DummyDataset

* fix pytest

* updated

* retina loss & predict & tesnor DONE

* [Feature] Add deit-base (#332)

* WIP: support deit

* WIP: add deithead

* WIP: fix checkpoint hook

* fix data preprocessor

* fix cfg

* WIP: add readme

* reset single_teacher_distill

* add metafile

* add model to model-index

* fix configs and readme

* [Feature]Feature map visualization (#293)

* WIP: vis

* WIP: add visualization

* WIP: add visualization hook

* WIP: support razor visualizer

* WIP

* WIP: wrap draw_featmap

* support feature map visualization

* add a demo image for visualization

* fix typos

* change eps to 1e-6

* add pytest for visualization

* fix vis hook

* fix arguments' name

* fix img path

* support draw inference results

* add visualization doc

* fix figure url

* move files

Co-authored-by: weihan cao <HIT-cwh>

* [Feature] Add kd examples (#305)

* support kd for mbv2 and shufflenetv2

* WIP: fix ckpt path

* WIP: fix kd r34-r18

* add metafile

* fix metafile

* delete

* [Doc] add documents about pruning. (#313)

* init

* update user guide

* update images

* update

* update How to prune your model

* update how_to_use_config_tool_of_pruning.md

* update doc

* move location

* update

* update

* update

* add mutablechannels.md

* add references

Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: jacky <jacky@xx.com>

* [Feature] PyTorch version of `PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient`. (#304)

* add pkd

* add pytest for pkd

* fix cfg

* WIP: support fcos3d

* WIP: support fcos3d pkd

* support mmdet3d

* fix cfgs

* change eps to 1e-6 and add some comments

* fix docstring

* fix cfg

* add assert

* add type hint

* WIP: add readme and metafile

* fix readme

* update metafiles and readme

* fix metafile

* fix pipeline figure

* for RFC

* Customed FX initialize

* add UT init

* [Refactor] Refactor Mutables and Mutators (#324)

* refactor mutables

* update load fix subnet

* add DumpChosen Typehint

* adapt UTs

* fix lint

* Add GroupMixin to ChannelMutator (temporarily)

* fix type hints

* add GroupMixin doc-string

* modified by comments

* fix type hits

* update subnet format

* fix channel group bugs and add UTs

* fix doc string

* fix comments

* refactor diff module forward

* fix error in channel mutator doc

* fix comments

Co-authored-by: liukai <liukai@pjlab.org.cn>

* [Fix] Update readme (#341)

* update kl readme

* update dsnas readme

* fix url

* Bump version to 1.0.0rc1 (#338)

update version

* init demo

* add customer_tracer

* add quantizer

* add fake_quant, loop, config

* remove CPatcher in custome_tracer

* demo_try

* init version

* modified base.py

* pre-rebase

* wip of adaround series

* adaround experiment

* trasfer to s2

* update api

* point at sub_reconstruction

* pre-checkout

* export onnx

* add customtracer

* fix lint

* move custom tracer

* fix import

* TDO: UTs

* Successfully RUN

* update loop

* update loop docstrings

* update quantizer docstrings

* update qscheme docstrings

* update qobserver docstrings

* update tracer docstrings

* update UTs init

* update UTs init

* fix review comments

* fix CI

* fix UTs

* update torch requirements

Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com>
Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com>
Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com>
Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com>
Co-authored-by: jacky <jacky@xx.com>
Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com>
Co-authored-by: humu789 <humu@pjlab.org.cn>

* [Features]Quantize pipeline (#350)

* init demo

* add customer_tracer

* add quantizer

* add fake_quant, loop, config

* remove CPatcher in custome_tracer

* demo_try

* init version

* modified base.py

* pre-rebase

* wip of adaround series

* adaround experiment

* trasfer to s2

* update api

* point at sub_reconstruction

* pre-checkout

* export onnx

* add customtracer

* fix lint

* move custom tracer

* fix import

* update

* updated

* retina loss & predict & tesnor DONE

* for RFC

* Customed FX initialize

* add UT init

* TDO: UTs

* Successfully RUN

* update loop

* update loop docstrings

* update quantizer docstrings

* update qscheme docstrings

* update qobserver docstrings

* update tracer docstrings

* update UTs init

* update UTs init

* fix bugs

* fix lsq

* refactor quantize pipeline

* fix quant

* WIP: debug qat

* fix lsq bugs

* fix qat, docstring in progress

* TDO: UTs

* fix bugs

* fix lsq

* refactor quantize pipeline

* fix quant

* WIP: debug qat

* fix lsq bugs

* fix qat, docstring in progress

* fixed DefaultQconfigs name

* fix bugs

* add comments and fix typos

* delete useless codes

* fix bugs and add comments

* rename prepare_module_dict

* update lsq config

Co-authored-by: humu789 <humu@pjlab.org.cn>
Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: FreakieHuang <frank0huang@foxmail.com>
Co-authored-by: pppppM <gjf_mail@126.com>

* [Feature] Add `prepare_for_mmdeploy` interface  (#365)

* remove useless code

* fix build graph module import bug

* refactor general quant

* rename GeneralQuant to MMArchitectureQuant

* fix some dtype bugs

* add prepare_for_mmdeploy interface

* update prepare for mmdeploy args

* fix some comments

Co-authored-by: humu789 <humu@pjlab.org.cn>

* CodeCamp #132 add MinMaxFloorObserver (#376)

* add minmaxfloor_observer.py

* add MinMaxFloorObserver and normative docstring

* add test for MinMaxFloorObserver

* Quant go (#409)

* add torch observer

* add torch fakequant

* refactor base quantizer

* add QConfigHander and QSchemeHander & finish quantizer_refactor_beta

* passed ptq_pipeline

* tmp-commit

* fix loop and algorithm

* delete fakequant

* refactor code structure

* remove lsq

* valid ptq pipeline

* wip

* fix del functions

* fix

* fix lint and pytest

Co-authored-by: HIT-cwh <2892770585@qq.com>

* [Refactor & Doc] Refactor graph_utils and add docstring and pytest (#420)

* refactor graph_utils and add docstring and pytest

* fix del fakequant

* delete useless codes

* Merge dev-1.x into quantize (#430)

* Fix a bug in make_divisible. (#333)

fix bug in make_divisible

Co-authored-by: liukai <liukai@pjlab.org.cn>

* [Fix] Fix counter mapping bug (#331)

* fix counter mapping bug

* move judgment into get_counter_type & update UT

* [Docs]Add MMYOLO projects link (#334)

* [Doc] fix typos in en/usr_guides (#299)

* Update README.md

* Update README_zh-CN.md

Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com>

* [Features]Support `MethodInputsRecorder` and `FunctionInputsRecorder` (#320)

* support MethodInputsRecorder and FunctionInputsRecorder

* fix bugs that the model can not be pickled

* WIP: add pytest for ema model

* fix bugs in recorder and delivery when ema_hook is used

* don't register the DummyDataset

* fix pytest

* [Feature] Add deit-base (#332)

* WIP: support deit

* WIP: add deithead

* WIP: fix checkpoint hook

* fix data preprocessor

* fix cfg

* WIP: add readme

* reset single_teacher_distill

* add metafile

* add model to model-index

* fix configs and readme

* [Feature]Feature map visualization (#293)

* WIP: vis

* WIP: add visualization

* WIP: add visualization hook

* WIP: support razor visualizer

* WIP

* WIP: wrap draw_featmap

* support feature map visualization

* add a demo image for visualization

* fix typos

* change eps to 1e-6

* add pytest for visualization

* fix vis hook

* fix arguments' name

* fix img path

* support draw inference results

* add visualization doc

* fix figure url

* move files

Co-authored-by: weihan cao <HIT-cwh>

* [Feature] Add kd examples (#305)

* support kd for mbv2 and shufflenetv2

* WIP: fix ckpt path

* WIP: fix kd r34-r18

* add metafile

* fix metafile

* delete

* [Doc] add documents about pruning. (#313)

* init

* update user guide

* update images

* update

* update How to prune your model

* update how_to_use_config_tool_of_pruning.md

* update doc

* move location

* update

* update

* update

* add mutablechannels.md

* add references

Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: jacky <jacky@xx.com>

* [Feature] PyTorch version of `PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient`. (#304)

* add pkd

* add pytest for pkd

* fix cfg

* WIP: support fcos3d

* WIP: support fcos3d pkd

* support mmdet3d

* fix cfgs

* change eps to 1e-6 and add some comments

* fix docstring

* fix cfg

* add assert

* add type hint

* WIP: add readme and metafile

* fix readme

* update metafiles and readme

* fix metafile

* fix pipeline figure

* [Refactor] Refactor Mutables and Mutators (#324)

* refactor mutables

* update load fix subnet

* add DumpChosen Typehint

* adapt UTs

* fix lint

* Add GroupMixin to ChannelMutator (temporarily)

* fix type hints

* add GroupMixin doc-string

* modified by comments

* fix type hits

* update subnet format

* fix channel group bugs and add UTs

* fix doc string

* fix comments

* refactor diff module forward

* fix error in channel mutator doc

* fix comments

Co-authored-by: liukai <liukai@pjlab.org.cn>

* [Fix] Update readme (#341)

* update kl readme

* update dsnas readme

* fix url

* Bump version to 1.0.0rc1 (#338)

update version

* [Feature] Add Autoformer algorithm (#315)

* update candidates

* update subnet_sampler_loop

* update candidate

* add readme

* rename variable

* rename variable

* clean

* update

* add doc string

* Revert "[Improvement] Support for candidate multiple dimensional search constraints."

* [Improvement] Update Candidate with multi-dim search constraints. (#322)

* update doc

* add support type

* clean code

* update candidates

* clean

* xx

* set_resource -> set_score

* fix ci bug

* py36 lint

* fix bug

* fix check constrain

* py36 ci

* redesign candidate

* fix pre-commit

* update cfg

* add build_resource_estimator

* fix ci bug

* remove runner.epoch in testcase

* [Feature] Autoformer architecture and dynamicOPs (#327)

* add DynamicSequential

* dynamiclayernorm

* add dynamic_pathchembed

* add DynamicMultiheadAttention and DynamicRelativePosition2D

* add channel-level dynamicOP

* add autoformer algo

* clean notes

* adapt channel_mutator

* vit fly

* fix import

* mutable init

* remove annotation

* add DynamicInputResizer

* add unittest for mutables

* add OneShotMutableChannelUnit_VIT

* clean code

* reset unit for vit

* remove attr

* add autoformer backbone UT

* add valuemutator UT

* clean code

* add autoformer algo UT

* update classifier UT

* fix test error

* ignore

* make lint

* update

* fix lint

* mutable_attrs

* fix test

* fix error

* remove DynamicInputResizer

* fix test ci

* remove InputResizer

* rename variables

* modify type

* Continued improvements of ChannelUnit

* fix lint

* fix lint

* remove OneShotMutableChannelUnit

* adjust derived type

* combination mixins

* clean code

* fix sample subnet

* search loop fly

* more annotations

* avoid counter warning and modify batch_augment cfg by gy

* restore

* source_value_mutables restriction

* simply arch_setting api

* update

* clean

* fix ut

* [Feature] Add performance predictor (#306)

* add predictor with 4 handlers

* [Improvement] Update Candidate with multi-dim search constraints. (#322)

* update doc

* add support type

* clean code

* update candidates

* clean

* xx

* set_resource -> set_score

* fix ci bug

* py36 lint

* fix bug

* fix check constrain

* py36 ci

* redesign candidate

* fix pre-commit

* update cfg

* add build_resource_estimator

* fix ci bug

* remove runner.epoch in testcase

* update metric_predictor:
1. update MetricPredictor;
2. add predictor config for searching;
3. add predictor in evolution_search_loop.

* add UT for predictor

* add MLPHandler

* patch optional.txt for predictors

* patch test_evolution_search_loop

* refactor apis of predictor and handlers

* fix ut and remove predictor_cfg in predictor

* adapt new mutable & mutator design

* fix ut

* remove unness assert after rebase

* move predictor-build in __init__ & simplify estimator-build

Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn>

* [Feature] Add DCFF (#295)

* add ChannelGroup (#250)

* rebase new dev-1.x

* modification for adding config_template

* add docstring to channel_group.py

* add docstring to mutable_channel_group.py

* rm channel_group_cfg from Graph2ChannelGroups

* change choice type of SequentialChannelGroup from float to int

* add a warning about group-wise conv

* restore __init__ of dynamic op

* in_channel_mutable  ->  mutable_in_channel

* rm abstractproperty

* add a comment about VT

* rm registry for ChannelGroup

* MUTABLECHANNELGROUP -> ChannelGroupType

* refine docstring of IndexDict

* update docstring

* update docstring

* is_prunable -> is_mutable

* update docstring

* fix error in pre-commit

* update unittest

* add return type

* unify init_xxx apit

* add unitest about init of MutableChannelGroup

* update according to reviews

* sequential_channel_group -> sequential_mutable_channel_group

Co-authored-by: liukai <liukai@pjlab.org.cn>

* Add BaseChannelMutator and refactor Autoslim (#289)

* add BaseChannelMutator

* add autoslim

* tmp

* make SequentialMutableChannelGroup accpeted both of num and ratio  as choice. and supports divisior

* update OneShotMutableChannelGroup

* pass supernet training of autoslim

* refine autoslim

* fix bug in OneShotMutableChannelGroup

* refactor make_divisible

* fix spell error:  channl -> channel

* init_using_backward_tracer -> init_from_backward_tracer
init_from_fx_tracer -> init_from_fx_tracer

* refine SequentialMutableChannelGroup

* let mutator support models with dynamicop

* support define search space in model

* tracer_cfg -> parse_cfg

* refine

* using -> from

* update docstring

* update docstring

Co-authored-by: liukai <liukai@pjlab.org.cn>

* tmpsave

* migrate ut

* tmpsave2

* add loss collector

* refactor slimmable and add l1-norm (#291)

* refactor slimmable and add l1-norm

* make l1-norm support convnd

* update get_channel_groups

* add  l1-norm_resnet34_8xb32_in1k.py

* add pretrained to resnet34-l1

* remove old channel mutator

* BaseChannelMutator -> ChannelMutator

* update according to reviews

* add readme to l1-norm

* MBV2_slimmable -> MBV2_slimmable_config

Co-authored-by: liukai <liukai@pjlab.org.cn>

* update config

* fix md & pytorch support <1.9.0 in batchnorm init

* Clean old codes. (#296)

* remove old dynamic ops

* move dynamic ops

* clean old mutable_channels

* rm OneShotMutableChannel

* rm MutableChannel

* refine

* refine

* use SquentialMutableChannel to replace OneshotMutableChannel

* refactor dynamicops folder

* let SquentialMutableChannel support float

Co-authored-by: liukai <liukai@pjlab.org.cn>

* fix ci

* ci fix py3.6.x & add mmpose

* ci fix py3.6.9 in utils/index_dict.py

* fix mmpose

* minimum_version_cpu=3.7

* fix ci 3.7.13

* fix pruning &meta ci

* support python3.6.9

* fix py3.6 import caused by circular import patch in py3.7

* fix py3.6.9

* Add channel-flow (#301)

* base_channel_mutator -> channel_mutator

* init

* update docstring

* allow omitting redundant configs for channel

* add register_mutable_channel_to_a_module to MutableChannelContainer

* update according to reviews 1

* update according to reviews 2

* update according to reviews 3

* remove old docstring

* fix error

* using->from

* update according to reviews

* support self-define input channel number

* update docstring

* chanenl -> channel_elem

Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: jacky <jacky@xx.com>

* support >=3.7

* support py3.6.9

* Rename: ChannelGroup -> ChannelUnit (#302)

* refine repr of MutableChannelGroup

* rename folder name

* ChannelGroup -> ChannelUnit

* filename in units folder

* channel_group -> channel_unit

* groups -> units

* group -> unit

* update

* get_mutable_channel_groups -> get_mutable_channel_units

* fix bug

* refine docstring

* fix ci

* fix bug in tracer

Co-authored-by: liukai <liukai@pjlab.org.cn>

* update new channel config format

* update pruning refactor

* update merged pruning

* update commit

* fix dynamic_conv_mixin

* update comments: readme&dynamic_conv_mixins.py

* update readme

* move kl softmax channel pooling to op by comments

* fix comments: fix redundant & split README.md

* dcff in ItePruneAlgorithm

* partial dynamic params for fuseconv

* add step_freq & prune_time check

* update comments

* update comments

* update comments

* fix ut

* fix gpu ut & revise step_freq in ItePruneAlgorithm

* update readme

* revise ItePruneAlgorithm

* fix docs

* fix dynamic_conv attr

* fix ci

Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com>
Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com>
Co-authored-by: jacky <jacky@xx.com>

* [Fix] Fix optional requirements (#357)

* fix optional requirements

* fix dcff ut

* fix import with get_placeholder

* supplement the previous commit

* [Fix] Fix configs of wrn models and ofd. (#361)

* 1.revise the configs of wrn22, wrn24, and wrn40. 2.revise the data_preprocessor of ofd_backbone_resnet50_resnet18_8xb16_cifar10

* 1.Add README for vanilla-wrm.

* 1.Revise readme of wrn

Co-authored-by: zhangzhongyu <zhangzhongyu@pjlab.org.cn>

* [Fix] Fix bug on mmrazor visualization, mismatch argument in define and use. (#356)

fix bug on mmrazor visualization, mismatch argument in define and use.

Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com>

* fix bug in benchmark_test (#364)

fix bug in configs

Co-authored-by: Your Name <you@example.com>

* [FIX] Fix wrn configs (#368)

* fix wrn configs

* fix wrn configs

* update online wrn model weight

* [Fix] fix bug on pkd config. Wrong import filename. (#373)

* [CI] Update ci to torch1.13 (#380)

update ci to torch1.13

* [Feature] Add BigNAS algorithm (#219)

* add calibrate-bn-statistics

* add test calibrate-bn-statistics

* fix mixins

* fix mixins

* fix mixin tests

* remove slimmable channel mutable and refactor dynamic op

* refact dynamic batch norm

* add progressive dynamic conv2d

* add center crop dynamic conv2d

* refactor dynamic directory

* refactor dynamic sequential

* rename length to depth in dynamic sequential

* add test for derived mutable

* refactor dynamic op

* refactor api of dynamic op

* add derive mutable mixin

* addbignas algorithm

* refactor bignas structure

* add input resizer

* add input resizer to bignas

* move input resizer from algorithm into classifier

* remove compnents

* add attentive mobilenet

* delete json file

* nearly(less 0.2) align inference accuracy with gml

* move mutate seperated in bignas mobilenet backbone

* add zero_init_residual

* add set_dropout

* set dropout in bignas algorithm

* fix registry

* add subnet yaml and nearly align inference accuracy with gml

* add rsb config for bignas

* remove base in config

* add gml bignas config

* convert to iter based

* bignas forward and backward fly

* fix merge conflict

* fix dynamicseq bug

* fix bug and refactor bignas

* arrange configs of bignas

* fix typo

* refactor attentive_mobilenet

* fix channel mismatch due to registion of DerivedMutable

* update bignas & fix se channel mismatch

* add AutoAugmentV2 & remove unness configs

* fix lint

* recover channel assertion in channel unit

* fix a group bug

* fix comments

* add docstring

* add norm in dynamic_embed

* fix search loop & other minor changes

* fix se expansion

* minor change

* add ut for bignas & attentive_mobilenet

* fix ut

* update bignas readme

* rm unness ut & supplement get_placeholder

* fix lint

* fix ut

* add subnet deployment in downstream tasks.

* minor change

* update ofa backbone

* minor fix

* Continued improvements of searchable backbone

* minor change

* drop ratio in backbone

* fix comments

* fix ci test

* fix test

* add dynamic shortcut UT

* modify strategy to fit bignas

* fix test

* fix bug in neck

* fix error

* fix error

* fix yaml

* save subnet ckpt

* merge autoslim_val/test_loop into subnet_val_loop

* move calibrate_bn_mixin to utils

* fix bugs and add docstring

* clean code

* fix register bug

* clean code

* update

Co-authored-by: wangshiguang <wangshiguang@sensetime.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: aptsunny <aptsunny@tongji.edu.cn>
Co-authored-by: sunyue1 <sunyue1@sensetime.com>

* [Bug] Fix ckpt (#372)

fix ckpt

* [Feature] Add tools to convert distill ckpt to student-only ckpt. (#381)

* [Feature] Add tools to convert distill ckpt to student-only ckpt.

* fix bug.

* add --model-only to only save model.

* Make changes accroding to PR review.

* Enhance the Abilities of the Tracer for Pruning. (#371)

* tmp

* add new mmdet models

* add docstring

* pass test and pre-commit

* rm razor tracer

* update fx tracer, now it can automatically wrap methods and functions.

* update tracer passed models

* add warning for torch <1.12.0

fix bug for python3.6

update placeholder to support placeholder.XXX

* fix bug

* update docs

* fix lint

* fix parse_cfg in configs

* restore mutablechannel

* test ite prune algorithm when using dist

* add get_model_from_path to MMModelLibrrary

* add mm models to DefaultModelLibrary

* add uts

* fix bug

* fix bug

* add uts

* add uts

* add uts

* add uts

* fix bug

* restore ite_prune_algorithm

* update doc

* PruneTracer -> ChannelAnalyzer

* prune_tracer -> channel_analyzer

* add test for fxtracer

* fix bug

* fix bug

* PruneTracer -> ChannelAnalyzer

refine

* CustomFxTracer -> MMFxTracer

* fix bug when test with torch<1.12

* update print log

* fix lint

* rm unuseful code

Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: jacky <jacky@xx.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: liukai <your_email@abc.example>

* fix bug in placer holder (#395)

* fix bug in placer holder

* remove redundent comment

Co-authored-by: liukai <your_email@abc.example>

* Add get_prune_config and a demo config_pruning (#389)

* update tools and test

* add demo

* disable test doc

* add switch for test tools and test_doc

* fix bug

* update doc

* update tools name

* mv get_channel_units

Co-authored-by: liukai <your_email@abc.example>

* [Improvement] Adapt OFA series with SearchableMobileNetV3 (#385)

* fix mutable bug in AttentiveMobileNetV3

* remove unness code

* update ATTENTIVE_SUBNET_A0-A6.yaml with optimized names

* unify the sampling usage in sandwich_rule-based NAS

* use alias to export subnet

* update OFA configs

* fix attr bug

* fix comments

* update convert_supernet2subnet.py

* correct the way to dump DerivedMutable

* fix convert index bug

* update OFA configs & models

* fix dynamic2static

* generalize convert_ofa_ckpt.py

* update input_resizer

* update README.md

* fix ut

* update export_fix_subnet

* update _dynamic_to_static

* update fix_subnet UT & minor fix bugs

* fix ut

* add new autoaug compared to attentivenas

* clean

* fix act

* fix act_cfg

* update fix_subnet

* fix lint

* add docstring

Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: aptsunny <aptsunny@tongji.edu.cn>

* [Fix]Dcff Deploy Revision (#383)

* dcff deploy revision

* tempsave

* update fix_subnet

* update mutator load

* export/load_fix_subnet revision for mutator

* update fix_subnet with dev-1.x

* update comments

* update docs

* update registry

* [Fix] Fix commands in README to adapt branch 1.x (#400)

* update commands in README for 1.x

* fix commands

Co-authored-by: gaoyang07 <1546308416@qq.com>

* Set requires_grad to False if the teacher is not trainable (#398)

* add choice and mask of units to checkpoint (#397)

* add choice and mask of units to checkpoint

* update

* fix bug

* remove device operation

* fix bug

* fix circle ci error

* fix error in numpy for circle ci

* fix bug in requirements

* restore

* add a note

* a new solution

* save mutable_channel.mask as float for dist training

* refine

* mv meta file test

Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: jacky <jacky@xx.com>

* [Bug]Fix fpn teacher distill (#388)

fix fpn distill

* [CodeCamp #122] Support KD algorithm MGD for detection. (#377)

* [Feature] Support KD algorithm MGD for detection.

* use connector to beauty mgd.

* fix typo, add unitest.

* fix mgd loss unitest.

* fix mgd connector unitest.

* add model pth and log file.

* add mAP.

* update l1 config (#405)

* add l1 config

* update l1 config

Co-authored-by: jacky <jacky@xx.com>

* [Feature] Add greedy search for AutoSlim (#336)

* WIP: add greedysearch

* fix greedy search and add bn_training_mode to autoslim

* fix cfg files

* fix autoslim configs

* fix bugs when converting dynamic bn to static bn

* change to test loop

* refactor greedy search

* rebase and fix greedysearch

* fix lint

* fix and delete useless codes

* fix pytest

* fix pytest and add bn_training_mode

* fix lint

* add reference to AutoSlimGreedySearchLoop's docstring

* sort candidate_choices

* fix save subnet

* delete useless codes in channel container

* change files' name: convert greedy_search_loop to autoslim_greedy_search_loop

* [Fix] Fix metafile (#422)

* fix ckpt path in metafile and readme

* fix darts file path

* fix docstring in ConfigurableDistiller

* fix darts

* fix error

* add darts of mmrazor version

* delete py36

Co-authored-by: liukai <your_email@abc.example>

* update bignas cfg (#412)

* check attentivenas training

* update ckpt link

* update supernet log

Co-authored-by: aptsunny <aptsunny@tongji.edu.cn>

* Bump version to 1.0.0rc2 (#423)

bump version to 1.0.0rc2

Co-authored-by: liukai <your_email@abc.example>

* fix lint

* fix ci

* add tmp docstring for passed ci

* add tmp docstring for passed ci

* fix ci

* add get_placeholder for quant

* add skip for unittest

* fix package placeholder bug

* add version judgement in __init__

* update prev commit

* update prev commit

* update prev commit

* update prev commit

* update prev commit

* update prev commit

* update prev commit

* update prev commit

* update prev commit

Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com>
Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com>
Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com>
Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com>
Co-authored-by: jacky <jacky@xx.com>
Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com>
Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn>
Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com>
Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com>
Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com>
Co-authored-by: zhangzhongyu <zhangzhongyu@pjlab.org.cn>
Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com>
Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com>
Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com>
Co-authored-by: wangshiguang <wangshiguang@sensetime.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: sunyue1 <sunyue1@sensetime.com>
Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: Ming-Hsuan-Tu <qrnnis2623891@gmail.com>
Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com>
Co-authored-by: Yue Sun <aptsunny@alumni.tongji.edu.cn>

* [Docs] Add docstring and unittest about backendconfig & observer & fakequant (#428)

* add ut about backendconfig

* add ut about observers and fakequants in torch

* fix torch1.13 ci

* [Docs] Add docstring for `MMArchitectureQuant` & `NativeQuantizer` (#425)

* add docstring on mm_architecture& native_quantizer

* add naive openvino r18 qat config & dist_ptq.sh

* Added a more accurate description

* unitest&doc

* checkpoint url

* unitest

* passed_pre_commit

* unitest on native_quantizer& fix bugs

* remove dist_ptq

* add get_placeholder&skipTest

* complete arg descriptions

* fix import bugs

* fix pre-commit

* add get_placeholder

* add typehint and doctring

* update docstring&typehint

* update docstring

* pre-commit

* fix some problems

* fix bug

* [Docs] Add docstring and unitest about custom tracer (#427)

* rename QConfigHandler and QSchemeHandler

* add docstring about custom tracer

* add ut about custom tracer

* fix torch1.13 ci

* fix lint

* fix ci

* fix ci

* [Docs & Refactor] Add docstring and UT of other quantizers (#439)

* add quantizer docstring and refactor the interface of AcademicQuantizer

* add AcademicQuantizer unittest

* add TensorRTQuantizer and OpenVINOQuantizer unittest & refactor prepare interface

* adapt torch113 ci

* fix import

* fix lint

* update some docstring

* fix ci

* [Feature&Doc]Modify ptq pipeline and support lsq (#435)

* modify ptq pipeline and support lsq

* use placeholder

* fix lsq && quantloop

* add lsq pytest

* add quant loop pytest

* test lsq observer

* fix bug under pt13

* fix reset_min_max_vals

* fix bugs under pt13

* fix configs

* add get_qconfig_mapping

* delete is_qat, add doc and fix pytest

* delete useless codes in custom_tracer

* skip pytest under pt13

* add todo: check freezebn

* fix pytest bugs

* fix pytest

* fix pytest

* fix pytest

* [Docs] Add customize_quantization_tutorial (#440)

* [Docs] Add quantization user guide (#441)

* add quantization user guide

* fix layout

* fix layout

* update README

* [Bug] Fix del redundant fakequant (#447)

fix del redundant fakequant

* [Feature] Add onnx exporters (#475)

* fix del redundant fakequant

* add onnx exporters

* fix onnx exporters and add docstring

* fix comments

* delete useless codes

* fix export_onnx in native quantizer

---------

Co-authored-by: pppppM <gjf_mail@126.com>

* [Feature]Rewrite the origin model during prepare (#488)

* add rewriter

* add deploy_cfg arg

* modify post_process_for_mmdeploy

* fix bugs

* add det config

* [Feature] Using rewriter in mmrazor when building qmodels. (#490)

* add rewriter

* add deploy_cfg arg

* modify post_process_for_mmdeploy

* fix bugs

* add det config

* replace deepcopy

* pop detectors' forward

* [Feature] Quantization global optimization (#491)

* add trtquantizer

* unify all fakequant before deploy

* move to aide

* add yolox config

* pre-rebase

* add unittest

* add a arg of post_process_for_deploy

* test trt yolox deploy

* opt quantizer interface

* fix rebase

* add trt r50 config

* update trt setting

* del redundant code

* fix lint

* fix ut of quantizers

* del redundant file

* fix lint

* fix some comments

* Fix code syntax in UT (#470)

Co-authored-by: 王盟 <unicorn@MacBook-Pro.local>

* passed lint and pytest

* try to fix ci

* [Bug] Try to fix CI (#502)

fix lint

* [Feature] Support lsq (#501)

* support deploy_cfg=None

* replace fakequant before load ckpt

* add _load_from_state_dict to lsq fakequant

* fix pre-commit

* test lsq load state dict

* change github ci: ubuntu 18.04 to ubuntu 20.04

* get_deploy_model order change back

* sync before save ckpt

* delete strict=False

* test context rewriter

* fix pre commit config

* try to fix ci

* [Bug] Try to fix CI (#502)

fix lint

---------

Co-authored-by: humu789 <humu@pjlab.org.cn>
Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com>

* [Feature] Add exporter pytest (#504)

* add exporter pytest

* fix bugs

* delete useless codes

* handle onnx

* delete useless codes

* [Bug] Fix ci converage setting (#508)

fix ci converage

* [Bug] Fix codecov (#509)

* remove codecov in requirements

* try to fix ci

* del adaround loss

* [BUG] Fix quantization loop (#507)

* fix quantization loop

* fix quant loop

* fix quant loop

* fix qat configs

* [Bug] Fix ci converage setting (#508)

fix ci converage

* [Bug] Fix codecov (#509)

* remove codecov in requirements

* try to fix ci

* del adaround loss

* add freeze_bn_begin to lsq

* delete useless codes

---------

Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com>

* add test ptq

* opt ptq pipeline

* refactor quant configs

* update config path

* add summary analyse tool

* fix benchmark_test:detnas_frcnn_shufflenet_subnet_coco_1x.py

* update quantization README.md

* update quantization metafile, readme, config path

* update quantization docs

* update git main link in workflow

* update benchmark_summary_analyse.py

* del dmcp results

* [Bug] fix a rebase error (#514)

fix a rebase error

* [Bug] Fix CI (#515)

* fix ci

* mmcv2.0 need torch1.8+

* Update CI config and Passed (#516)

* test ci

* update test.yml based on mmcv2.0.0

* [Docs] Fix cwd test accuary (#517)

* test ci

* update test.yml based on mmcv2.0.0

* update cwd_logits_pspnet result

---------

Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com>
Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com>
Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com>
Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com>
Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com>
Co-authored-by: jacky <jacky@xx.com>
Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com>
Co-authored-by: FreakieHuang <frank0huang@foxmail.com>
Co-authored-by: pppppM <gjf_mail@126.com>
Co-authored-by: L-Icarus <30308843+L-Icarus@users.noreply.github.com>
Co-authored-by: HIT-cwh <2892770585@qq.com>
Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn>
Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com>
Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com>
Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com>
Co-authored-by: zhangzhongyu <zhangzhongyu@pjlab.org.cn>
Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com>
Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com>
Co-authored-by: wangshiguang <wangshiguang@sensetime.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: sunyue1 <sunyue1@sensetime.com>
Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: Ming-Hsuan-Tu <qrnnis2623891@gmail.com>
Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com>
Co-authored-by: Yue Sun <aptsunny@alumni.tongji.edu.cn>
Co-authored-by: Ivan Zhang <51170394+415905716@users.noreply.github.com>
Co-authored-by: wm901115nwpu <wmnwpu@gmail.com>
Co-authored-by: 王盟 <unicorn@MacBook-Pro.local>

* [Docs&Feature] Prepare for checkouting default branch and releasing new version (#518)

* prepare for checkout default branch

* update README.md and model zoo

* update installation.md and update dev-1.x links

* update README_zh-CN

* add changelog

* update ci config

* update some links in quantization readme

* update quantization user guide

* update calibrate_dataloader

* add interface pop_rewriter_function_record

* Bump version to 1.0.0 (#521)

* update release time

* bump version to 1.0.0

* [CI] Fix merge stage test (#523)

fix merge_stage_test in ci

---------

Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com>
Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com>
Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com>
Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com>
Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com>
Co-authored-by: jacky <jacky@xx.com>
Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com>
Co-authored-by: FreakieHuang <frank0huang@foxmail.com>
Co-authored-by: pppppM <gjf_mail@126.com>
Co-authored-by: L-Icarus <30308843+L-Icarus@users.noreply.github.com>
Co-authored-by: HIT-cwh <2892770585@qq.com>
Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn>
Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com>
Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com>
Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com>
Co-authored-by: zhangzhongyu <zhangzhongyu@pjlab.org.cn>
Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com>
Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com>
Co-authored-by: wangshiguang <wangshiguang@sensetime.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: sunyue1 <sunyue1@sensetime.com>
Co-authored-by: Ming-Hsuan-Tu <qrnnis2623891@gmail.com>
Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com>
Co-authored-by: Yue Sun <aptsunny@alumni.tongji.edu.cn>
Co-authored-by: Ivan Zhang <51170394+415905716@users.noreply.github.com>
Co-authored-by: wm901115nwpu <wmnwpu@gmail.com>
Co-authored-by: 王盟 <unicorn@MacBook-Pro.local>

* move folders and update readme (#528)

* move folders

* update readme

---------

Co-authored-by: liukai <your_email@abc.example>

* [Bug] Fix torch2 error (#536)

fix torch2 error

* [Feature] Add GPTQ and uniform interfaces (#538)

* add gptq implementation

* pre-checkout

* passed resnet example

* passed llama example

* aglin gptq acc

* add activation quantization

* uniform interfaces

* add gptq readme

* update mmrazor_large redame

* add gptq opt example

* fix sparse_gpt example for opt

* fix import Protocol from py37

* fix error function name

* fix bug in test

* fix bug

* fix bug

* limit sparsegpt test with torch>=1.12

* add docstring for gptq and sparse_gpt

* pre-commit

* align acc & add save load ckpt & add ut

* fix ut

* fix ut

* fix ut

* fix ut & add torch2.0 for ci

* del torch2.0 for ci

* fix ut

---------

Co-authored-by: FIRST_NAME LAST_NAME <MY_NAME@example.com>

---------

Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com>
Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com>
Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com>
Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com>
Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com>
Co-authored-by: jacky <jacky@xx.com>
Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com>
Co-authored-by: FreakieHuang <frank0huang@foxmail.com>
Co-authored-by: pppppM <gjf_mail@126.com>
Co-authored-by: L-Icarus <30308843+L-Icarus@users.noreply.github.com>
Co-authored-by: HIT-cwh <2892770585@qq.com>
Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn>
Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com>
Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com>
Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com>
Co-authored-by: zhangzhongyu <zhangzhongyu@pjlab.org.cn>
Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com>
Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com>
Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com>
Co-authored-by: wangshiguang <wangshiguang@sensetime.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: sunyue1 <sunyue1@sensetime.com>
Co-authored-by: Ming-Hsuan-Tu <qrnnis2623891@gmail.com>
Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com>
Co-authored-by: Yue Sun <aptsunny@alumni.tongji.edu.cn>
Co-authored-by: Ivan Zhang <51170394+415905716@users.noreply.github.com>
Co-authored-by: wm901115nwpu <wmnwpu@gmail.com>
Co-authored-by: 王盟 <unicorn@MacBook-Pro.local>
Co-authored-by: FIRST_NAME LAST_NAME <MY_NAME@example.com>
  • Loading branch information
35 people committed May 25, 2023
1 parent d3cd028 commit 454f397
Show file tree
Hide file tree
Showing 39 changed files with 4,404 additions and 6 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,5 @@ repos:
| ^docs
| ^configs
| ^.*/configs*
| ^projects
)
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ English | [简体中文](README_zh-CN.md)

</div>

**:star: MMRazor for Large Models** is Available Now! Please refer to [MMRazorLarge](projects/mmrazor_large/README.md)

## Introduction

MMRazor is a model compression toolkit for model slimming and AutoML, which includes 4 mainstream technologies:
Expand Down
4 changes: 2 additions & 2 deletions mmrazor/implementations/pruning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import group_fisher
from . import group_fisher, sparse_gpt

__all__ = ['group_fisher']
__all__ = ['group_fisher', 'sparse_gpt']
9 changes: 9 additions & 0 deletions mmrazor/implementations/pruning/sparse_gpt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .compressor import SparseGptCompressor
from .ops import SparseGptLinear, SparseGptMixIn
from .utils import replace_with_dynamic_ops

__all__ = [
'SparseGptLinear', 'SparseGptMixIn', 'replace_with_dynamic_ops',
'SparseGptCompressor'
]
106 changes: 106 additions & 0 deletions mmrazor/implementations/pruning/sparse_gpt/compressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn

from mmrazor.utils import print_log
from .ops import SparseGptConv2d, SparseGptLinear, SparseGptMixIn
from .utils import replace_with_dynamic_ops


def to_static_model(model: nn.Module):
"""Replace dynamicops with torch modules."""
from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet,
load_fix_subnet)
fix_subnet = export_fix_subnet(model)[0]
load_fix_subnet(model, fix_subnet)
return model


class SparseGptCompressor():
"""The compressor with SparseGPT."""

def __init__(self) -> None:
self.model: nn.Module = None

def prepare(self,
model: nn.Module,
prune_conv=True,
prune_linear=True) -> None:
"""Prepare for compressing model."""
self.model = model
prune_modules: dict = {}
if prune_conv:
prune_modules[nn.Conv2d] = SparseGptConv2d
if prune_linear:
prune_modules[nn.Linear] = SparseGptLinear
replace_with_dynamic_ops(model, prune_modules)

@classmethod
def to_static_model(cls, model):
"""Convert replaced op with the original torch model."""
return to_static_model(model)

# hessian

def register_hessian_hooks(self):
"""Register updating hessian hooks for specified ops."""
for module in self.sparse_ops:
module.register_hessian_hook()

def remove_hessian_hooks(self):
"""Remove updating hessian hooks for specified ops."""
for module in self.sparse_ops:
module.remove_hessian_hook()

def init_hessian(self, device=None):
"""Init hessian."""
for op in self.sparse_ops:
op.init_hessian(device=device)

# prune
def prune(self,
sparsity,
prunen=0,
prunem=0,
blocksize=128,
percdamp=.01,
device=torch.device('cuda')):
"""Apply the compression algorithm to the model."""
for name, module in self.named_sparse_ops:
try:
original_device = next(module.parameters()).device
module: SparseGptMixIn = module.to(device)
error = module.prune(
sparsity=sparsity,
prunen=prunen,
prunem=prunem,
blocksize=blocksize,
percdamp=percdamp,
)
print_log(f'prune {name} success \t error = {error}')
module.to(original_device)
torch.cuda.empty_cache()
except Exception as e:
print_log(f'prune {name} failed as {e}')

def prune_24(self, device=torch.device('cuda:0')):
"""Apply the compression algorithm to the model with the specified
setting."""
self.prune(0.5, prunen=2, prunem=4, device=device)

# ops

@property
def sparse_ops(self):
"""The ops to be applied the algorithm."""
assert self.model is not None
for module in self.model.modules():
if isinstance(module, SparseGptMixIn):
yield module

@property
def named_sparse_ops(self):
"""The named ops to be applied the algorithm."""
for name, module in self.model.named_modules():
if isinstance(module, SparseGptMixIn):
yield name, module
Loading

0 comments on commit 454f397

Please sign in to comment.