diff --git a/configs/pruning/mmcls/dmcp/metafile.yml b/configs/pruning/mmcls/dmcp/metafile.yml new file mode 100644 index 000000000..4c1268093 --- /dev/null +++ b/configs/pruning/mmcls/dmcp/metafile.yml @@ -0,0 +1,19 @@ +Models: + - Name: dmcp_resnet50_subnet_32xb64 + In Collection: DMCP + Config: configs/pruning/mmcls/dmcp/dmcp_resnet50_subnet_32xb64.py + Weights: https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/resnet50/2G/DMCP_R50_2G.pth + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 76.11 + - Name: dmcp_mbv2_subnet_32xb64 + In Collection: DMCP + Config: configs/pruning/mmcls/dmcp/dmcp_mbv2_subnet_32xb64.py + Weights: https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/mobilenetv2/100M/DMCP_MBV2_100M.pth + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 67.22 diff --git a/configs/pruning/mmcls/group_fisher/mobilenet/metafile.yml b/configs/pruning/mmcls/group_fisher/mobilenet/metafile.yml new file mode 100644 index 000000000..24f41eaae --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/mobilenet/metafile.yml @@ -0,0 +1,19 @@ +Models: + - Name: group_fisher_act_finetune_mobilenet-v2_8xb32_in1k + In Collection: GroupFisher + Config: configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.pth + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 70.82 + - Name: group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k + In Collection: GroupFisher + Config: configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.pth + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 70.87 diff --git a/configs/pruning/mmcls/group_fisher/resnet50/metafile.yml b/configs/pruning/mmcls/group_fisher/resnet50/metafile.yml new file mode 100644 index 000000000..fd670a3c2 --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/resnet50/metafile.yml @@ -0,0 +1,19 @@ +Models: + - Name: group_fisher_act_finetune_resnet50_8xb32_in1k + In Collection: GroupFisher + Config: configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_finetune_resnet50_8xb32_in1k.pth + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 75.22 + - Name: group_fisher_flops_finetune_resnet50_8xb32_in1k + In Collection: GroupFisher + Config: configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_finetune_resnet50_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_finetune_resnet50_8xb32_in1k.pth + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 75.61 diff --git a/configs/pruning/mmcls/l1-norm/README.md b/configs/pruning/mmcls/l1-norm/README.md index 8dc2c1a20..2b6509298 100644 --- a/configs/pruning/mmcls/l1-norm/README.md +++ b/configs/pruning/mmcls/l1-norm/README.md @@ -18,3 +18,44 @@ We use ItePruneAlgorithm and L1MutableChannelUnit to implement l1-norm pruning. | ResNet34_Pruned_C | 73.89 | +0.27 | 3.40 | 7.6% | 2.02 | 7.3% | [config](./l1-norm_resnet34_8xb32_in1k_a.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/l1-norm/l1-norm_resnet34_8xb32_in1k_c.pth) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/l1-norm/l1-norm_resnet34_8xb32_in1k_c.json) | **Note:** There is a different implementation from the original paper. We pruned the layers related to the shortcut with a shared pruning decision, while the original paper pruned them separately in *Pruned C*. This may be why our *Pruned C* outperforms *Prune A* and *Prune B*, while *Pruned C* is worst in the original paper. + +## Getting Started + +### Prune + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh \ + {prune_config_path}.py 8 --work-dir $WORK_DIR +``` + +after the pruning process, you can get a checkpoint file in the work_dir. This checkpoint file including all parameters of the original model. In the next step, we will use the checkpoint to export a pruned checkpoint. + +### Get the pruned model + +```bash +python ./tools/pruning/get_static_model_from_algorithm.py \ + {prune_config_path}.py \ + {checkpoint_file}.pth \ + --o {output_folder} +``` + +This step will export a pruned checkpoint and a json file which records the pruning structure. This two file will be used to deploy the pruned model. + +### Deploy + +For a pruned model, you only need to use the pruning deploy config instead of the pretrain config to deploy the pruned version of your model. If you are not fimilar with MMDeploy, please refer to [mmdeploy](https://github.com/open-mmlab/mmdeploy/tree/1.x). + +```bash +python {mmdeploy}/tools/deploy.py \ + {mmdeploy}/{mmdeploy_config}.py \ + {pruning_deploy_config}.py \ + {pruned_checkpoint}.pth \ + {mmdeploy}/tests/data/tiger.jpeg +``` + +### Get the Flops and Parameters of a Pruned Model + +```bash +python ./tools/pruning/get_flops.py \ + {pruning_deploy_config}.py +``` diff --git a/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_a_deploy.py b/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_a_deploy.py new file mode 100644 index 000000000..c754d11fc --- /dev/null +++ b/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_a_deploy.py @@ -0,0 +1,57 @@ +############################################################################# +"""You have to fill these args. + +_base_(str): The path to your pretrain config file. +fix_subnet (Union[dict,str]): The dict store the pruning structure or the + json file including it. +divisor (int): The divisor the make the channel number divisible. +""" + +_base_ = ['mmcls::resnet/resnet34_8xb32_in1k.py'] +un_prune = 1.0 +stage_ratio_1 = 0.7 +stage_ratio_2 = 0.7 +stage_ratio_3 = 0.7 +stage_ratio_4 = un_prune + +# the config template of target_pruning_ratio can be got by +# python ./tools/get_channel_units.py {config_file} --choice +fix_subnet = { + # stage 1 + 'backbone.conv1_(0, 64)_64': un_prune, # short cut layers + 'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.1.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.2.conv1_(0, 64)_64': un_prune, + # stage 2 + 'backbone.layer2.0.conv1_(0, 128)_128': un_prune, + 'backbone.layer2.0.conv2_(0, 128)_128': un_prune, # short cut layers + 'backbone.layer2.1.conv1_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.2.conv1_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.3.conv1_(0, 128)_128': un_prune, + # stage 3 + 'backbone.layer3.0.conv1_(0, 256)_256': un_prune, + 'backbone.layer3.0.conv2_(0, 256)_256': un_prune, # short cut layers + 'backbone.layer3.1.conv1_(0, 256)_256': stage_ratio_3, + 'backbone.layer3.2.conv1_(0, 256)_256': stage_ratio_3, + 'backbone.layer3.3.conv1_(0, 256)_256': stage_ratio_3, + 'backbone.layer3.4.conv1_(0, 256)_256': stage_ratio_3, + 'backbone.layer3.5.conv1_(0, 256)_256': un_prune, + # stage 4 + 'backbone.layer4.0.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.0.conv2_(0, 512)_512': un_prune, # short cut layers + 'backbone.layer4.1.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.2.conv1_(0, 512)_512': stage_ratio_4 +} +divisor = 8 +############################################################################## + +architecture = _base_.model + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherDeploySubModel', + architecture=architecture, + fix_subnet=fix_subnet, + divisor=divisor, +) diff --git a/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_b_deploy.py b/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_b_deploy.py new file mode 100644 index 000000000..636ff0766 --- /dev/null +++ b/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_b_deploy.py @@ -0,0 +1,57 @@ +############################################################################# +"""You have to fill these args. + +_base_(str): The path to your pretrain config file. +fix_subnet (Union[dict,str]): The dict store the pruning structure or the + json file including it. +divisor (int): The divisor the make the channel number divisible. +""" + +_base_ = ['mmcls::resnet/resnet34_8xb32_in1k.py'] + +un_prune = 1.0 +stage_ratio_1 = 0.5 +stage_ratio_2 = 0.4 +stage_ratio_3 = 0.6 +stage_ratio_4 = un_prune + +fix_subnet = { + # stage 1 + 'backbone.conv1_(0, 64)_64': un_prune, # short cut layers + 'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.1.conv1_(0, 64)_64': stage_ratio_1, + 'backbone.layer1.2.conv1_(0, 64)_64': un_prune, + # stage 2 + 'backbone.layer2.0.conv1_(0, 128)_128': un_prune, + 'backbone.layer2.0.conv2_(0, 128)_128': un_prune, # short cut layers + 'backbone.layer2.1.conv1_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.2.conv1_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.3.conv1_(0, 128)_128': un_prune, + # stage 3 + 'backbone.layer3.0.conv1_(0, 256)_256': un_prune, + 'backbone.layer3.0.conv2_(0, 256)_256': un_prune, # short cut layers + 'backbone.layer3.1.conv1_(0, 256)_256': stage_ratio_3, + 'backbone.layer3.2.conv1_(0, 256)_256': stage_ratio_3, + 'backbone.layer3.3.conv1_(0, 256)_256': stage_ratio_3, + 'backbone.layer3.4.conv1_(0, 256)_256': stage_ratio_3, + 'backbone.layer3.5.conv1_(0, 256)_256': un_prune, + # stage 4 + 'backbone.layer4.0.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.0.conv2_(0, 512)_512': un_prune, # short cut layers + 'backbone.layer4.1.conv1_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.2.conv1_(0, 512)_512': stage_ratio_4 +} + +divisor = 8 +############################################################################## + +architecture = _base_.model + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherDeploySubModel', + architecture=architecture, + fix_subnet=fix_subnet, + divisor=divisor, +) diff --git a/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_c_deploy.py b/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_c_deploy.py new file mode 100644 index 000000000..2c7a42e12 --- /dev/null +++ b/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_c_deploy.py @@ -0,0 +1,54 @@ +############################################################################# +"""You have to fill these args. + +_base_(str): The path to your pretrain config file. +fix_subnet (Union[dict,str]): The dict store the pruning structure or the + json file including it. +divisor (int): The divisor the make the channel number divisible. +""" + +_base_ = ['mmcls::resnet/resnet34_8xb32_in1k.py'] +un_prune = 1.0 + +# the config template of target_pruning_ratio can be got by +# python ./tools/get_channel_units.py {config_file} --choice +fix_subnet = { + # stage 1 + 'backbone.conv1_(0, 64)_64': un_prune, # short cut layers + 'backbone.layer1.0.conv1_(0, 64)_64': un_prune, + 'backbone.layer1.1.conv1_(0, 64)_64': un_prune, + 'backbone.layer1.2.conv1_(0, 64)_64': un_prune, + # stage 2 + 'backbone.layer2.0.conv1_(0, 128)_128': un_prune, + 'backbone.layer2.0.conv2_(0, 128)_128': un_prune, # short cut layers + 'backbone.layer2.1.conv1_(0, 128)_128': un_prune, + 'backbone.layer2.2.conv1_(0, 128)_128': un_prune, + 'backbone.layer2.3.conv1_(0, 128)_128': un_prune, + # stage 3 + 'backbone.layer3.0.conv1_(0, 256)_256': un_prune, + 'backbone.layer3.0.conv2_(0, 256)_256': 0.8, # short cut layers + 'backbone.layer3.1.conv1_(0, 256)_256': un_prune, + 'backbone.layer3.2.conv1_(0, 256)_256': un_prune, + 'backbone.layer3.3.conv1_(0, 256)_256': un_prune, + 'backbone.layer3.4.conv1_(0, 256)_256': un_prune, + 'backbone.layer3.5.conv1_(0, 256)_256': un_prune, + # stage 4 + 'backbone.layer4.0.conv1_(0, 512)_512': un_prune, + 'backbone.layer4.0.conv2_(0, 512)_512': un_prune, # short cut layers + 'backbone.layer4.1.conv1_(0, 512)_512': un_prune, + 'backbone.layer4.2.conv1_(0, 512)_512': un_prune +} + +divisor = 8 +############################################################################## + +architecture = _base_.model + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherDeploySubModel', + architecture=architecture, + fix_subnet=fix_subnet, + divisor=divisor, +) diff --git a/configs/pruning/mmcls/l1-norm/metafile.yml b/configs/pruning/mmcls/l1-norm/metafile.yml new file mode 100644 index 000000000..3009fdc25 --- /dev/null +++ b/configs/pruning/mmcls/l1-norm/metafile.yml @@ -0,0 +1,28 @@ +Models: + - Name: l1-norm_resnet34_8xb32_in1k_a + In Collection: L1-norm + Config: configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_a.py + Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/l1-norm/l1-norm_resnet34_8xb32_in1k_a.pth + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 73.61 + - Name: l1-norm_resnet34_8xb32_in1k_b + In Collection: L1-norm + Config: configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_b.py + Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/l1-norm/l1-norm_resnet34_8xb32_in1k_b.pth + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 73.20 + - Name: l1-norm_resnet34_8xb32_in1k_c + In Collection: L1-norm + Config: configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_c.py + Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/l1-norm/l1-norm_resnet34_8xb32_in1k_c.pth + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 73.89 diff --git a/configs/pruning/mmcls/l1-norm/script.sh b/configs/pruning/mmcls/l1-norm/script.sh new file mode 100644 index 000000000..2bc1e9274 --- /dev/null +++ b/configs/pruning/mmcls/l1-norm/script.sh @@ -0,0 +1,25 @@ + +# export pruned checkpoint example + +python ./tools/pruning/get_static_model_from_algorithm.py configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_a.py https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/l1-norm/l1-norm_resnet34_8xb32_in1k_a.pth -o ./work_dirs/norm_resnet34_8xb32_in1k_a + +# deploy example + +razor_config=configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_a_deploy.py +deploy_config=mmdeploy/configs/mmcls/classification_onnxruntime_dynamic.py +static_model_checkpoint_path=path/to/pruend/checkpoint + +python mmdeploy/tools/deploy.py $deploy_config \ + $razor_config \ + $static_model_checkpoint_path \ + mmdeploy/tests/data/tiger.jpeg \ + --work-dir ./work_dirs/mmdeploy + +python mmdeploy/tools/profiler.py $deploy_config \ + $razor_config \ + mmdeploy/demo/resources \ + --model ./work_dirs/mmdeploy/end2end.onnx \ + --shape 224x224 \ + --device cpu \ + --num-iter 1000 \ + --warmup 100 diff --git a/configs/pruning/mmdet/group_fisher/retinanet/metafile.yml b/configs/pruning/mmdet/group_fisher/retinanet/metafile.yml new file mode 100644 index 000000000..232f9cb97 --- /dev/null +++ b/configs/pruning/mmdet/group_fisher/retinanet/metafile.yml @@ -0,0 +1,19 @@ +Models: + - Name: group_fisher_act_finetune_retinanet_r50_fpn_1x_coco + In Collection: GroupFisher + Config: configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py + Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.pth + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 36.5 + - Name: group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco + In Collection: GroupFisher + Config: configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py + Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.pth + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 36.6 diff --git a/model-index.yml b/model-index.yml index 90969c7e3..15e4595cb 100644 --- a/model-index.yml +++ b/model-index.yml @@ -21,3 +21,8 @@ Import: - configs/distill/mmdet/pkd/metafile.yml - configs/distill/mmdet3d/pkd/metafile.yml - configs/distill/mmcls/deit/metafile.yml + - configs/pruning/mmcls/group_fisher/mobilenet/metafile.yml + - configs/pruning/mmcls/group_fisher/resnet50/metafile.yml + - configs/pruning/mmdet/group_fisher/retinanet/metafile.yml + - configs/pruning/mmcls/l1-norm/metafile.yml + - configs/pruning/mmcls/dmcp/metafile.yml diff --git a/tools/pruning/get_static_model_from_algorithm.py b/tools/pruning/get_static_model_from_algorithm.py new file mode 100644 index 000000000..8d28842e7 --- /dev/null +++ b/tools/pruning/get_static_model_from_algorithm.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json +import os + +import torch +from mmengine import Config, fileio +from mmengine.runner.checkpoint import load_checkpoint + +from mmrazor.models import BaseAlgorithm +from mmrazor.registry import MODELS +from mmrazor.utils import print_log + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Export a pruned model checkpoint.') + parser.add_argument('config', help='config of the model') + parser.add_argument( + 'checkpoint', + default=None, + type=str, + help='checkpoint path of the model') + parser.add_argument( + '-o', + type=str, + default='', + help='output path to store the pruned checkpoint.') + args = parser.parse_args() + return args + + +def get_save_path(config_path, checkpoint_path, target_path): + if target_path != '': + work_dir = target_path + else: + work_dir = 'work_dirs/' + os.path.basename(config_path).split('.')[0] + + checkpoint_name = os.path.basename(checkpoint_path).split( + '.')[0] + '_pruned' + + return work_dir, checkpoint_name + + +def get_static_model(algorithm): + from mmrazor.structures.subnet import export_fix_subnet, load_fix_subnet + pruning_structure = algorithm.mutator.choice_template + + # to static model + fix_mutable = export_fix_subnet(algorithm.architecture)[0] + load_fix_subnet(algorithm.architecture, fix_mutable) + model = algorithm.architecture + return model, pruning_structure + + +if __name__ == '__main__': + # init + args = parse_args() + config_path = args.config + checkpoint_path = args.checkpoint + target_path = args.o + + work_dir, checkpoint_name = get_save_path(config_path, checkpoint_path, + target_path) + os.makedirs(work_dir, exist_ok=True) + + # build model + config = Config.fromfile(config_path) + model = MODELS.build(config.model) + assert isinstance(model, BaseAlgorithm), 'Model must be a BaseAlgorithm' + load_checkpoint(model, checkpoint_path, map_location='cpu') + + pruned_model, structure = get_static_model(model) + + # save + torch.save(pruned_model.state_dict(), + os.path.join(work_dir, checkpoint_name + '.pth')) + fileio.dump( + structure, os.path.join(work_dir, checkpoint_name + '.json'), indent=4) + + print_log('Save pruned model to {}'.format(work_dir)) + print_log('Pruning Structure: {}'.format(json.dumps(structure, indent=4)))