diff --git a/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k.py b/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k.py index 89ef4138f..c7c168cc3 100644 --- a/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k.py +++ b/configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k.py @@ -8,24 +8,24 @@ # the config template of target_pruning_ratio can be got by # python ./tools/get_channel_units.py {config_file} --choice target_pruning_ratio = { - 'backbone.layer1.2.conv2_(0, 64)_64': stage_ratio_1, + 'backbone.conv1_(0, 64)_64': stage_ratio_1, 'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1, 'backbone.layer1.1.conv1_(0, 64)_64': stage_ratio_1, 'backbone.layer1.2.conv1_(0, 64)_64': stage_ratio_1, 'backbone.layer2.0.conv1_(0, 128)_128': stage_ratio_2, - 'backbone.layer2.3.conv2_(0, 128)_128': stage_ratio_2, + 'backbone.layer2.0.conv2_(0, 128)_128': stage_ratio_2, 'backbone.layer2.1.conv1_(0, 128)_128': stage_ratio_2, 'backbone.layer2.2.conv1_(0, 128)_128': stage_ratio_2, 'backbone.layer2.3.conv1_(0, 128)_128': stage_ratio_2, 'backbone.layer3.0.conv1_(0, 256)_256': stage_ratio_3, - 'backbone.layer3.5.conv2_(0, 256)_256': stage_ratio_3, + 'backbone.layer3.0.conv2_(0, 256)_256': stage_ratio_3, 'backbone.layer3.1.conv1_(0, 256)_256': stage_ratio_3, 'backbone.layer3.2.conv1_(0, 256)_256': stage_ratio_3, 'backbone.layer3.3.conv1_(0, 256)_256': stage_ratio_3, 'backbone.layer3.4.conv1_(0, 256)_256': stage_ratio_3, 'backbone.layer3.5.conv1_(0, 256)_256': stage_ratio_3, 'backbone.layer4.0.conv1_(0, 512)_512': stage_ratio_4, - 'backbone.layer4.2.conv2_(0, 512)_512': stage_ratio_4, + 'backbone.layer4.0.conv2_(0, 512)_512': stage_ratio_4, 'backbone.layer4.1.conv1_(0, 512)_512': stage_ratio_4, 'backbone.layer4.2.conv1_(0, 512)_512': stage_ratio_4 } diff --git a/demo/config_pruning.ipynb b/demo/config_pruning.ipynb new file mode 100644 index 000000000..b5cff0088 --- /dev/null +++ b/demo/config_pruning.ipynb @@ -0,0 +1,816 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 使用MMRazor对ResNet34进行剪枝" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "本教程主要介绍如何手动配置剪枝config。此外外我们还提供一种自动获取剪枝config的方式,请参考 [Search and prune/准备剪枝Config](./search_and_prune.ipynb#prune_config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 回顾MMCls" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare config path\n", + "MMCLS_PATH='/home/liukai/Documents/mmlab2/others/mmclassification/'\n", + "config_file=MMCLS_PATH+'configs/resnet/resnet34_8xb32_in1k.py'" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "# Run config\n", + "# !python ./tools/train.py $config_file" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 准备剪枝config" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. 跨库调用resnet34配置文件\n", + "2. 增加pretrained参数\n", + "3. 将resnet34模型装入剪枝算法wrapper中\n", + "4. 配置剪枝比例\n", + "5. 运行" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "from mmengine import Config\n", + "prune_config_path='./prune_resnet34.py'\n", + "def write_config(config_str,filename):\n", + " with open(filename,'w') as f:\n", + " f.write(config_str)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. 跨库调用resnet34配置文件\n", + "\n", + "首先我们先跨库调用resnet34的配置文件。通过跨库调用,我们可以继承原有配置文件的所有内容。" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'type': 'ImageClassifier', 'backbone': {'type': 'ResNet', 'depth': 34, 'num_stages': 4, 'out_indices': (3,), 'style': 'pytorch'}, 'neck': {'type': 'GlobalAveragePooling'}, 'head': {'type': 'LinearClsHead', 'num_classes': 1000, 'in_channels': 512, 'loss': {'type': 'CrossEntropyLoss', 'loss_weight': 1.0}, 'topk': (1, 5)}, '_scope_': 'mmcls'}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/config/utils.py:50: UserWarning: There is not `Config` define in {'Name': 'convnext-base_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 15359124480, 'Parameters': 88591464}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_in21k_20220124-13b83eec.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n", + " warnings.warn(f'There is not `Config` define in {model_cfg}')\n", + "/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/config/utils.py:50: UserWarning: There is not `Config` define in {'Name': 'convnext-large_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 34368026112, 'Parameters': 197767336}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_in21k_20220124-41b5a79f.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n", + " warnings.warn(f'There is not `Config` define in {model_cfg}')\n", + "/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/config/utils.py:50: UserWarning: There is not `Config` define in {'Name': 'convnext-xlarge_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 60929820672, 'Parameters': 350196968}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_3rdparty_in21k_20220124-f909bad7.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n", + " warnings.warn(f'There is not `Config` define in {model_cfg}')\n", + "/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/config/utils.py:50: UserWarning: There is not `Config` define in {'Name': 'swinv2-base-w12_3rdparty_in21k-192px', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 8510000000, 'Parameters': 87920000}, 'In Collection': 'Swin-Transformer V2', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-base-w12_3rdparty_in21k-192px_20220803-f7dc9763.pth', 'Converted From': {'Weights': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth', 'Code': 'https://github.com/microsoft/Swin-Transformer'}}\n", + " warnings.warn(f'There is not `Config` define in {model_cfg}')\n", + "/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/config/utils.py:50: UserWarning: There is not `Config` define in {'Name': 'swinv2-large-w12_3rdparty_in21k-192px', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 19040000000, 'Parameters': 196740000}, 'In Collection': 'Swin-Transformer V2', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-large-w12_3rdparty_in21k-192px_20220803-d9073fee.pth', 'Converted From': {'Weights': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth', 'Code': 'https://github.com/microsoft/Swin-Transformer'}}\n", + " warnings.warn(f'There is not `Config` define in {model_cfg}')\n" + ] + } + ], + "source": [ + "config_string = \"\"\"\n", + "_base_ = ['mmcls::resnet/resnet34_8xb32_in1k.py']\n", + "\"\"\"\n", + "write_config(config_string, prune_config_path)\n", + "print(Config.fromfile(prune_config_path)['model'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 增加预训练参数\n", + "我们将原有的’model‘字段取出,命名为architecture,并且给archtecture增加init_cfg字段用来加载预训练模型参数。" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [], + "source": [ + "config_string += \"\"\"\\n\n", + "data_preprocessor = {'type': 'mmcls.ClsDataPreprocessor'}\n", + "architecture = _base_.model\n", + "architecture.update({\n", + " 'init_cfg': {\n", + " 'type':\n", + " 'Pretrained',\n", + " 'checkpoint':\n", + " 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa\n", + " }\n", + "})\n", + "\"\"\"\n", + "write_config(config_string, prune_config_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 将resnet34模型装入剪枝算法wrapper中\n", + "\n", + "我们将原有的model作为architecture放入到ItePruneAlgorithm算法中,并且将ItePruneAlgorithm作为新的model字段。" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [], + "source": [ + "config_string+=\"\"\"\n", + "target_pruning_ratio={}\n", + "model = dict(\n", + " _delete_=True,\n", + " _scope_='mmrazor',\n", + " type='ItePruneAlgorithm',\n", + " architecture=architecture,\n", + " mutator_cfg=dict(\n", + " type='ChannelMutator',\n", + " channel_unit_cfg=dict(\n", + " type='L1MutableChannelUnit',\n", + " default_args=dict(choice_mode='ratio'))),\n", + " target_pruning_ratio=target_pruning_ratio,\n", + " step_epoch=1,\n", + " prune_times=1,\n", + ")\n", + "\"\"\"\n", + "write_config(config_string, prune_config_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "配置到这一步时,我们的config文件已经能够运行了。但是因为我们没有配置target_pruning_ratio,因此现在跑起来就和直接用原有config跑起来没有区别,接下来我们会介绍如何配置剪枝比例" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "#! python ./tools/train.py $prune_config_path" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. 配置剪枝比例" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "我们的模型使用tracer解析模型,进而获得剪枝节点,为了方便用户配置剪枝节点比例,我们提供了一个获得剪枝节点剪枝比例配置的工具。通过该工具,我们可以方便地对剪枝比例进行配置。" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"backbone.conv1_(0, 64)_64\":1.0,\n", + " \"backbone.layer1.0.conv1_(0, 64)_64\":1.0,\n", + " \"backbone.layer1.1.conv1_(0, 64)_64\":1.0,\n", + " \"backbone.layer1.2.conv1_(0, 64)_64\":1.0,\n", + " \"backbone.layer2.0.conv1_(0, 128)_128\":1.0,\n", + " \"backbone.layer2.0.conv2_(0, 128)_128\":1.0,\n", + " \"backbone.layer2.1.conv1_(0, 128)_128\":1.0,\n", + " \"backbone.layer2.2.conv1_(0, 128)_128\":1.0,\n", + " \"backbone.layer2.3.conv1_(0, 128)_128\":1.0,\n", + " \"backbone.layer3.0.conv1_(0, 256)_256\":1.0,\n", + " \"backbone.layer3.0.conv2_(0, 256)_256\":1.0,\n", + " \"backbone.layer3.1.conv1_(0, 256)_256\":1.0,\n", + " \"backbone.layer3.2.conv1_(0, 256)_256\":1.0,\n", + " \"backbone.layer3.3.conv1_(0, 256)_256\":1.0,\n", + " \"backbone.layer3.4.conv1_(0, 256)_256\":1.0,\n", + " \"backbone.layer3.5.conv1_(0, 256)_256\":1.0,\n", + " \"backbone.layer4.0.conv1_(0, 512)_512\":1.0,\n", + " \"backbone.layer4.0.conv2_(0, 512)_512\":1.0,\n", + " \"backbone.layer4.1.conv1_(0, 512)_512\":1.0,\n", + " \"backbone.layer4.2.conv1_(0, 512)_512\":1.0\n", + "}" + ] + } + ], + "source": [ + "! python ./tools/get_channel_units.py $prune_config_path --choice -o prune_ratio_templace.json &> /dev/null 2>&1\n", + "! cat prune_ratio_templace.json\n", + "! rm prune_ratio_templace.json" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "我们修改该配置模板如下,并且将替换到我们的剪枝配置文件中。\n", + "\n", + "(该配置来源于:Li, Hao, et al. \"Pruning filters for efficient convnets.\" arXiv preprint arXiv:1608.08710 (2016).)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "target_config = \"\"\"\n", + "stage_ratio_1 = 0.7\n", + "stage_ratio_2 = 0.7\n", + "stage_ratio_3 = 0.7\n", + "stage_ratio_4 = 1.0\n", + "\n", + "target_pruning_ratio = {\n", + " \"backbone.conv1_(0, 64)_64\": stage_ratio_1,\n", + " \"backbone.layer1.0.conv1_(0, 64)_64\": stage_ratio_1,\n", + " \"backbone.layer1.1.conv1_(0, 64)_64\": stage_ratio_1,\n", + " \"backbone.layer1.2.conv1_(0, 64)_64\": stage_ratio_1,\n", + " \"backbone.layer2.0.conv1_(0, 128)_128\": stage_ratio_2,\n", + " \"backbone.layer2.0.conv2_(0, 128)_128\": stage_ratio_2,\n", + " \"backbone.layer2.1.conv1_(0, 128)_128\": stage_ratio_2,\n", + " \"backbone.layer2.2.conv1_(0, 128)_128\": stage_ratio_2,\n", + " \"backbone.layer2.3.conv1_(0, 128)_128\": stage_ratio_2,\n", + " \"backbone.layer3.0.conv1_(0, 256)_256\": stage_ratio_3,\n", + " \"backbone.layer3.0.conv2_(0, 256)_256\": stage_ratio_3,\n", + " \"backbone.layer3.1.conv1_(0, 256)_256\": stage_ratio_3,\n", + " \"backbone.layer3.2.conv1_(0, 256)_256\": stage_ratio_3,\n", + " \"backbone.layer3.3.conv1_(0, 256)_256\": stage_ratio_3,\n", + " \"backbone.layer3.4.conv1_(0, 256)_256\": stage_ratio_3,\n", + " \"backbone.layer3.5.conv1_(0, 256)_256\": stage_ratio_3,\n", + " \"backbone.layer4.0.conv1_(0, 512)_512\": stage_ratio_4,\n", + " \"backbone.layer4.0.conv2_(0, 512)_512\": stage_ratio_4,\n", + " \"backbone.layer4.1.conv1_(0, 512)_512\": stage_ratio_4,\n", + " \"backbone.layer4.2.conv1_(0, 512)_512\": stage_ratio_4\n", + "}\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "_base_ = ['mmcls::resnet/resnet34_8xb32_in1k.py']\n", + "\n", + "\n", + "data_preprocessor = {'type': 'mmcls.ClsDataPreprocessor'}\n", + "architecture = _base_.model\n", + "architecture.update({\n", + " 'init_cfg': {\n", + " 'type':\n", + " 'Pretrained',\n", + " 'checkpoint':\n", + " 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa\n", + " }\n", + "})\n", + "\n", + "\n", + "stage_ratio_1 = 0.7\n", + "stage_ratio_2 = 0.7\n", + "stage_ratio_3 = 0.7\n", + "stage_ratio_4 = 1.0\n", + "\n", + "target_pruning_ratio = {\n", + " \"backbone.conv1_(0, 64)_64\": stage_ratio_1,\n", + " \"backbone.layer1.0.conv1_(0, 64)_64\": stage_ratio_1,\n", + " \"backbone.layer1.1.conv1_(0, 64)_64\": stage_ratio_1,\n", + " \"backbone.layer1.2.conv1_(0, 64)_64\": stage_ratio_1,\n", + " \"backbone.layer2.0.conv1_(0, 128)_128\": stage_ratio_2,\n", + " \"backbone.layer2.0.conv2_(0, 128)_128\": stage_ratio_2,\n", + " \"backbone.layer2.1.conv1_(0, 128)_128\": stage_ratio_2,\n", + " \"backbone.layer2.2.conv1_(0, 128)_128\": stage_ratio_2,\n", + " \"backbone.layer2.3.conv1_(0, 128)_128\": stage_ratio_2,\n", + " \"backbone.layer3.0.conv1_(0, 256)_256\": stage_ratio_3,\n", + " \"backbone.layer3.0.conv2_(0, 256)_256\": stage_ratio_3,\n", + " \"backbone.layer3.1.conv1_(0, 256)_256\": stage_ratio_3,\n", + " \"backbone.layer3.2.conv1_(0, 256)_256\": stage_ratio_3,\n", + " \"backbone.layer3.3.conv1_(0, 256)_256\": stage_ratio_3,\n", + " \"backbone.layer3.4.conv1_(0, 256)_256\": stage_ratio_3,\n", + " \"backbone.layer3.5.conv1_(0, 256)_256\": stage_ratio_3,\n", + " \"backbone.layer4.0.conv1_(0, 512)_512\": stage_ratio_4,\n", + " \"backbone.layer4.0.conv2_(0, 512)_512\": stage_ratio_4,\n", + " \"backbone.layer4.1.conv1_(0, 512)_512\": stage_ratio_4,\n", + " \"backbone.layer4.2.conv1_(0, 512)_512\": stage_ratio_4\n", + "}\n", + "\n", + "model = dict(\n", + " _delete_=True,\n", + " _scope_='mmrazor',\n", + " type='ItePruneAlgorithm',\n", + " architecture=architecture,\n", + " mutator_cfg=dict(\n", + " type='ChannelMutator',\n", + " channel_unit_cfg=dict(\n", + " type='L1MutableChannelUnit',\n", + " default_args=dict(choice_mode='ratio'))),\n", + " target_pruning_ratio=target_pruning_ratio,\n", + " step_epoch=1,\n", + " prune_times=1,\n", + ")\n" + ] + } + ], + "source": [ + "config_string=config_string.replace('target_pruning_ratio={}',target_config)\n", + "write_config(config_string,prune_config_path)\n", + "! cat $prune_config_path" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5. 运行" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/config/utils.py:50: UserWarning: There is not `Config` define in {'Name': 'convnext-base_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 15359124480, 'Parameters': 88591464}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_in21k_20220124-13b83eec.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n", + " warnings.warn(f'There is not `Config` define in {model_cfg}')\n", + "/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/config/utils.py:50: UserWarning: There is not `Config` define in {'Name': 'convnext-large_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 34368026112, 'Parameters': 197767336}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_in21k_20220124-41b5a79f.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n", + " warnings.warn(f'There is not `Config` define in {model_cfg}')\n", + "/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/config/utils.py:50: UserWarning: There is not `Config` define in {'Name': 'convnext-xlarge_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 60929820672, 'Parameters': 350196968}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_3rdparty_in21k_20220124-f909bad7.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n", + " warnings.warn(f'There is not `Config` define in {model_cfg}')\n", + "/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/config/utils.py:50: UserWarning: There is not `Config` define in {'Name': 'swinv2-base-w12_3rdparty_in21k-192px', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 8510000000, 'Parameters': 87920000}, 'In Collection': 'Swin-Transformer V2', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-base-w12_3rdparty_in21k-192px_20220803-f7dc9763.pth', 'Converted From': {'Weights': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth', 'Code': 'https://github.com/microsoft/Swin-Transformer'}}\n", + " warnings.warn(f'There is not `Config` define in {model_cfg}')\n", + "/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/config/utils.py:50: UserWarning: There is not `Config` define in {'Name': 'swinv2-large-w12_3rdparty_in21k-192px', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 19040000000, 'Parameters': 196740000}, 'In Collection': 'Swin-Transformer V2', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-large-w12_3rdparty_in21k-192px_20220803-d9073fee.pth', 'Converted From': {'Weights': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth', 'Code': 'https://github.com/microsoft/Swin-Transformer'}}\n", + " warnings.warn(f'There is not `Config` define in {model_cfg}')\n", + "11/02 17:45:20 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "------------------------------------------------------------\n", + "System environment:\n", + " sys.platform: linux\n", + " Python: 3.9.13 (main, Aug 25 2022, 23:26:10) [GCC 11.2.0]\n", + " CUDA available: True\n", + " numpy_random_seed: 961503846\n", + " GPU 0: NVIDIA GeForce GTX 1660 Ti\n", + " CUDA_HOME: /usr/local/cuda\n", + " NVCC: Cuda compilation tools, release 11.3, V11.3.58\n", + " GCC: gcc (Ubuntu 11.2.0-19ubuntu1) 11.2.0\n", + " PyTorch: 1.12.1+cu113\n", + " PyTorch compiling details: PyTorch built with:\n", + " - GCC 9.3\n", + " - C++ Version: 201402\n", + " - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications\n", + " - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)\n", + " - OpenMP 201511 (a.k.a. OpenMP 4.5)\n", + " - LAPACK is enabled (usually provided by MKL)\n", + " - NNPACK is enabled\n", + " - CPU capability usage: AVX2\n", + " - CUDA Runtime 11.3\n", + " - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86\n", + " - CuDNN 8.3.2 (built against CUDA 11.5)\n", + " - Magma 2.5.2\n", + " - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.3, CUDNN_VERSION=8.3.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.12.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, \n", + "\n", + " TorchVision: 0.13.1+cu113\n", + " OpenCV: 4.6.0\n", + " MMEngine: 0.1.0\n", + "\n", + "Runtime environment:\n", + " cudnn_benchmark: False\n", + " mp_cfg: {'mp_start_method': 'fork', 'opencv_num_threads': 0}\n", + " dist_cfg: {'backend': 'nccl'}\n", + " seed: None\n", + " Distributed launcher: none\n", + " Distributed training: False\n", + " GPU number: 1\n", + "------------------------------------------------------------\n", + "\n", + "11/02 17:45:21 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Config:\n", + "model = dict(\n", + " _scope_='mmrazor',\n", + " type='ItePruneAlgorithm',\n", + " architecture=dict(\n", + " type='ImageClassifier',\n", + " backbone=dict(\n", + " type='ResNet',\n", + " depth=34,\n", + " num_stages=4,\n", + " out_indices=(3, ),\n", + " style='pytorch'),\n", + " neck=dict(type='GlobalAveragePooling'),\n", + " head=dict(\n", + " type='LinearClsHead',\n", + " num_classes=1000,\n", + " in_channels=512,\n", + " loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n", + " topk=(1, 5)),\n", + " _scope_='mmcls',\n", + " init_cfg=dict(\n", + " type='Pretrained',\n", + " checkpoint=\n", + " 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth'\n", + " )),\n", + " mutator_cfg=dict(\n", + " type='ChannelMutator',\n", + " channel_unit_cfg=dict(\n", + " type='L1MutableChannelUnit',\n", + " default_args=dict(choice_mode='ratio'))),\n", + " target_pruning_ratio=dict({\n", + " 'backbone.conv1_(0, 64)_64': 0.7,\n", + " 'backbone.layer1.0.conv1_(0, 64)_64': 0.7,\n", + " 'backbone.layer1.1.conv1_(0, 64)_64': 0.7,\n", + " 'backbone.layer1.2.conv1_(0, 64)_64': 0.7,\n", + " 'backbone.layer2.0.conv1_(0, 128)_128': 0.7,\n", + " 'backbone.layer2.0.conv2_(0, 128)_128': 0.7,\n", + " 'backbone.layer2.1.conv1_(0, 128)_128': 0.7,\n", + " 'backbone.layer2.2.conv1_(0, 128)_128': 0.7,\n", + " 'backbone.layer2.3.conv1_(0, 128)_128': 0.7,\n", + " 'backbone.layer3.0.conv1_(0, 256)_256': 0.7,\n", + " 'backbone.layer3.0.conv2_(0, 256)_256': 0.7,\n", + " 'backbone.layer3.1.conv1_(0, 256)_256': 0.7,\n", + " 'backbone.layer3.2.conv1_(0, 256)_256': 0.7,\n", + " 'backbone.layer3.3.conv1_(0, 256)_256': 0.7,\n", + " 'backbone.layer3.4.conv1_(0, 256)_256': 0.7,\n", + " 'backbone.layer3.5.conv1_(0, 256)_256': 0.7,\n", + " 'backbone.layer4.0.conv1_(0, 512)_512': 1.0,\n", + " 'backbone.layer4.0.conv2_(0, 512)_512': 1.0,\n", + " 'backbone.layer4.1.conv1_(0, 512)_512': 1.0,\n", + " 'backbone.layer4.2.conv1_(0, 512)_512': 1.0\n", + " }),\n", + " step_epoch=1,\n", + " prune_times=1)\n", + "dataset_type = 'ImageNet'\n", + "data_preprocessor = dict(\n", + " mean=[123.675, 116.28, 103.53],\n", + " std=[58.395, 57.12, 57.375],\n", + " to_rgb=True,\n", + " type='mmcls.ClsDataPreprocessor')\n", + "train_pipeline = [\n", + " dict(type='LoadImageFromFile', _scope_='mmcls'),\n", + " dict(type='RandomResizedCrop', scale=224, _scope_='mmcls'),\n", + " dict(type='RandomFlip', prob=0.5, direction='horizontal', _scope_='mmcls'),\n", + " dict(type='PackClsInputs', _scope_='mmcls')\n", + "]\n", + "test_pipeline = [\n", + " dict(type='LoadImageFromFile', _scope_='mmcls'),\n", + " dict(type='ResizeEdge', scale=256, edge='short', _scope_='mmcls'),\n", + " dict(type='CenterCrop', crop_size=224, _scope_='mmcls'),\n", + " dict(type='PackClsInputs', _scope_='mmcls')\n", + "]\n", + "train_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/train.txt',\n", + " data_prefix='train',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='RandomResizedCrop', scale=224),\n", + " dict(type='RandomFlip', prob=0.5, direction='horizontal'),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=True, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "val_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/val.txt',\n", + " data_prefix='val',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='ResizeEdge', scale=256, edge='short'),\n", + " dict(type='CenterCrop', crop_size=224),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "val_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n", + "test_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/val.txt',\n", + " data_prefix='val',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='ResizeEdge', scale=256, edge='short'),\n", + " dict(type='CenterCrop', crop_size=224),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "test_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n", + "optim_wrapper = dict(\n", + " optimizer=dict(\n", + " type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001,\n", + " _scope_='mmcls'))\n", + "param_scheduler = dict(\n", + " type='MultiStepLR',\n", + " by_epoch=True,\n", + " milestones=[30, 60, 90],\n", + " gamma=0.1,\n", + " _scope_='mmcls')\n", + "train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)\n", + "val_cfg = dict()\n", + "test_cfg = dict()\n", + "auto_scale_lr = dict(base_batch_size=256)\n", + "default_scope = 'mmcls'\n", + "default_hooks = dict(\n", + " timer=dict(type='IterTimerHook', _scope_='mmcls'),\n", + " logger=dict(type='LoggerHook', interval=100, _scope_='mmcls'),\n", + " param_scheduler=dict(type='ParamSchedulerHook', _scope_='mmcls'),\n", + " checkpoint=dict(type='CheckpointHook', interval=1, _scope_='mmcls'),\n", + " sampler_seed=dict(type='DistSamplerSeedHook', _scope_='mmcls'),\n", + " visualization=dict(\n", + " type='VisualizationHook', enable=False, _scope_='mmcls'))\n", + "env_cfg = dict(\n", + " cudnn_benchmark=False,\n", + " mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),\n", + " dist_cfg=dict(backend='nccl'))\n", + "vis_backends = [dict(type='LocalVisBackend', _scope_='mmcls')]\n", + "visualizer = dict(\n", + " type='ClsVisualizer',\n", + " vis_backends=[dict(type='LocalVisBackend')],\n", + " _scope_='mmcls')\n", + "log_level = 'INFO'\n", + "load_from = None\n", + "resume = False\n", + "architecture = dict(\n", + " type='ImageClassifier',\n", + " backbone=dict(\n", + " type='ResNet',\n", + " depth=34,\n", + " num_stages=4,\n", + " out_indices=(3, ),\n", + " style='pytorch'),\n", + " neck=dict(type='GlobalAveragePooling'),\n", + " head=dict(\n", + " type='LinearClsHead',\n", + " num_classes=1000,\n", + " in_channels=512,\n", + " loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n", + " topk=(1, 5)),\n", + " _scope_='mmcls',\n", + " init_cfg=dict(\n", + " type='Pretrained',\n", + " checkpoint=\n", + " 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth'\n", + " ))\n", + "stage_ratio_1 = 0.7\n", + "stage_ratio_2 = 0.7\n", + "stage_ratio_3 = 0.7\n", + "stage_ratio_4 = 1.0\n", + "target_pruning_ratio = dict({\n", + " 'backbone.conv1_(0, 64)_64': 0.7,\n", + " 'backbone.layer1.0.conv1_(0, 64)_64': 0.7,\n", + " 'backbone.layer1.1.conv1_(0, 64)_64': 0.7,\n", + " 'backbone.layer1.2.conv1_(0, 64)_64': 0.7,\n", + " 'backbone.layer2.0.conv1_(0, 128)_128': 0.7,\n", + " 'backbone.layer2.0.conv2_(0, 128)_128': 0.7,\n", + " 'backbone.layer2.1.conv1_(0, 128)_128': 0.7,\n", + " 'backbone.layer2.2.conv1_(0, 128)_128': 0.7,\n", + " 'backbone.layer2.3.conv1_(0, 128)_128': 0.7,\n", + " 'backbone.layer3.0.conv1_(0, 256)_256': 0.7,\n", + " 'backbone.layer3.0.conv2_(0, 256)_256': 0.7,\n", + " 'backbone.layer3.1.conv1_(0, 256)_256': 0.7,\n", + " 'backbone.layer3.2.conv1_(0, 256)_256': 0.7,\n", + " 'backbone.layer3.3.conv1_(0, 256)_256': 0.7,\n", + " 'backbone.layer3.4.conv1_(0, 256)_256': 0.7,\n", + " 'backbone.layer3.5.conv1_(0, 256)_256': 0.7,\n", + " 'backbone.layer4.0.conv1_(0, 512)_512': 1.0,\n", + " 'backbone.layer4.0.conv2_(0, 512)_512': 1.0,\n", + " 'backbone.layer4.1.conv1_(0, 512)_512': 1.0,\n", + " 'backbone.layer4.2.conv1_(0, 512)_512': 1.0\n", + "})\n", + "launcher = 'none'\n", + "work_dir = './work_dirs/prune_resnet34'\n", + "\n", + "Result has been saved to /home/liukai/Documents/mmlab2/mmrazor_github2/work_dirs/prune_resnet34/modules_statistic_results.json\n", + "11/02 17:45:21 - mmengine - \u001b[5m\u001b[4m\u001b[33mWARNING\u001b[0m - add a input before backbone.conv1(backbone.conv1), error: backbone.conv1(backbone.conv1)\n", + "11/02 17:45:21 - mmengine - \u001b[5m\u001b[4m\u001b[33mWARNING\u001b[0m - add a output after head.fc(head.fc), error: head.fc(head.fc)\n", + "11/02 17:45:22 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Distributed training is not used, all SyncBatchNorm (SyncBN) layers in the model will be automatically reverted to BatchNormXd layers if they are used.\n", + "11/02 17:45:24 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - load model from: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth\n", + "11/02 17:45:24 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - http loads checkpoint from path: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth\n", + "11/02 17:45:24 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to /home/liukai/Documents/mmlab2/mmrazor_github2/work_dirs/prune_resnet34 by HardDiskBackend.\n", + "11/02 17:45:24 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - The model is pruned at 0th epoch once.\n", + "11/02 17:45:25 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:25 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 1 epochs\n", + "11/02 17:45:26 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - The model is pruned at 0th epoch once.\n", + "11/02 17:45:26 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [1][1/1] accuracy/top1: 0.0000 accuracy/top5: 5.0000\n", + "11/02 17:45:27 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:27 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 2 epochs\n", + "11/02 17:45:28 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [2][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:28 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:28 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 3 epochs\n", + "11/02 17:45:29 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [3][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:30 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:30 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 4 epochs\n", + "11/02 17:45:31 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [4][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:31 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:31 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 5 epochs\n", + "11/02 17:45:32 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [5][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:32 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:32 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 6 epochs\n", + "11/02 17:45:34 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [6][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:34 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:34 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 7 epochs\n", + "11/02 17:45:35 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [7][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:35 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:35 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 8 epochs\n", + "11/02 17:45:37 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [8][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:37 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:37 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 9 epochs\n", + "11/02 17:45:38 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [9][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:38 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:38 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 10 epochs\n", + "11/02 17:45:39 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [10][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:40 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:40 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 11 epochs\n", + "11/02 17:45:41 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [11][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:41 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:41 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 12 epochs\n", + "11/02 17:45:42 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [12][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:42 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:42 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 13 epochs\n", + "11/02 17:45:43 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [13][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:43 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:43 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 14 epochs\n", + "11/02 17:45:44 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [14][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:45 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:45 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 15 epochs\n", + "11/02 17:45:46 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [15][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:46 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:46 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 16 epochs\n", + "11/02 17:45:47 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [16][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:47 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:47 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 17 epochs\n", + "11/02 17:45:48 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [17][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:48 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:48 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 18 epochs\n", + "11/02 17:45:49 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [18][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:50 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:50 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 19 epochs\n", + "11/02 17:45:51 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [19][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:51 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:51 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 20 epochs\n", + "11/02 17:45:52 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [20][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:52 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:52 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 21 epochs\n", + "11/02 17:45:53 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [21][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:53 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:53 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 22 epochs\n", + "11/02 17:45:54 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [22][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:55 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:55 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 23 epochs\n", + "11/02 17:45:56 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [23][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:56 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:56 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 24 epochs\n", + "11/02 17:45:57 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [24][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:57 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:57 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 25 epochs\n", + "11/02 17:45:58 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [25][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:45:58 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:45:58 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 26 epochs\n", + "11/02 17:45:59 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [26][1/1] accuracy/top1: 0.0000 accuracy/top5: 0.0000\n", + "11/02 17:46:00 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_resnet34_20221102_174520\n", + "11/02 17:46:00 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 27 epochs\n", + "^C\n", + "Traceback (most recent call last):\n", + " File \"/home/liukai/Documents/mmlab2/mmrazor_github2/./tools/train.py\", line 110, in \n", + " main()\n", + " File \"/home/liukai/Documents/mmlab2/mmrazor_github2/./tools/train.py\", line 106, in main\n", + " runner.train()\n", + " File \"/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/runner/runner.py\", line 1631, in train\n", + " model = self.train_loop.run() # type: ignore\n", + " File \"/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/runner/loops.py\", line 88, in run\n", + " self.run_epoch()\n", + " File \"/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/runner/loops.py\", line 106, in run_epoch\n", + " self.runner.call_hook('after_train_epoch')\n", + " File \"/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/runner/runner.py\", line 1693, in call_hook\n", + " getattr(hook, fn_name)(self, **kwargs)\n", + " File \"/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/hooks/checkpoint_hook.py\", line 247, in after_train_epoch\n", + " self._save_checkpoint(runner)\n", + " File \"/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/dist/utils.py\", line 329, in wrapper\n", + " return func(*args, **kwargs)\n", + " File \"/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/hooks/checkpoint_hook.py\", line 286, in _save_checkpoint\n", + " runner.save_checkpoint(\n", + " File \"/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/dist/utils.py\", line 329, in wrapper\n", + " return func(*args, **kwargs)\n", + " File \"/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/runner/runner.py\", line 2092, in save_checkpoint\n", + " save_checkpoint(checkpoint, filepath)\n", + " File \"/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/runner/checkpoint.py\", line 698, in save_checkpoint\n", + " file_client.put(f.getvalue(), filename)\n", + "KeyboardInterrupt\n" + ] + } + ], + "source": [ + "! python ./tools/train.py $prune_config_path" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [], + "source": [ + "# 清理临时文件\n", + "# ! rm prune_ratio_templace.json" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.13 ('lab2max')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "e31a827d0913016ad78e01c7b97f787f4b9e53102dd62d238e8548bcd97ff875" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/demo/search_and_prune.ipynb b/demo/search_and_prune.ipynb new file mode 100644 index 000000000..3eaff9926 --- /dev/null +++ b/demo/search_and_prune.ipynb @@ -0,0 +1,2035 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.chdir('../')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 使用MMRazor对ResNet34进行剪枝" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. 回顾MMCLs\n", + "2. 搜索最优剪枝结构\n", + "3. 剪枝" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 回顾MMCls" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare config path\n", + "work_dir='./prune_example/'\n", + "config_path=f\"{work_dir}/configs/\"\n", + "! mkdir -p $config_path\n", + "\n", + "pretrained_path='https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "我们直接使用mmcls的配置文件。(这里使用了mmengine的跨库调用功能,请参考[MMEngine文档](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/config.html#id11))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "_base_ = ['mmcls::resnet/resnet34_8xb32_in1k.py']\n" + ] + } + ], + "source": [ + "! echo \"_base_ = ['mmcls::resnet/resnet34_8xb32_in1k.py']\" > $config_path/pretrain.py\n", + "! cat $config_path/pretrain.py\n", + "\n", + "# Run config\n", + "# ! python ./tools/train.py $config_path/pretrain.py" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/config/utils.py:50: UserWarning: There is not `Config` define in {'Name': 'convnext-base_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 15359124480, 'Parameters': 88591464}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_in21k_20220124-13b83eec.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n", + " warnings.warn(f'There is not `Config` define in {model_cfg}')\n", + "/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/config/utils.py:50: UserWarning: There is not `Config` define in {'Name': 'convnext-large_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 34368026112, 'Parameters': 197767336}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_in21k_20220124-41b5a79f.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n", + " warnings.warn(f'There is not `Config` define in {model_cfg}')\n", + "/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/config/utils.py:50: UserWarning: There is not `Config` define in {'Name': 'convnext-xlarge_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 60929820672, 'Parameters': 350196968}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_3rdparty_in21k_20220124-f909bad7.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n", + " warnings.warn(f'There is not `Config` define in {model_cfg}')\n", + "/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/config/utils.py:50: UserWarning: There is not `Config` define in {'Name': 'swinv2-base-w12_3rdparty_in21k-192px', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 8510000000, 'Parameters': 87920000}, 'In Collection': 'Swin-Transformer V2', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-base-w12_3rdparty_in21k-192px_20220803-f7dc9763.pth', 'Converted From': {'Weights': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth', 'Code': 'https://github.com/microsoft/Swin-Transformer'}}\n", + " warnings.warn(f'There is not `Config` define in {model_cfg}')\n", + "/home/liukai/Documents/mmlab2/others/max/mmengine/mmengine/config/utils.py:50: UserWarning: There is not `Config` define in {'Name': 'swinv2-large-w12_3rdparty_in21k-192px', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 19040000000, 'Parameters': 196740000}, 'In Collection': 'Swin-Transformer V2', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-large-w12_3rdparty_in21k-192px_20220803-d9073fee.pth', 'Converted From': {'Weights': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth', 'Code': 'https://github.com/microsoft/Swin-Transformer'}}\n", + " warnings.warn(f'There is not `Config` define in {model_cfg}')\n", + "11/08 14:14:45 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "------------------------------------------------------------\n", + "System environment:\n", + " sys.platform: linux\n", + " Python: 3.9.13 (main, Aug 25 2022, 23:26:10) [GCC 11.2.0]\n", + " CUDA available: True\n", + " numpy_random_seed: 305217533\n", + " GPU 0: NVIDIA GeForce GTX 1660 Ti\n", + " CUDA_HOME: /usr/local/cuda\n", + " NVCC: Cuda compilation tools, release 11.3, V11.3.58\n", + " GCC: gcc (Ubuntu 11.2.0-19ubuntu1) 11.2.0\n", + " PyTorch: 1.12.1+cu113\n", + " PyTorch compiling details: PyTorch built with:\n", + " - GCC 9.3\n", + " - C++ Version: 201402\n", + " - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications\n", + " - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)\n", + " - OpenMP 201511 (a.k.a. OpenMP 4.5)\n", + " - LAPACK is enabled (usually provided by MKL)\n", + " - NNPACK is enabled\n", + " - CPU capability usage: AVX2\n", + " - CUDA Runtime 11.3\n", + " - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86\n", + " - CuDNN 8.3.2 (built against CUDA 11.5)\n", + " - Magma 2.5.2\n", + " - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.3, CUDNN_VERSION=8.3.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.12.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, \n", + "\n", + " TorchVision: 0.13.1+cu113\n", + " OpenCV: 4.6.0\n", + " MMEngine: 0.1.0\n", + "\n", + "Runtime environment:\n", + " cudnn_benchmark: False\n", + " mp_cfg: {'mp_start_method': 'fork', 'opencv_num_threads': 0}\n", + " dist_cfg: {'backend': 'nccl'}\n", + " seed: None\n", + " Distributed launcher: none\n", + " Distributed training: False\n", + " GPU number: 1\n", + "------------------------------------------------------------\n", + "\n", + "11/08 14:14:45 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Config:\n", + "model = dict(\n", + " type='ImageClassifier',\n", + " backbone=dict(\n", + " type='ResNet',\n", + " depth=34,\n", + " num_stages=4,\n", + " out_indices=(3, ),\n", + " style='pytorch'),\n", + " neck=dict(type='GlobalAveragePooling'),\n", + " head=dict(\n", + " type='LinearClsHead',\n", + " num_classes=1000,\n", + " in_channels=512,\n", + " loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n", + " topk=(1, 5)),\n", + " _scope_='mmcls')\n", + "dataset_type = 'ImageNet'\n", + "data_preprocessor = dict(\n", + " mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n", + "train_pipeline = [\n", + " dict(type='LoadImageFromFile', _scope_='mmcls'),\n", + " dict(type='RandomResizedCrop', scale=224, _scope_='mmcls'),\n", + " dict(type='RandomFlip', prob=0.5, direction='horizontal', _scope_='mmcls'),\n", + " dict(type='PackClsInputs', _scope_='mmcls')\n", + "]\n", + "test_pipeline = [\n", + " dict(type='LoadImageFromFile', _scope_='mmcls'),\n", + " dict(type='ResizeEdge', scale=256, edge='short', _scope_='mmcls'),\n", + " dict(type='CenterCrop', crop_size=224, _scope_='mmcls'),\n", + " dict(type='PackClsInputs', _scope_='mmcls')\n", + "]\n", + "train_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/train.txt',\n", + " data_prefix='train',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='RandomResizedCrop', scale=224),\n", + " dict(type='RandomFlip', prob=0.5, direction='horizontal'),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=True, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "val_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/val.txt',\n", + " data_prefix='val',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='ResizeEdge', scale=256, edge='short'),\n", + " dict(type='CenterCrop', crop_size=224),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "val_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n", + "test_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/val.txt',\n", + " data_prefix='val',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='ResizeEdge', scale=256, edge='short'),\n", + " dict(type='CenterCrop', crop_size=224),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "test_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n", + "optim_wrapper = dict(\n", + " optimizer=dict(\n", + " type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001,\n", + " _scope_='mmcls'))\n", + "param_scheduler = dict(\n", + " type='MultiStepLR',\n", + " by_epoch=True,\n", + " milestones=[30, 60, 90],\n", + " gamma=0.1,\n", + " _scope_='mmcls')\n", + "train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)\n", + "val_cfg = dict()\n", + "test_cfg = dict()\n", + "auto_scale_lr = dict(base_batch_size=256)\n", + "default_scope = 'mmcls'\n", + "default_hooks = dict(\n", + " timer=dict(type='IterTimerHook', _scope_='mmcls'),\n", + " logger=dict(type='LoggerHook', interval=100, _scope_='mmcls'),\n", + " param_scheduler=dict(type='ParamSchedulerHook', _scope_='mmcls'),\n", + " checkpoint=dict(type='CheckpointHook', interval=1, _scope_='mmcls'),\n", + " sampler_seed=dict(type='DistSamplerSeedHook', _scope_='mmcls'),\n", + " visualization=dict(\n", + " type='VisualizationHook', enable=False, _scope_='mmcls'))\n", + "env_cfg = dict(\n", + " cudnn_benchmark=False,\n", + " mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),\n", + " dist_cfg=dict(backend='nccl'))\n", + "vis_backends = [dict(type='LocalVisBackend', _scope_='mmcls')]\n", + "visualizer = dict(\n", + " type='ClsVisualizer',\n", + " vis_backends=[dict(type='LocalVisBackend')],\n", + " _scope_='mmcls')\n", + "log_level = 'INFO'\n", + "load_from = None\n", + "resume = False\n", + "launcher = 'none'\n", + "work_dir = './work_dirs/pretrain'\n", + "\n", + "Result has been saved to /home/liukai/Documents/mmlab2/mmrazor_github2/work_dirs/pretrain/modules_statistic_results.json\n", + "11/08 14:14:46 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Distributed training is not used, all SyncBatchNorm (SyncBN) layers in the model will be automatically reverted to BatchNormXd layers if they are used.\n", + "11/08 14:14:46 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to /home/liukai/Documents/mmlab2/mmrazor_github2/work_dirs/pretrain by HardDiskBackend.\n", + "11/08 14:14:48 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: pretrain_20221108_141444\n", + "11/08 14:14:48 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 1 epochs\n", + "11/08 14:14:49 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [1][7/7] accuracy/top1: 1.0000 accuracy/top5: 8.0000\n", + "11/08 14:14:50 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: pretrain_20221108_141444\n", + "11/08 14:14:50 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 2 epochs\n", + "11/08 14:14:51 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [2][7/7] accuracy/top1: 3.5000 accuracy/top5: 12.5000\n", + "11/08 14:14:52 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: pretrain_20221108_141444\n", + "11/08 14:14:52 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 3 epochs\n" + ] + } + ], + "source": [ + "! timeout 10 python ./tools/train.py $config_path/pretrain.py" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 准备剪枝结构搜索config\n", + "1. 生成搜索config\n", + "2. 运行搜索config\n", + "3. 获得搜索出的最优结构" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.生成搜索config\n", + "我们提供了一个工具可以帮助我们一条命令行生成搜索算法config。该工具需要我们提供两个参数:config路径与checkpoint路径" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "usage: get_search_config.py [-h] [--flops-min FLOPS_MIN]\n", + " [--flops-max FLOPS_MAX] [-o O]\n", + " config checkpoint\n", + "\n", + "Get the config to search the pruning structure of a model\n", + "\n", + "positional arguments:\n", + " config config of the model\n", + " checkpoint checkpoint path of the model\n", + "\n", + "optional arguments:\n", + " -h, --help show this help message and exit\n", + " --flops-min FLOPS_MIN\n", + " minimal flops\n", + " --flops-max FLOPS_MAX\n", + " maximal flops\n", + " -o O output path to store the search config.\n" + ] + } + ], + "source": [ + "! python tools/get_search_config.py -h" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model = dict(\n", + " _scope_='mmrazor',\n", + " type='SearchWrapper',\n", + " architecture=dict(\n", + " type='ImageClassifier',\n", + " backbone=dict(\n", + " type='ResNet',\n", + " depth=34,\n", + " num_stages=4,\n", + " out_indices=(3, ),\n", + " style='pytorch'),\n", + " neck=dict(type='GlobalAveragePooling'),\n", + " head=dict(\n", + " type='LinearClsHead',\n", + " num_classes=1000,\n", + " in_channels=512,\n", + " loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n", + " topk=(1, 5)),\n", + " _scope_='mmcls',\n", + " init_cfg=dict(\n", + " type='Pretrained',\n", + " checkpoint=\n", + " 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth'\n", + " ),\n", + " data_preprocessor=dict(\n", + " mean=[123.675, 116.28, 103.53],\n", + " std=[58.395, 57.12, 57.375],\n", + " to_rgb=True)),\n", + " mutator_cfg=dict(\n", + " type='ChannelMutator',\n", + " channel_unit_cfg=dict(\n", + " type='L1MutableChannelUnit',\n", + " default_args=dict(choice_mode='ratio')),\n", + " parse_cfg=dict(\n", + " type='BackwardTracer',\n", + " loss_calculator=dict(\n", + " type='ImageClassifierPseudoLoss',\n", + " input_shape=(2, 3, 224, 224)))))\n", + "dataset_type = 'ImageNet'\n", + "data_preprocessor = None\n", + "train_pipeline = [\n", + " dict(type='LoadImageFromFile', _scope_='mmcls'),\n", + " dict(type='RandomResizedCrop', scale=224, _scope_='mmcls'),\n", + " dict(type='RandomFlip', prob=0.5, direction='horizontal', _scope_='mmcls'),\n", + " dict(type='PackClsInputs', _scope_='mmcls')\n", + "]\n", + "test_pipeline = [\n", + " dict(type='LoadImageFromFile', _scope_='mmcls'),\n", + " dict(type='ResizeEdge', scale=256, edge='short', _scope_='mmcls'),\n", + " dict(type='CenterCrop', crop_size=224, _scope_='mmcls'),\n", + " dict(type='PackClsInputs', _scope_='mmcls')\n", + "]\n", + "train_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/train.txt',\n", + " data_prefix='train',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='RandomResizedCrop', scale=224),\n", + " dict(type='RandomFlip', prob=0.5, direction='horizontal'),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=True, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "val_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='mmcls.ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/val.txt',\n", + " data_prefix='val',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='ResizeEdge', scale=256, edge='short'),\n", + " dict(type='CenterCrop', crop_size=224),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5), _scope_='mmcls')\n", + "test_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/val.txt',\n", + " data_prefix='val',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='ResizeEdge', scale=256, edge='short'),\n", + " dict(type='CenterCrop', crop_size=224),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "test_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n", + "optim_wrapper = dict(\n", + " optimizer=dict(\n", + " type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001,\n", + " _scope_='mmcls'))\n", + "param_scheduler = dict(\n", + " type='MultiStepLR',\n", + " by_epoch=True,\n", + " milestones=[30, 60, 90],\n", + " gamma=0.1,\n", + " _scope_='mmcls')\n", + "train_cfg = dict(\n", + " type='mmrazor.PruneEvolutionSearchLoop',\n", + " dataloader=dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='mmcls.ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/val.txt',\n", + " data_prefix='val',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='ResizeEdge', scale=256, edge='short'),\n", + " dict(type='CenterCrop', crop_size=224),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n", + " persistent_workers=True),\n", + " evaluator=dict(type='mmcls.Accuracy', topk=(1, 5), _scope_='mmcls'),\n", + " max_epochs=20,\n", + " num_candidates=20,\n", + " top_k=5,\n", + " num_mutation=10,\n", + " num_crossover=10,\n", + " mutate_prob=0.2,\n", + " flops_range=(0.45, 0.55),\n", + " score_key='accuracy/top1')\n", + "val_cfg = dict()\n", + "test_cfg = dict()\n", + "auto_scale_lr = dict(base_batch_size=256)\n", + "default_scope = 'mmcls'\n", + "default_hooks = dict(\n", + " timer=dict(type='IterTimerHook', _scope_='mmcls'),\n", + " logger=dict(type='LoggerHook', interval=100, _scope_='mmcls'),\n", + " param_scheduler=dict(type='ParamSchedulerHook', _scope_='mmcls'),\n", + " checkpoint=dict(type='CheckpointHook', interval=1, _scope_='mmcls'),\n", + " sampler_seed=dict(type='DistSamplerSeedHook', _scope_='mmcls'),\n", + " visualization=dict(\n", + " type='VisualizationHook', enable=False, _scope_='mmcls'))\n", + "env_cfg = dict(\n", + " cudnn_benchmark=False,\n", + " mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),\n", + " dist_cfg=dict(backend='nccl'))\n", + "vis_backends = [dict(type='LocalVisBackend', _scope_='mmcls')]\n", + "visualizer = dict(\n", + " type='ClsVisualizer',\n", + " vis_backends=[dict(type='LocalVisBackend')],\n", + " _scope_='mmcls')\n", + "log_level = 'INFO'\n", + "load_from = None\n", + "resume = False\n" + ] + } + ], + "source": [ + "! python tools/get_search_config.py $config_path/pretrain.py $pretrained_path -o $config_path/search.py &> /dev/null \n", + "! cat $config_path/search.py " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "对于生成的config,我们可以进一步进行修改,以匹配自己的需求。其中\"model\"和“train_cfg”是最重要的两个字段。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 运行搜索config" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11/08 14:14:57 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "------------------------------------------------------------\n", + "System environment:\n", + " sys.platform: linux\n", + " Python: 3.9.13 (main, Aug 25 2022, 23:26:10) [GCC 11.2.0]\n", + " CUDA available: True\n", + " numpy_random_seed: 1574091845\n", + " GPU 0: NVIDIA GeForce GTX 1660 Ti\n", + " CUDA_HOME: /usr/local/cuda\n", + " NVCC: Cuda compilation tools, release 11.3, V11.3.58\n", + " GCC: gcc (Ubuntu 11.2.0-19ubuntu1) 11.2.0\n", + " PyTorch: 1.12.1+cu113\n", + " PyTorch compiling details: PyTorch built with:\n", + " - GCC 9.3\n", + " - C++ Version: 201402\n", + " - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications\n", + " - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)\n", + " - OpenMP 201511 (a.k.a. OpenMP 4.5)\n", + " - LAPACK is enabled (usually provided by MKL)\n", + " - NNPACK is enabled\n", + " - CPU capability usage: AVX2\n", + " - CUDA Runtime 11.3\n", + " - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86\n", + " - CuDNN 8.3.2 (built against CUDA 11.5)\n", + " - Magma 2.5.2\n", + " - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.3, CUDNN_VERSION=8.3.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.12.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, \n", + "\n", + " TorchVision: 0.13.1+cu113\n", + " OpenCV: 4.6.0\n", + " MMEngine: 0.1.0\n", + "\n", + "Runtime environment:\n", + " cudnn_benchmark: False\n", + " mp_cfg: {'mp_start_method': 'fork', 'opencv_num_threads': 0}\n", + " dist_cfg: {'backend': 'nccl'}\n", + " seed: None\n", + " Distributed launcher: none\n", + " Distributed training: False\n", + " GPU number: 1\n", + "------------------------------------------------------------\n", + "\n", + "11/08 14:14:58 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Config:\n", + "model = dict(\n", + " _scope_='mmrazor',\n", + " type='SearchWrapper',\n", + " architecture=dict(\n", + " type='ImageClassifier',\n", + " backbone=dict(\n", + " type='ResNet',\n", + " depth=34,\n", + " num_stages=4,\n", + " out_indices=(3, ),\n", + " style='pytorch'),\n", + " neck=dict(type='GlobalAveragePooling'),\n", + " head=dict(\n", + " type='LinearClsHead',\n", + " num_classes=1000,\n", + " in_channels=512,\n", + " loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n", + " topk=(1, 5)),\n", + " _scope_='mmcls',\n", + " init_cfg=dict(\n", + " type='Pretrained',\n", + " checkpoint=\n", + " 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth'\n", + " ),\n", + " data_preprocessor=dict(\n", + " mean=[123.675, 116.28, 103.53],\n", + " std=[58.395, 57.12, 57.375],\n", + " to_rgb=True)),\n", + " mutator_cfg=dict(\n", + " type='ChannelMutator',\n", + " channel_unit_cfg=dict(\n", + " type='L1MutableChannelUnit',\n", + " default_args=dict(choice_mode='ratio')),\n", + " parse_cfg=dict(\n", + " type='BackwardTracer',\n", + " loss_calculator=dict(\n", + " type='ImageClassifierPseudoLoss',\n", + " input_shape=(2, 3, 224, 224)))))\n", + "dataset_type = 'ImageNet'\n", + "data_preprocessor = None\n", + "train_pipeline = [\n", + " dict(type='LoadImageFromFile', _scope_='mmcls'),\n", + " dict(type='RandomResizedCrop', scale=224, _scope_='mmcls'),\n", + " dict(type='RandomFlip', prob=0.5, direction='horizontal', _scope_='mmcls'),\n", + " dict(type='PackClsInputs', _scope_='mmcls')\n", + "]\n", + "test_pipeline = [\n", + " dict(type='LoadImageFromFile', _scope_='mmcls'),\n", + " dict(type='ResizeEdge', scale=256, edge='short', _scope_='mmcls'),\n", + " dict(type='CenterCrop', crop_size=224, _scope_='mmcls'),\n", + " dict(type='PackClsInputs', _scope_='mmcls')\n", + "]\n", + "train_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/train.txt',\n", + " data_prefix='train',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='RandomResizedCrop', scale=224),\n", + " dict(type='RandomFlip', prob=0.5, direction='horizontal'),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=True, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "val_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='mmcls.ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/val.txt',\n", + " data_prefix='val',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='ResizeEdge', scale=256, edge='short'),\n", + " dict(type='CenterCrop', crop_size=224),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5), _scope_='mmcls')\n", + "test_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/val.txt',\n", + " data_prefix='val',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='ResizeEdge', scale=256, edge='short'),\n", + " dict(type='CenterCrop', crop_size=224),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "test_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n", + "optim_wrapper = dict(\n", + " optimizer=dict(\n", + " type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001,\n", + " _scope_='mmcls'))\n", + "param_scheduler = dict(\n", + " type='MultiStepLR',\n", + " by_epoch=True,\n", + " milestones=[30, 60, 90],\n", + " gamma=0.1,\n", + " _scope_='mmcls')\n", + "train_cfg = dict(\n", + " type='mmrazor.PruneEvolutionSearchLoop',\n", + " dataloader=dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='mmcls.ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/val.txt',\n", + " data_prefix='val',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='ResizeEdge', scale=256, edge='short'),\n", + " dict(type='CenterCrop', crop_size=224),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n", + " persistent_workers=True),\n", + " evaluator=dict(type='mmcls.Accuracy', topk=(1, 5), _scope_='mmcls'),\n", + " max_epochs=20,\n", + " num_candidates=20,\n", + " top_k=5,\n", + " num_mutation=10,\n", + " num_crossover=10,\n", + " mutate_prob=0.2,\n", + " flops_range=(0.45, 0.55),\n", + " score_key='accuracy/top1')\n", + "val_cfg = dict()\n", + "test_cfg = dict()\n", + "auto_scale_lr = dict(base_batch_size=256)\n", + "default_scope = 'mmcls'\n", + "default_hooks = dict(\n", + " timer=dict(type='IterTimerHook', _scope_='mmcls'),\n", + " logger=dict(type='LoggerHook', interval=100, _scope_='mmcls'),\n", + " param_scheduler=dict(type='ParamSchedulerHook', _scope_='mmcls'),\n", + " checkpoint=dict(type='CheckpointHook', interval=1, _scope_='mmcls'),\n", + " sampler_seed=dict(type='DistSamplerSeedHook', _scope_='mmcls'),\n", + " visualization=dict(\n", + " type='VisualizationHook', enable=False, _scope_='mmcls'))\n", + "env_cfg = dict(\n", + " cudnn_benchmark=False,\n", + " mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),\n", + " dist_cfg=dict(backend='nccl'))\n", + "vis_backends = [dict(type='LocalVisBackend', _scope_='mmcls')]\n", + "visualizer = dict(\n", + " type='ClsVisualizer',\n", + " vis_backends=[dict(type='LocalVisBackend')],\n", + " _scope_='mmcls')\n", + "log_level = 'INFO'\n", + "load_from = None\n", + "resume = False\n", + "launcher = 'none'\n", + "work_dir = './prune_example//search/'\n", + "\n", + "Result has been saved to /home/liukai/Documents/mmlab2/mmrazor_github2/prune_example/search/modules_statistic_results.json\n", + "11/08 14:14:58 - mmengine - \u001b[5m\u001b[4m\u001b[33mWARNING\u001b[0m - add a input before backbone.conv1(backbone.conv1), error: backbone.conv1(backbone.conv1)\n", + "11/08 14:14:58 - mmengine - \u001b[5m\u001b[4m\u001b[33mWARNING\u001b[0m - add a output after head.fc(head.fc), error: head.fc(head.fc)\n", + "11/08 14:14:59 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Distributed training is not used, all SyncBatchNorm (SyncBN) layers in the model will be automatically reverted to BatchNormXd layers if they are used.\n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - load model from: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth\n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - http loads checkpoint from path: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth\n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.conv1.weight - torch.Size([64, 3, 7, 7]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.bn1.weight - torch.Size([64]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.bn1.bias - torch.Size([64]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.0.conv1.weight - torch.Size([64, 64, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.0.bn1.weight - torch.Size([64]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.0.bn1.bias - torch.Size([64]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.0.conv2.weight - torch.Size([64, 64, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.0.bn2.weight - torch.Size([64]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.0.bn2.bias - torch.Size([64]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.1.conv1.weight - torch.Size([64, 64, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.1.bn1.weight - torch.Size([64]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.1.bn1.bias - torch.Size([64]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.1.conv2.weight - torch.Size([64, 64, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.1.bn2.weight - torch.Size([64]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.1.bn2.bias - torch.Size([64]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.2.conv1.weight - torch.Size([64, 64, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.2.bn1.weight - torch.Size([64]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.2.bn1.bias - torch.Size([64]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.2.conv2.weight - torch.Size([64, 64, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.2.bn2.weight - torch.Size([64]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer1.2.bn2.bias - torch.Size([64]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.0.conv1.weight - torch.Size([128, 64, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.0.bn1.weight - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.0.bn1.bias - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.0.conv2.weight - torch.Size([128, 128, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.0.bn2.weight - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.0.bn2.bias - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.0.downsample.0.weight - torch.Size([128, 64, 1, 1]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.0.downsample.1.weight - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.0.downsample.1.bias - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.1.conv1.weight - torch.Size([128, 128, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.1.bn1.weight - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.1.bn1.bias - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.1.conv2.weight - torch.Size([128, 128, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.1.bn2.weight - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.1.bn2.bias - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.2.conv1.weight - torch.Size([128, 128, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.2.bn1.weight - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.2.bn1.bias - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.2.conv2.weight - torch.Size([128, 128, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.2.bn2.weight - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.2.bn2.bias - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.3.conv1.weight - torch.Size([128, 128, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.3.bn1.weight - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.3.bn1.bias - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.3.conv2.weight - torch.Size([128, 128, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.3.bn2.weight - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer2.3.bn2.bias - torch.Size([128]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.0.conv1.weight - torch.Size([256, 128, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.0.bn1.weight - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.0.bn1.bias - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.0.conv2.weight - torch.Size([256, 256, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.0.bn2.weight - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.0.bn2.bias - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.0.downsample.0.weight - torch.Size([256, 128, 1, 1]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.0.downsample.1.weight - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.0.downsample.1.bias - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.1.conv1.weight - torch.Size([256, 256, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.1.bn1.weight - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.1.bn1.bias - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.1.conv2.weight - torch.Size([256, 256, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.1.bn2.weight - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.1.bn2.bias - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.2.conv1.weight - torch.Size([256, 256, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.2.bn1.weight - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.2.bn1.bias - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.2.conv2.weight - torch.Size([256, 256, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.2.bn2.weight - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.2.bn2.bias - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.3.conv1.weight - torch.Size([256, 256, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.3.bn1.weight - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.3.bn1.bias - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.3.conv2.weight - torch.Size([256, 256, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.3.bn2.weight - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.3.bn2.bias - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.4.conv1.weight - torch.Size([256, 256, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.4.bn1.weight - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.4.bn1.bias - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.4.conv2.weight - torch.Size([256, 256, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.4.bn2.weight - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.4.bn2.bias - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.5.conv1.weight - torch.Size([256, 256, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.5.bn1.weight - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.5.bn1.bias - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.5.conv2.weight - torch.Size([256, 256, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.5.bn2.weight - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer3.5.bn2.bias - torch.Size([256]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.0.conv1.weight - torch.Size([512, 256, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.0.bn1.weight - torch.Size([512]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.0.bn1.bias - torch.Size([512]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.0.conv2.weight - torch.Size([512, 512, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.0.bn2.weight - torch.Size([512]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.0.bn2.bias - torch.Size([512]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.0.downsample.0.weight - torch.Size([512, 256, 1, 1]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.0.downsample.1.weight - torch.Size([512]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.0.downsample.1.bias - torch.Size([512]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.1.conv1.weight - torch.Size([512, 512, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.1.bn1.weight - torch.Size([512]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.1.bn1.bias - torch.Size([512]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.1.conv2.weight - torch.Size([512, 512, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.1.bn2.weight - torch.Size([512]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.1.bn2.bias - torch.Size([512]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.2.conv1.weight - torch.Size([512, 512, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.2.bn1.weight - torch.Size([512]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.2.bn1.bias - torch.Size([512]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.2.conv2.weight - torch.Size([512, 512, 3, 3]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.2.bn2.weight - torch.Size([512]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.backbone.layer4.2.bn2.bias - torch.Size([512]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.head.fc.weight - torch.Size([1000, 512]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "architecture.head.fc.bias - torch.Size([1000]): \n", + "PretrainedInit: load from https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth \n", + " \n", + "11/08 14:15:01 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to /home/liukai/Documents/mmlab2/mmrazor_github2/prune_example/search by HardDiskBackend.\n", + "11/08 14:15:15 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[1/20] Score:0.0\n", + "11/08 14:15:17 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[2/20] Score:0.0\n", + "11/08 14:15:18 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[3/20] Score:0.0\n", + "11/08 14:15:19 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[4/20] Score:1.0\n", + "11/08 14:15:20 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[5/20] Score:0.0\n", + "11/08 14:15:21 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[6/20] Score:0.0\n", + "11/08 14:15:21 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[7/20] Score:1.0\n", + "11/08 14:15:22 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[8/20] Score:0.0\n", + "11/08 14:15:23 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[9/20] Score:0.5\n", + "11/08 14:15:24 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[10/20] Score:0.0\n", + "11/08 14:15:25 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[11/20] Score:1.0\n", + "11/08 14:15:26 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[12/20] Score:0.0\n", + "11/08 14:15:27 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[13/20] Score:0.0\n", + "11/08 14:15:28 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[14/20] Score:0.0\n", + "11/08 14:15:29 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[15/20] Score:0.0\n", + "11/08 14:15:30 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[16/20] Score:0.0\n", + "11/08 14:15:31 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[17/20] Score:0.0\n", + "11/08 14:15:32 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[18/20] Score:0.5\n", + "11/08 14:15:33 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[19/20] Score:0.0\n", + "11/08 14:15:34 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[0/20] Candidate:[20/20] Score:0.5\n", + "11/08 14:15:34 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - top k scores before update: []\n", + "11/08 14:15:34 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - top k scores after update: [1.0, 1.0, 1.0, 0.5, 0.5]\n", + "11/08 14:15:40 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Search finished and best_fix_subnet.json saved in /home/liukai/Documents/mmlab2/mmrazor_github2/prune_example/search.\n", + "11/08 14:15:40 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[1/20], top1_score: 1.0\n", + "11/08 14:15:41 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[1/20] Candidate:[1/20] Score:0.5\n", + "11/08 14:15:42 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[1/20] Candidate:[2/20] Score:0.0\n", + "11/08 14:15:43 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[1/20] Candidate:[3/20] Score:0.0\n", + "11/08 14:15:44 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[1/20] Candidate:[4/20] Score:1.0\n", + "11/08 14:15:45 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[1/20] Candidate:[5/20] Score:0.0\n", + "11/08 14:15:46 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[1/20] Candidate:[6/20] Score:0.0\n", + "11/08 14:15:47 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[1/20] Candidate:[7/20] Score:1.0\n", + "11/08 14:15:48 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[1/20] Candidate:[8/20] Score:0.0\n", + "11/08 14:15:49 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[1/20] Candidate:[9/20] Score:0.0\n", + "11/08 14:15:50 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[1/20] Candidate:[10/20] Score:0.0\n", + "11/08 14:15:51 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[1/20] Candidate:[11/20] Score:1.5\n", + "11/08 14:15:52 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[1/20] Candidate:[12/20] Score:0.5\n", + "11/08 14:15:53 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[1/20] Candidate:[13/20] Score:0.0\n", + "11/08 14:15:54 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[1/20] Candidate:[14/20] Score:0.5\n", + "11/08 14:15:55 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch:[1/20] Candidate:[15/20] Score:0.0\n" + ] + } + ], + "source": [ + "! timeout 60 python ./tools/train.py $config_path/search.py --work-dir $work_dir/search/" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 获取搜索结果" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "搜索结果存放在work_dir目录下面,我们可以通过以下命令查看我们最终的搜索结果。搜索结果用一个字典表示,字典的key是每个可搜索通道节点的名字(name),value则是该通道节点保留的通道比例。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"backbone.conv1_(0, 64)_64\": 1.0,\n", + " \"backbone.layer1.0.conv1_(0, 64)_64\": 0.3426007989589187,\n", + " \"backbone.layer1.1.conv1_(0, 64)_64\": 0.2969206924310629,\n", + " \"backbone.layer1.2.conv1_(0, 64)_64\": 0.04568010652785583,\n", + " \"backbone.layer2.0.conv1_(0, 128)_128\": 1.0,\n", + " \"backbone.layer2.0.conv2_(0, 128)_128\": 1.0,\n", + " \"backbone.layer2.1.conv1_(0, 128)_128\": 0.19414045274338726,\n", + " \"backbone.layer2.2.conv1_(0, 128)_128\": 1.0,\n", + " \"backbone.layer2.3.conv1_(0, 128)_128\": 1.0,\n", + " \"backbone.layer3.0.conv1_(0, 256)_256\": 0.19414045274338726,\n", + " \"backbone.layer3.0.conv2_(0, 256)_256\": 0.8393719574493509,\n", + " \"backbone.layer3.1.conv1_(0, 256)_256\": 0.39970093211873853,\n", + " \"backbone.layer3.2.conv1_(0, 256)_256\": 0.13133030626758552,\n", + " \"backbone.layer3.3.conv1_(0, 256)_256\": 0.21127049269133322,\n", + " \"backbone.layer3.4.conv1_(0, 256)_256\": 1.0,\n", + " \"backbone.layer3.5.conv1_(0, 256)_256\": 0.017130039947945933,\n", + " \"backbone.layer4.0.conv1_(0, 512)_512\": 0.3340357789849457,\n", + " \"backbone.layer4.0.conv2_(0, 512)_512\": 0.7337367111036843,\n", + " \"backbone.layer4.1.conv1_(0, 512)_512\": 0.9078921172411345,\n", + " \"backbone.layer4.2.conv1_(0, 512)_512\": 0.9764122770329183\n", + "}" + ] + } + ], + "source": [ + "! cat $work_dir/search/best_fix_subnet.json" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 准备剪枝Config" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. 生成剪枝config模板\n", + "2. 替换剪枝比例" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. 生成剪枝config模板\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "usage: get_prune_config.py [-h] [-o O] config checkpoint\n", + "\n", + "Get the config to prune a model.\n", + "\n", + "positional arguments:\n", + " config config of the model\n", + " checkpoint checkpoint path of the model\n", + "\n", + "optional arguments:\n", + " -h, --help show this help message and exit\n", + " -o O output path to store the pruning config.\n" + ] + } + ], + "source": [ + "! python ./tools/get_prune_config.py -h" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model = dict(\n", + " _scope_='mmrazor',\n", + " type='ItePruneAlgorithm',\n", + " architecture=dict(\n", + " type='ImageClassifier',\n", + " backbone=dict(\n", + " type='ResNet',\n", + " depth=34,\n", + " num_stages=4,\n", + " out_indices=(3, ),\n", + " style='pytorch'),\n", + " neck=dict(type='GlobalAveragePooling'),\n", + " head=dict(\n", + " type='LinearClsHead',\n", + " num_classes=1000,\n", + " in_channels=512,\n", + " loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n", + " topk=(1, 5)),\n", + " _scope_='mmcls',\n", + " init_cfg=dict(\n", + " type='Pretrained',\n", + " checkpoint=\n", + " 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth'\n", + " ),\n", + " data_preprocessor=dict(\n", + " mean=[123.675, 116.28, 103.53],\n", + " std=[58.395, 57.12, 57.375],\n", + " to_rgb=True)),\n", + " target_pruning_ratio=dict({\n", + " 'backbone.conv1_(0, 64)_64': 1.0,\n", + " 'backbone.layer1.0.conv1_(0, 64)_64': 1.0,\n", + " 'backbone.layer1.1.conv1_(0, 64)_64': 1.0,\n", + " 'backbone.layer1.2.conv1_(0, 64)_64': 1.0,\n", + " 'backbone.layer2.0.conv1_(0, 128)_128': 1.0,\n", + " 'backbone.layer2.0.conv2_(0, 128)_128': 1.0,\n", + " 'backbone.layer2.1.conv1_(0, 128)_128': 1.0,\n", + " 'backbone.layer2.2.conv1_(0, 128)_128': 1.0,\n", + " 'backbone.layer2.3.conv1_(0, 128)_128': 1.0,\n", + " 'backbone.layer3.0.conv1_(0, 256)_256': 1.0,\n", + " 'backbone.layer3.0.conv2_(0, 256)_256': 1.0,\n", + " 'backbone.layer3.1.conv1_(0, 256)_256': 1.0,\n", + " 'backbone.layer3.2.conv1_(0, 256)_256': 1.0,\n", + " 'backbone.layer3.3.conv1_(0, 256)_256': 1.0,\n", + " 'backbone.layer3.4.conv1_(0, 256)_256': 1.0,\n", + " 'backbone.layer3.5.conv1_(0, 256)_256': 1.0,\n", + " 'backbone.layer4.0.conv1_(0, 512)_512': 1.0,\n", + " 'backbone.layer4.0.conv2_(0, 512)_512': 1.0,\n", + " 'backbone.layer4.1.conv1_(0, 512)_512': 1.0,\n", + " 'backbone.layer4.2.conv1_(0, 512)_512': 1.0\n", + " }),\n", + " mutator_cfg=dict(\n", + " type='ChannelMutator',\n", + " channel_unit_cfg=dict(\n", + " type='L1MutableChannelUnit',\n", + " default_args=dict(choice_mode='ratio')),\n", + " parse_cfg=dict(\n", + " type='BackwardTracer',\n", + " loss_calculator=dict(\n", + " type='ImageClassifierPseudoLoss',\n", + " input_shape=(2, 3, 32, 32)))))\n", + "dataset_type = 'ImageNet'\n", + "data_preprocessor = None\n", + "train_pipeline = [\n", + " dict(type='LoadImageFromFile', _scope_='mmcls'),\n", + " dict(type='RandomResizedCrop', scale=224, _scope_='mmcls'),\n", + " dict(type='RandomFlip', prob=0.5, direction='horizontal', _scope_='mmcls'),\n", + " dict(type='PackClsInputs', _scope_='mmcls')\n", + "]\n", + "test_pipeline = [\n", + " dict(type='LoadImageFromFile', _scope_='mmcls'),\n", + " dict(type='ResizeEdge', scale=256, edge='short', _scope_='mmcls'),\n", + " dict(type='CenterCrop', crop_size=224, _scope_='mmcls'),\n", + " dict(type='PackClsInputs', _scope_='mmcls')\n", + "]\n", + "train_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/train.txt',\n", + " data_prefix='train',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='RandomResizedCrop', scale=224),\n", + " dict(type='RandomFlip', prob=0.5, direction='horizontal'),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=True, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "val_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/val.txt',\n", + " data_prefix='val',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='ResizeEdge', scale=256, edge='short'),\n", + " dict(type='CenterCrop', crop_size=224),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "val_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n", + "test_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/val.txt',\n", + " data_prefix='val',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='ResizeEdge', scale=256, edge='short'),\n", + " dict(type='CenterCrop', crop_size=224),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "test_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n", + "optim_wrapper = dict(\n", + " optimizer=dict(\n", + " type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001,\n", + " _scope_='mmcls'))\n", + "param_scheduler = dict(\n", + " type='MultiStepLR',\n", + " by_epoch=True,\n", + " milestones=[30, 60, 90],\n", + " gamma=0.1,\n", + " _scope_='mmcls')\n", + "train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)\n", + "val_cfg = dict()\n", + "test_cfg = dict()\n", + "auto_scale_lr = dict(base_batch_size=256)\n", + "default_scope = 'mmcls'\n", + "default_hooks = dict(\n", + " timer=dict(type='IterTimerHook', _scope_='mmcls'),\n", + " logger=dict(type='LoggerHook', interval=100, _scope_='mmcls'),\n", + " param_scheduler=dict(type='ParamSchedulerHook', _scope_='mmcls'),\n", + " checkpoint=dict(type='CheckpointHook', interval=1, _scope_='mmcls'),\n", + " sampler_seed=dict(type='DistSamplerSeedHook', _scope_='mmcls'),\n", + " visualization=dict(\n", + " type='VisualizationHook', enable=False, _scope_='mmcls'))\n", + "env_cfg = dict(\n", + " cudnn_benchmark=False,\n", + " mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),\n", + " dist_cfg=dict(backend='nccl'))\n", + "vis_backends = [dict(type='LocalVisBackend', _scope_='mmcls')]\n", + "visualizer = dict(\n", + " type='ClsVisualizer',\n", + " vis_backends=[dict(type='LocalVisBackend')],\n", + " _scope_='mmcls')\n", + "log_level = 'INFO'\n", + "load_from = None\n", + "resume = False\n" + ] + } + ], + "source": [ + "! python ./tools/get_prune_config.py $config_path/pretrain.py $pretrained_path -o $config_path/prune.py &> /dev/null\n", + "! cat $config_path/prune.py" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 替换剪枝比例" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model = dict(\n", + " _scope_='mmrazor',\n", + " type='ItePruneAlgorithm',\n", + " architecture=dict(\n", + " type='ImageClassifier',\n", + " backbone=dict(\n", + " type='ResNet',\n", + " depth=34,\n", + " num_stages=4,\n", + " out_indices=(3, ),\n", + " style='pytorch'),\n", + " neck=dict(type='GlobalAveragePooling'),\n", + " head=dict(\n", + " type='LinearClsHead',\n", + " num_classes=1000,\n", + " in_channels=512,\n", + " loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n", + " topk=(1, 5)),\n", + " _scope_='mmcls',\n", + " init_cfg=dict(\n", + " type='Pretrained',\n", + " checkpoint=\n", + " 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth'\n", + " ),\n", + " data_preprocessor=dict(\n", + " mean=[123.675, 116.28, 103.53],\n", + " std=[58.395, 57.12, 57.375],\n", + " to_rgb=True)),\n", + " target_pruning_ratio=dict({\n", + " 'backbone.conv1_(0, 64)_64':\n", + " 1.0,\n", + " 'backbone.layer1.0.conv1_(0, 64)_64':\n", + " 0.3426007989589187,\n", + " 'backbone.layer1.1.conv1_(0, 64)_64':\n", + " 0.2969206924310629,\n", + " 'backbone.layer1.2.conv1_(0, 64)_64':\n", + " 0.04568010652785583,\n", + " 'backbone.layer2.0.conv1_(0, 128)_128':\n", + " 1.0,\n", + " 'backbone.layer2.0.conv2_(0, 128)_128':\n", + " 1.0,\n", + " 'backbone.layer2.1.conv1_(0, 128)_128':\n", + " 0.19414045274338726,\n", + " 'backbone.layer2.2.conv1_(0, 128)_128':\n", + " 1.0,\n", + " 'backbone.layer2.3.conv1_(0, 128)_128':\n", + " 1.0,\n", + " 'backbone.layer3.0.conv1_(0, 256)_256':\n", + " 0.19414045274338726,\n", + " 'backbone.layer3.0.conv2_(0, 256)_256':\n", + " 0.8393719574493509,\n", + " 'backbone.layer3.1.conv1_(0, 256)_256':\n", + " 0.39970093211873853,\n", + " 'backbone.layer3.2.conv1_(0, 256)_256':\n", + " 0.13133030626758552,\n", + " 'backbone.layer3.3.conv1_(0, 256)_256':\n", + " 0.21127049269133322,\n", + " 'backbone.layer3.4.conv1_(0, 256)_256':\n", + " 1.0,\n", + " 'backbone.layer3.5.conv1_(0, 256)_256':\n", + " 0.017130039947945933,\n", + " 'backbone.layer4.0.conv1_(0, 512)_512':\n", + " 0.3340357789849457,\n", + " 'backbone.layer4.0.conv2_(0, 512)_512':\n", + " 0.7337367111036843,\n", + " 'backbone.layer4.1.conv1_(0, 512)_512':\n", + " 0.9078921172411345,\n", + " 'backbone.layer4.2.conv1_(0, 512)_512':\n", + " 0.9764122770329183\n", + " }),\n", + " mutator_cfg=dict(\n", + " type='ChannelMutator',\n", + " channel_unit_cfg=dict(\n", + " type='L1MutableChannelUnit',\n", + " default_args=dict(choice_mode='ratio')),\n", + " parse_cfg=dict(\n", + " type='BackwardTracer',\n", + " loss_calculator=dict(\n", + " type='ImageClassifierPseudoLoss',\n", + " input_shape=(2, 3, 32, 32)))))\n", + "dataset_type = 'ImageNet'\n", + "data_preprocessor = None\n", + "train_pipeline = [\n", + " dict(type='LoadImageFromFile', _scope_='mmcls'),\n", + " dict(type='RandomResizedCrop', scale=224, _scope_='mmcls'),\n", + " dict(type='RandomFlip', prob=0.5, direction='horizontal', _scope_='mmcls'),\n", + " dict(type='PackClsInputs', _scope_='mmcls')\n", + "]\n", + "test_pipeline = [\n", + " dict(type='LoadImageFromFile', _scope_='mmcls'),\n", + " dict(type='ResizeEdge', scale=256, edge='short', _scope_='mmcls'),\n", + " dict(type='CenterCrop', crop_size=224, _scope_='mmcls'),\n", + " dict(type='PackClsInputs', _scope_='mmcls')\n", + "]\n", + "train_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/train.txt',\n", + " data_prefix='train',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='RandomResizedCrop', scale=224),\n", + " dict(type='RandomFlip', prob=0.5, direction='horizontal'),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=True, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "val_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/val.txt',\n", + " data_prefix='val',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='ResizeEdge', scale=256, edge='short'),\n", + " dict(type='CenterCrop', crop_size=224),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "val_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n", + "test_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/val.txt',\n", + " data_prefix='val',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='ResizeEdge', scale=256, edge='short'),\n", + " dict(type='CenterCrop', crop_size=224),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "test_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n", + "optim_wrapper = dict(\n", + " optimizer=dict(\n", + " type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001,\n", + " _scope_='mmcls'))\n", + "param_scheduler = dict(\n", + " type='MultiStepLR',\n", + " by_epoch=True,\n", + " milestones=[30, 60, 90],\n", + " gamma=0.1,\n", + " _scope_='mmcls')\n", + "train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)\n", + "val_cfg = dict()\n", + "test_cfg = dict()\n", + "auto_scale_lr = dict(base_batch_size=256)\n", + "default_scope = 'mmcls'\n", + "default_hooks = dict(\n", + " timer=dict(type='IterTimerHook', _scope_='mmcls'),\n", + " logger=dict(type='LoggerHook', interval=100, _scope_='mmcls'),\n", + " param_scheduler=dict(type='ParamSchedulerHook', _scope_='mmcls'),\n", + " checkpoint=dict(type='CheckpointHook', interval=1, _scope_='mmcls'),\n", + " sampler_seed=dict(type='DistSamplerSeedHook', _scope_='mmcls'),\n", + " visualization=dict(\n", + " type='VisualizationHook', enable=False, _scope_='mmcls'))\n", + "env_cfg = dict(\n", + " cudnn_benchmark=False,\n", + " mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),\n", + " dist_cfg=dict(backend='nccl'))\n", + "vis_backends = [dict(type='LocalVisBackend', _scope_='mmcls')]\n", + "visualizer = dict(\n", + " type='ClsVisualizer',\n", + " vis_backends=[dict(type='LocalVisBackend')],\n", + " _scope_='mmcls')\n", + "log_level = 'INFO'\n", + "load_from = None\n", + "resume = False\n" + ] + } + ], + "source": [ + "# change prune config\n", + "from mmengine import fileio,Config\n", + "pruning_subnet=fileio.load(f'{work_dir}/search/best_fix_subnet.json')\n", + "pruning_config=Config.fromfile(f'{config_path}/prune.py')\n", + "pruning_config['model']['target_pruning_ratio']=pruning_subnet\n", + "pruning_config.dump(f'{config_path}/prune.py')\n", + "! cat $config_path/prune.py" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 剪枝" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11/08 14:16:03 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - \n", + "------------------------------------------------------------\n", + "System environment:\n", + " sys.platform: linux\n", + " Python: 3.9.13 (main, Aug 25 2022, 23:26:10) [GCC 11.2.0]\n", + " CUDA available: True\n", + " numpy_random_seed: 1825000398\n", + " GPU 0: NVIDIA GeForce GTX 1660 Ti\n", + " CUDA_HOME: /usr/local/cuda\n", + " NVCC: Cuda compilation tools, release 11.3, V11.3.58\n", + " GCC: gcc (Ubuntu 11.2.0-19ubuntu1) 11.2.0\n", + " PyTorch: 1.12.1+cu113\n", + " PyTorch compiling details: PyTorch built with:\n", + " - GCC 9.3\n", + " - C++ Version: 201402\n", + " - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications\n", + " - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)\n", + " - OpenMP 201511 (a.k.a. OpenMP 4.5)\n", + " - LAPACK is enabled (usually provided by MKL)\n", + " - NNPACK is enabled\n", + " - CPU capability usage: AVX2\n", + " - CUDA Runtime 11.3\n", + " - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86\n", + " - CuDNN 8.3.2 (built against CUDA 11.5)\n", + " - Magma 2.5.2\n", + " - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.3, CUDNN_VERSION=8.3.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.12.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, \n", + "\n", + " TorchVision: 0.13.1+cu113\n", + " OpenCV: 4.6.0\n", + " MMEngine: 0.1.0\n", + "\n", + "Runtime environment:\n", + " cudnn_benchmark: False\n", + " mp_cfg: {'mp_start_method': 'fork', 'opencv_num_threads': 0}\n", + " dist_cfg: {'backend': 'nccl'}\n", + " seed: None\n", + " Distributed launcher: none\n", + " Distributed training: False\n", + " GPU number: 1\n", + "------------------------------------------------------------\n", + "\n", + "11/08 14:16:04 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Config:\n", + "model = dict(\n", + " _scope_='mmrazor',\n", + " type='ItePruneAlgorithm',\n", + " architecture=dict(\n", + " type='ImageClassifier',\n", + " backbone=dict(\n", + " type='ResNet',\n", + " depth=34,\n", + " num_stages=4,\n", + " out_indices=(3, ),\n", + " style='pytorch'),\n", + " neck=dict(type='GlobalAveragePooling'),\n", + " head=dict(\n", + " type='LinearClsHead',\n", + " num_classes=1000,\n", + " in_channels=512,\n", + " loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n", + " topk=(1, 5)),\n", + " _scope_='mmcls',\n", + " init_cfg=dict(\n", + " type='Pretrained',\n", + " checkpoint=\n", + " 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth'\n", + " ),\n", + " data_preprocessor=dict(\n", + " mean=[123.675, 116.28, 103.53],\n", + " std=[58.395, 57.12, 57.375],\n", + " to_rgb=True)),\n", + " target_pruning_ratio=dict({\n", + " 'backbone.conv1_(0, 64)_64':\n", + " 1.0,\n", + " 'backbone.layer1.0.conv1_(0, 64)_64':\n", + " 0.3426007989589187,\n", + " 'backbone.layer1.1.conv1_(0, 64)_64':\n", + " 0.2969206924310629,\n", + " 'backbone.layer1.2.conv1_(0, 64)_64':\n", + " 0.04568010652785583,\n", + " 'backbone.layer2.0.conv1_(0, 128)_128':\n", + " 1.0,\n", + " 'backbone.layer2.0.conv2_(0, 128)_128':\n", + " 1.0,\n", + " 'backbone.layer2.1.conv1_(0, 128)_128':\n", + " 0.19414045274338726,\n", + " 'backbone.layer2.2.conv1_(0, 128)_128':\n", + " 1.0,\n", + " 'backbone.layer2.3.conv1_(0, 128)_128':\n", + " 1.0,\n", + " 'backbone.layer3.0.conv1_(0, 256)_256':\n", + " 0.19414045274338726,\n", + " 'backbone.layer3.0.conv2_(0, 256)_256':\n", + " 0.8393719574493509,\n", + " 'backbone.layer3.1.conv1_(0, 256)_256':\n", + " 0.39970093211873853,\n", + " 'backbone.layer3.2.conv1_(0, 256)_256':\n", + " 0.13133030626758552,\n", + " 'backbone.layer3.3.conv1_(0, 256)_256':\n", + " 0.21127049269133322,\n", + " 'backbone.layer3.4.conv1_(0, 256)_256':\n", + " 1.0,\n", + " 'backbone.layer3.5.conv1_(0, 256)_256':\n", + " 0.017130039947945933,\n", + " 'backbone.layer4.0.conv1_(0, 512)_512':\n", + " 0.3340357789849457,\n", + " 'backbone.layer4.0.conv2_(0, 512)_512':\n", + " 0.7337367111036843,\n", + " 'backbone.layer4.1.conv1_(0, 512)_512':\n", + " 0.9078921172411345,\n", + " 'backbone.layer4.2.conv1_(0, 512)_512':\n", + " 0.9764122770329183\n", + " }),\n", + " mutator_cfg=dict(\n", + " type='ChannelMutator',\n", + " channel_unit_cfg=dict(\n", + " type='L1MutableChannelUnit',\n", + " default_args=dict(choice_mode='ratio')),\n", + " parse_cfg=dict(\n", + " type='BackwardTracer',\n", + " loss_calculator=dict(\n", + " type='ImageClassifierPseudoLoss',\n", + " input_shape=(2, 3, 32, 32)))))\n", + "dataset_type = 'ImageNet'\n", + "data_preprocessor = None\n", + "train_pipeline = [\n", + " dict(type='LoadImageFromFile', _scope_='mmcls'),\n", + " dict(type='RandomResizedCrop', scale=224, _scope_='mmcls'),\n", + " dict(type='RandomFlip', prob=0.5, direction='horizontal', _scope_='mmcls'),\n", + " dict(type='PackClsInputs', _scope_='mmcls')\n", + "]\n", + "test_pipeline = [\n", + " dict(type='LoadImageFromFile', _scope_='mmcls'),\n", + " dict(type='ResizeEdge', scale=256, edge='short', _scope_='mmcls'),\n", + " dict(type='CenterCrop', crop_size=224, _scope_='mmcls'),\n", + " dict(type='PackClsInputs', _scope_='mmcls')\n", + "]\n", + "train_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/train.txt',\n", + " data_prefix='train',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='RandomResizedCrop', scale=224),\n", + " dict(type='RandomFlip', prob=0.5, direction='horizontal'),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=True, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "val_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/val.txt',\n", + " data_prefix='val',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='ResizeEdge', scale=256, edge='short'),\n", + " dict(type='CenterCrop', crop_size=224),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "val_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n", + "test_dataloader = dict(\n", + " batch_size=32,\n", + " num_workers=5,\n", + " dataset=dict(\n", + " type='ImageNet',\n", + " data_root='data/imagenet',\n", + " ann_file='meta/val.txt',\n", + " data_prefix='val',\n", + " pipeline=[\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='ResizeEdge', scale=256, edge='short'),\n", + " dict(type='CenterCrop', crop_size=224),\n", + " dict(type='PackClsInputs')\n", + " ],\n", + " _scope_='mmcls'),\n", + " sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n", + " persistent_workers=True)\n", + "test_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n", + "optim_wrapper = dict(\n", + " optimizer=dict(\n", + " type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001,\n", + " _scope_='mmcls'))\n", + "param_scheduler = dict(\n", + " type='MultiStepLR',\n", + " by_epoch=True,\n", + " milestones=[30, 60, 90],\n", + " gamma=0.1,\n", + " _scope_='mmcls')\n", + "train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)\n", + "val_cfg = dict()\n", + "test_cfg = dict()\n", + "auto_scale_lr = dict(base_batch_size=256)\n", + "default_scope = 'mmcls'\n", + "default_hooks = dict(\n", + " timer=dict(type='IterTimerHook', _scope_='mmcls'),\n", + " logger=dict(type='LoggerHook', interval=100, _scope_='mmcls'),\n", + " param_scheduler=dict(type='ParamSchedulerHook', _scope_='mmcls'),\n", + " checkpoint=dict(type='CheckpointHook', interval=1, _scope_='mmcls'),\n", + " sampler_seed=dict(type='DistSamplerSeedHook', _scope_='mmcls'),\n", + " visualization=dict(\n", + " type='VisualizationHook', enable=False, _scope_='mmcls'))\n", + "env_cfg = dict(\n", + " cudnn_benchmark=False,\n", + " mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),\n", + " dist_cfg=dict(backend='nccl'))\n", + "vis_backends = [dict(type='LocalVisBackend', _scope_='mmcls')]\n", + "visualizer = dict(\n", + " type='ClsVisualizer',\n", + " vis_backends=[dict(type='LocalVisBackend')],\n", + " _scope_='mmcls')\n", + "log_level = 'INFO'\n", + "load_from = None\n", + "resume = False\n", + "launcher = 'none'\n", + "work_dir = './prune_example//prune'\n", + "\n", + "Result has been saved to /home/liukai/Documents/mmlab2/mmrazor_github2/prune_example/prune/modules_statistic_results.json\n", + "11/08 14:16:04 - mmengine - \u001b[5m\u001b[4m\u001b[33mWARNING\u001b[0m - add a input before backbone.conv1(backbone.conv1), error: backbone.conv1(backbone.conv1)\n", + "11/08 14:16:04 - mmengine - \u001b[5m\u001b[4m\u001b[33mWARNING\u001b[0m - add a output after head.fc(head.fc), error: head.fc(head.fc)\n", + "11/08 14:16:05 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Distributed training is not used, all SyncBatchNorm (SyncBN) layers in the model will be automatically reverted to BatchNormXd layers if they are used.\n", + "11/08 14:16:06 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - load model from: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth\n", + "11/08 14:16:06 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - http loads checkpoint from path: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth\n", + "11/08 14:16:06 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to /home/liukai/Documents/mmlab2/mmrazor_github2/prune_example/prune by HardDiskBackend.\n", + "11/08 14:16:06 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - The model is pruned at 0th epoch once.\n", + "11/08 14:16:08 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_20221108_141603\n", + "11/08 14:16:08 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 1 epochs\n", + "11/08 14:16:09 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [1][7/7] accuracy/top1: 1.0000 accuracy/top5: 1.0000\n", + "11/08 14:16:10 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Exp name: prune_20221108_141603\n", + "11/08 14:16:10 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Saving checkpoint at 2 epochs\n", + "11/08 14:16:11 - mmengine - \u001b[4m\u001b[37mINFO\u001b[0m - Epoch(val) [2][7/7] accuracy/top1: 0.5000 accuracy/top5: 0.5000\n" + ] + } + ], + "source": [ + "# run your prune config\n", + "! timeout 10 python ./tools/train.py $config_path/prune.py --work-dir $work_dir/prune" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# ! rm -r $work_dir" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.13 ('lab2max')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "e31a827d0913016ad78e01c7b97f787f4b9e53102dd62d238e8548bcd97ff875" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/en/user_guides/pruning_user_guide.md b/docs/en/user_guides/pruning_user_guide.md index a1600c887..dd8f73ae1 100644 --- a/docs/en/user_guides/pruning_user_guide.md +++ b/docs/en/user_guides/pruning_user_guide.md @@ -146,3 +146,6 @@ Please refer to the following documents for more details. - [MutableChannel](../../../mmrazor/models/mutables/mutable_channel/MutableChannel.md) - [ChannelMutator](../../../mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb) - [MutableChannelUnit](../../../mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb) +- Examples + - [Search and prune](../../../demo/search_and_prune.ipynb) + - [Config pruning](../../../demo/config_pruning.ipynb) diff --git a/mmrazor/engine/runner/__init__.py b/mmrazor/engine/runner/__init__.py index 9715a4e6b..a510212a9 100644 --- a/mmrazor/engine/runner/__init__.py +++ b/mmrazor/engine/runner/__init__.py @@ -3,11 +3,13 @@ from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop from .evolution_search_loop import EvolutionSearchLoop +from .prune_evolution_search_loop import PruneEvolutionSearchLoop from .slimmable_val_loop import SlimmableValLoop from .subnet_sampler_loop import GreedySamplerTrainLoop __all__ = [ 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', - 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop' + 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop', + 'PruneEvolutionSearchLoop' ] diff --git a/mmrazor/engine/runner/prune_evolution_search_loop.py b/mmrazor/engine/runner/prune_evolution_search_loop.py new file mode 100644 index 000000000..76d5c4afd --- /dev/null +++ b/mmrazor/engine/runner/prune_evolution_search_loop.py @@ -0,0 +1,230 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math +import os.path as osp +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmengine import fileio +from mmengine.dist import broadcast_object_list +from mmengine.evaluator import Evaluator +from torch.utils.data import DataLoader + +from mmrazor.models.task_modules import ResourceEstimator +from mmrazor.registry import LOOPS +from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet +from mmrazor.utils import SupportRandomSubnet +from .evolution_search_loop import EvolutionSearchLoop + + +def get_flops(model: nn.Module, subnet: SupportRandomSubnet, + estimator: ResourceEstimator): + """Check whether is beyond flops constraints. + + Returns: + bool: The result of checking. + """ + + assert hasattr(model, 'set_subnet') and hasattr(model, 'architecture') + model.set_subnet(subnet) + fix_mutable = export_fix_subnet(model) + copied_model = copy.deepcopy(model) + load_fix_subnet(copied_model, fix_mutable) + + model_to_check = model.architecture + + results = estimator.estimate(model=model_to_check) + + flops = results['flops'] + return flops + + +def auto_scale(subnet, target, now): + new_subnet = copy.deepcopy(subnet) + scale = math.sqrt(target / now) + for key in new_subnet: + new_subnet[key] = max(min(new_subnet[key] * scale, 1.0), 0.01) + return new_subnet + + +@LOOPS.register_module() +class PruneEvolutionSearchLoop(EvolutionSearchLoop): + """Loop for evolution searching. + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict): A dataloader object or a dict to + build a dataloader. + evaluator (Evaluator or dict or list): Used for computing metrics. + max_epochs (int): Total searching epochs. Defaults to 20. + max_keep_ckpts (int): The maximum checkpoints of searcher to keep. + Defaults to 3. + resume_from (str, optional): Specify the path of saved .pkl file for + resuming searching. + num_candidates (int): The length of candidate pool. Defaults to 50. + top_k (int): Specify top k candidates based on scores. Defaults to 10. + num_mutation (int): The number of candidates got by mutation. + Defaults to 25. + num_crossover (int): The number of candidates got by crossover. + Defaults to 25. + mutate_prob (float): The probability of mutation. Defaults to 0.1. + flops_range (tuple, optional): It is used for screening candidates. + resource_estimator_cfg (dict): The config for building estimator, which + is be used to estimate the flops of sampled subnet. Defaults to + None, which means default config is used. + score_key (str): Specify one metric in evaluation results to score + candidates. Defaults to 'accuracy_top-1'. + init_candidates (str, optional): The candidates file path, which is + used to init `self.candidates`. Its format is usually in .yaml + format. Defaults to None. + """ + + def __init__(self, + runner, + dataloader: Union[DataLoader, Dict], + bn_dataloader, + evaluator: Union[Evaluator, Dict, List], + max_epochs: int = 20, + max_keep_ckpts: int = 3, + resume_from: Optional[str] = None, + num_candidates: int = 50, + top_k: int = 10, + num_mutation: int = 25, + num_crossover: int = 25, + mutate_prob: float = 0.1, + flops_range: Tuple[float, float] = (0.1, 0.9), + resource_estimator_cfg: Optional[dict] = None, + score_key: str = 'accuracy/top1', + init_candidates: Optional[str] = None) -> None: + if bn_dataloader['batch_size'] < 2: + bn_dataloader['batch_size'] = 2 + + super().__init__(runner, dataloader, evaluator, max_epochs, + max_keep_ckpts, resume_from, num_candidates, top_k, + num_mutation, num_crossover, mutate_prob, flops_range, + resource_estimator_cfg, score_key, init_candidates) + if isinstance(bn_dataloader, dict): + # Determine whether or not different ranks use different seed. + diff_rank_seed = runner._randomness_cfg.get( + 'diff_rank_seed', False) + self.bn_dataloader = runner.build_dataloader( + bn_dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed) + else: + self.bn_dataloader = bn_dataloader + self.flops_range: Tuple[float, float] = self._update_flop_range() + + def run_epoch(self) -> None: + super().run_epoch() + self._save_best_fix_subnet() + + def sample_candidates(self) -> None: + """Update candidate pool contains specified number of candicates.""" + if self.runner.rank == 0: + while len(self.candidates) < self.num_candidates: + candidate = self.model.sample_subnet() + passed, candidate = self._scale_and_check_subnet_constraints( + random_subnet=candidate) + if passed: + self.candidates.append(candidate) + else: + self.candidates = Candidates([None] * self.num_candidates) + # broadcast candidates to val with multi-GPUs. + broadcast_object_list(self.candidates.data) + + def gen_mutation_candidates(self) -> List: + """Generate specified number of mutation candicates.""" + mutation_candidates: List = [] + max_mutate_iters = self.num_mutation * 10 + mutate_iter = 0 + while len(mutation_candidates) < self.num_mutation: + mutate_iter += 1 + if mutate_iter > max_mutate_iters: + break + + mutation_candidate = self._mutation() + + passed, candidate = self._scale_and_check_subnet_constraints( + random_subnet=mutation_candidate) + if passed: + mutation_candidates.append(candidate) + return mutation_candidates + + def gen_crossover_candidates(self) -> List: + """Generate specofied number of crossover candicates.""" + crossover_candidates: List = [] + crossover_iter = 0 + max_crossover_iters = self.num_crossover * 10 + while len(crossover_candidates) < self.num_crossover: + crossover_iter += 1 + if crossover_iter > max_crossover_iters: + break + + crossover_candidate = self._crossover() + + passed, candidate = self._scale_and_check_subnet_constraints( + random_subnet=crossover_candidate) + if passed: + crossover_candidates.append(candidate) + return crossover_candidates + + def _save_best_fix_subnet(self): + """Save best subnet in searched top-k candidates.""" + if self.runner.rank == 0: + best_random_subnet = self.top_k_candidates.subnets[0] + self.model.set_subnet(best_random_subnet) + save_name = 'best_fix_subnet.json' + fileio.dump( + best_random_subnet, + osp.join(self.runner.work_dir, save_name), + indent=4) + self.runner.logger.info( + 'Search finished and ' + f'{save_name} saved in {self.runner.work_dir}.') + + @torch.no_grad() + def _val_candidate(self) -> Dict: + # bn rescale + len_img = 0 + self.runner.model.train() + for _, data_batch in enumerate(self.bn_dataloader): + data = self.runner.model.data_preprocessor(data_batch, True) + self.runner.model._run_forward(data, mode='tensor') # type: ignore + len_img += len(data_batch['data_samples']) + if len_img > 1000: + break + return super()._val_candidate() + + def _scale_and_check_subnet_constraints( + self, + random_subnet: SupportRandomSubnet, + auto_scale_times=5) -> Tuple[bool, SupportRandomSubnet]: + """Check whether is beyond constraints. + + Returns: + bool: The result of checking. + """ + is_pass = False + assert auto_scale_times >= 0 + for _ in range(auto_scale_times + 1): + flops = get_flops(self.model, random_subnet, self.estimator) + if self.check_subnet_flops(flops): + is_pass = True + break + else: + random_subnet = auto_scale( + random_subnet, + (self.flops_range[1] + self.flops_range[0]) / 2, flops) + continue + + return is_pass, random_subnet + + def _update_flop_range(self): + flops = get_flops(self.model, self.model.curent_subnet(), + self.estimator) + flops_range = [ratio * flops for ratio in self.flops_range] + return flops_range + + def check_subnet_flops(self, flops): + return self.flops_range[0] <= flops <= self.flops_range[ + 1] # type: ignore diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index e6258b012..da3f7c5cc 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -4,25 +4,13 @@ FpnTeacherDistill, OverhaulFeatureDistillation, SelfDistill, SingleTeacherDistill) from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP, Dsnas, DsnasDDP -from .pruning import SlimmableNetwork, SlimmableNetworkDDP +from .pruning import SearchWrapper, SlimmableNetwork, SlimmableNetworkDDP from .pruning.ite_prune_algorithm import ItePruneAlgorithm __all__ = [ - 'SingleTeacherDistill', - 'BaseAlgorithm', - 'FpnTeacherDistill', - 'SPOS', - 'SlimmableNetwork', - 'SlimmableNetworkDDP', - 'AutoSlim', - 'AutoSlimDDP', - 'Darts', - 'DartsDDP', - 'SelfDistill', - 'DataFreeDistillation', - 'DAFLDataFreeDistillation', - 'OverhaulFeatureDistillation', - 'ItePruneAlgorithm', - 'Dsnas', - 'DsnasDDP', + 'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS', + 'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP', + 'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation', + 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', + 'ItePruneAlgorithm', 'Dsnas', 'DsnasDDP', 'SearchWrapper' ] diff --git a/mmrazor/models/algorithms/pruning/__init__.py b/mmrazor/models/algorithms/pruning/__init__.py index 0b426146b..a416dddec 100644 --- a/mmrazor/models/algorithms/pruning/__init__.py +++ b/mmrazor/models/algorithms/pruning/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .search_wrapper import SearchWrapper from .slimmable_network import SlimmableNetwork, SlimmableNetworkDDP -__all__ = ['SlimmableNetwork', 'SlimmableNetworkDDP'] +__all__ = ['SlimmableNetwork', 'SlimmableNetworkDDP', 'SearchWrapper'] diff --git a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py index cca03a71f..553194c3a 100644 --- a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py +++ b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py @@ -113,6 +113,13 @@ def __init__(self, init_cfg: Optional[Dict] = None) -> None: super().__init__(architecture, data_preprocessor, init_cfg) + import torch.distributed as dist + if dist.is_initialized(): + self.architecture = nn.SyncBatchNorm.convert_sync_batchnorm( + self.architecture) + else: + from mmengine.model import revert_sync_batchnorm + self.architecture = revert_sync_batchnorm(self.architecture) # mutator self.mutator: ChannelMutator = MODELS.build(mutator_cfg) @@ -136,7 +143,6 @@ def forward(self, data_samples: Optional[List[BaseDataElement]] = None, mode: str = 'tensor') -> ForwardResults: """Forward.""" - print(self._epoch, self._iteration) if self.prune_config_manager.is_prune_time(self._epoch, self._iteration): diff --git a/mmrazor/models/algorithms/pruning/search_wrapper.py b/mmrazor/models/algorithms/pruning/search_wrapper.py new file mode 100644 index 000000000..04a7bbfaf --- /dev/null +++ b/mmrazor/models/algorithms/pruning/search_wrapper.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModel +from mmengine.structures import BaseDataElement + +from mmrazor.models.mutators import ChannelMutator +from mmrazor.registry import MODELS +from ..base import BaseAlgorithm + +LossResults = Dict[str, torch.Tensor] +TensorResults = Union[Tuple[torch.Tensor], torch.Tensor] +PredictResults = List[BaseDataElement] +ForwardResults = Union[LossResults, TensorResults, PredictResults] + + +@MODELS.register_module() +class SearchWrapper(BaseAlgorithm): + + def __init__(self, + architecture: Union[BaseModel, Dict], + mutator_cfg: Union[Dict, ChannelMutator] = dict( + type='ChannelMutator', + channel_unit_cfg=dict( + type='SequentialMutableChannelUnit')), + data_preprocessor: Optional[Union[Dict, nn.Module]] = None, + init_cfg: Optional[Dict] = None) -> None: + + super().__init__(architecture, data_preprocessor, init_cfg) + + # mutator + self.mutator: ChannelMutator = MODELS.build(mutator_cfg) + self.mutator.prepare_from_supernet(self.architecture) + + def sample_subnet(self): + return self.mutator.sample_choices() + + def set_subnet(self, chocies): + self.mutator.set_choices(chocies) + + def curent_subnet(self): + return self.mutator.current_choices diff --git a/mmrazor/models/architectures/dynamic_ops/__init__.py b/mmrazor/models/architectures/dynamic_ops/__init__.py index 620c9e4c8..94ef3f308 100644 --- a/mmrazor/models/architectures/dynamic_ops/__init__.py +++ b/mmrazor/models/architectures/dynamic_ops/__init__.py @@ -1,8 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .bricks.dynamic_conv import BigNasConv2d, DynamicConv2d, OFAConv2d +from .bricks.dynamic_conv import (BigNasConv2d, DynamicConv2d, + DynamicConv2dAdaptivePadding, OFAConv2d) from .bricks.dynamic_linear import DynamicLinear from .bricks.dynamic_norm import (DynamicBatchNorm1d, DynamicBatchNorm2d, - DynamicBatchNorm3d, SwitchableBatchNorm2d) + DynamicBatchNorm3d, DynamicBatchNormXd, + DynamicSyncBatchNorm, SwitchableBatchNorm2d) from .mixins.dynamic_conv_mixins import DynamicConvMixin from .mixins.dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin, DynamicLinearMixin, DynamicMixin) @@ -11,5 +13,7 @@ 'BigNasConv2d', 'DynamicConv2d', 'OFAConv2d', 'DynamicLinear', 'DynamicBatchNorm1d', 'DynamicBatchNorm2d', 'DynamicBatchNorm3d', 'DynamicMixin', 'DynamicChannelMixin', 'DynamicBatchNormMixin', - 'DynamicLinearMixin', 'SwitchableBatchNorm2d', 'DynamicConvMixin' + 'DynamicLinearMixin', 'SwitchableBatchNorm2d', 'DynamicConvMixin', + 'DynamicConv2dAdaptivePadding', 'DynamicSyncBatchNorm', + 'DynamicBatchNormXd' ] diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py index 71fc7ab98..0e2de465b 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py @@ -1,12 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. +import math from typing import Callable, Dict +import torch import torch.nn as nn import torch.nn.functional as F +from mmengine.registry import MODELS from torch import Tensor from mmrazor.models.mutables.base_mutable import BaseMutable -from mmrazor.registry import MODELS from ..mixins.dynamic_conv_mixins import (BigNasConvMixin, DynamicConvMixin, OFAConvMixin) @@ -39,31 +41,17 @@ def __init__(self, *args, **kwargs) -> None: def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d': """Convert an instance of nn.Conv2d to a new instance of DynamicConv2d.""" - # a group-wise conv will not be converted to dynamic conv - if module.groups > 1 and not (module.groups == module.out_channels == - module.in_channels): - global GroupWiseConvWarned - if GroupWiseConvWarned is False: - from mmengine import MMLogger - logger = MMLogger.get_current_instance() - logger.warning( - ('Group-wise convolutional layers are not supported to be' - 'pruned now, so they are not converted to new' - 'DynamicConvs.')) - GroupWiseConvWarned = True - - return module - else: - return cls( - in_channels=module.in_channels, - out_channels=module.out_channels, - kernel_size=module.kernel_size, - stride=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - bias=True if module.bias is not None else False, - padding_mode=module.padding_mode) + + return cls( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + bias=True if module.bias is not None else False, + padding_mode=module.padding_mode) @property def conv_func(self) -> Callable: @@ -188,3 +176,26 @@ def static_op_factory(self): def forward(self, x: Tensor) -> Tensor: """Forward of OFA's conv2d.""" return self.forward_mixin(x) + + +@MODELS.register_module() +class DynamicConv2dAdaptivePadding(DynamicConv2d): + """Dynamic version of mmcv.cnn.bricks.Conv2dAdaptivePadding.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + img_h, img_w = x.size()[-2:] + kernel_h, kernel_w = self.weight.size()[-2:] + stride_h, stride_w = self.stride + output_h = math.ceil(img_h / stride_h) + output_w = math.ceil(img_w / stride_w) + pad_h = ( + max((output_h - 1) * self.stride[0] + + (kernel_h - 1) * self.dilation[0] + 1 - img_h, 0)) + pad_w = ( + max((output_w - 1) * self.stride[1] + + (kernel_w - 1) * self.dilation[1] + 1 - img_w, 0)) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 + ]) + return super().forward(x) diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py index e3e795fa4..d0f9cf15b 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py @@ -1,9 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional +import torch import torch.nn as nn import torch.nn.functional as F +from mmengine.model.utils import _BatchNormXd from torch import Tensor +from torch.nn.modules._functions import SyncBatchNorm as sync_batch_norm from torch.nn.modules.batchnorm import _BatchNorm from mmrazor.models.mutables.base_mutable import BaseMutable @@ -190,3 +193,135 @@ def _check_candidates(self, candidates: List): def static_op_factory(self): """Return initializer of static op.""" return nn.BatchNorm2d + + +class DynamicSyncBatchNorm(nn.SyncBatchNorm, DynamicBatchNormMixin): + + def __init__(self, + num_features: int, + eps: float = 0.00001, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + process_group: Optional[Any] = None, + device=None, + dtype=None) -> None: + super().__init__(num_features, eps, momentum, affine, + track_running_stats, process_group, device, dtype) + self.mutable_attrs: Dict[str, Optional[BaseMutable]] = nn.ModuleDict() + + @classmethod + def convert_from(cls, module): + return cls(module.num_features, module.eps, module.momentum, + module.affine, module.track_running_stats, + module.process_group) + + @property + def static_op_factory(self): + return nn.SyncBatchNorm + + def forward(self, input: Tensor) -> Tensor: + # currently only GPU input is supported + if not input.is_cuda: + raise ValueError( + 'SyncBatchNorm expected input tensor to be on GPU') + + self._check_input_dim(input) + self._check_non_zero_input_channels(input) + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + assert self.num_batches_tracked is not None + self.num_batches_tracked.add_(1) + if self.momentum is None: # use cumulative moving average + exponential_average_factor = (1.0 / + self.num_batches_tracked.item()) + else: # use exponential moving average + exponential_average_factor = self.momentum + r""" + Decide whether the mini-batch stats should be used for normalization + rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when + buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is + None) + r""" + Buffers are only updated if they are to be tracked and we are in + training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when + they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + # If buffers are not to be tracked, ensure that they won't be updated + running_mean = ( + self.running_mean + if not self.training or self.track_running_stats else None) + running_var = ( + self.running_var + if not self.training or self.track_running_stats else None) + + # Don't sync batchnorm stats in inference mode (model.eval()). + need_sync = (bn_training and self.training) + if need_sync: + process_group = torch.distributed.group.WORLD + if self.process_group: + process_group = self.process_group + world_size = torch.distributed.get_world_size(process_group) + need_sync = world_size > 1 + + running_mean, running_var, weight, bias = self.get_dynamic_params() + + # fallback to framework BN when synchronization is not necessary + if not need_sync: + out = F.batch_norm( + input, + running_mean, + running_var, + weight, + bias, + bn_training, + exponential_average_factor, + self.eps, + ) + else: + assert bn_training + out = sync_batch_norm.apply( + input, + weight, + bias, + running_mean, + running_var, + self.eps, + exponential_average_factor, + process_group, + world_size, + ) + + # copy changed running statistics + if self.training and self.track_running_stats: + out_mask = self._get_num_features_mask() + self.running_mean.masked_scatter_(out_mask, running_mean) + self.running_var.masked_scatter_(out_mask, running_var) + + return out + + +class DynamicBatchNormXd(_DynamicBatchNorm): + + @property + def static_op_factory(self): + return _BatchNormXd + + def _check_input_dim(self, input: torch.Tensor): + return diff --git a/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_conv_mixins.py b/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_conv_mixins.py index e3ed46ded..eed6430c2 100644 --- a/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_conv_mixins.py +++ b/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_conv_mixins.py @@ -172,10 +172,22 @@ def _get_dynamic_params_by_mutable_channels( # depth-wise conv weight = weight[out_mask] else: - raise NotImplementedError( - 'Current `ChannelMutator` only support pruning the depth-wise ' - '`nn.Conv2d` or `nn.Conv2d` module whose group number equals ' - f'to one, but got {self.groups}.') + # group-wise conv + in_mask_ = in_mask.reshape([self.groups, -1]) # G in/G + in_per_group = in_mask_.sum(dim=-1)[0].item() + assert (in_mask_.sum(dim=-1) == in_per_group).all() + out_mask_ = out_mask.reshape([self.groups, -1]) # G out/G + out_per_group = out_mask_.sum(dim=-1)[0].item() + assert (out_mask_.sum(dim=-1) == out_per_group).all() + + mask = out_mask_.unsqueeze(-1) * in_mask_.unsqueeze( + -2) # G out/G in/G + mask = mask.flatten() + weight = weight.flatten(0, 1) + weight = weight[mask] + weight = weight.reshape( + [self.groups * out_per_group, in_per_group, *self.kernel_size]) + bias = self.bias[out_mask] if self.bias is not None else None return weight, bias diff --git a/mmrazor/models/mutables/derived_mutable.py b/mmrazor/models/mutables/derived_mutable.py index 98f680ee9..3cf2ce6e5 100644 --- a/mmrazor/models/mutables/derived_mutable.py +++ b/mmrazor/models/mutables/derived_mutable.py @@ -259,10 +259,10 @@ def dump_chosen(self) -> CHOICE_TYPE: Returns: Dict: Dumped information. """ - print_log( - 'Trying to dump chosen for derived mutable, ' - 'but its value depend on the source mutables.', - level=logging.WARNING) + # print_log( + # 'Trying to dump chosen for derived mutable, ' + # 'but its value depend on the source mutables.', + # level=logging.WARNING) return self.current_choice @property diff --git a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py index 28f1e4854..e5e708b7e 100644 --- a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py @@ -75,7 +75,10 @@ def fix_chosen(self, chosen=None): def dump_chosen(self): """dump current choice to a dict.""" - raise NotImplementedError() + mask = self.current_mask + mask = mask.bool() + mask = mask.tolist() + return mask def num_choices(self) -> int: """Number of available choices.""" diff --git a/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py b/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py index 9292d64c8..d67e89a9b 100644 --- a/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py +++ b/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py @@ -6,6 +6,7 @@ from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin from mmrazor.registry import MODELS from mmrazor.utils import IndexDict +from ..derived_mutable import DerivedMutable from .base_mutable_channel import BaseMutableChannel from .simple_mutable_channel import SimpleMutableChannel @@ -48,14 +49,26 @@ def current_choice(self) -> torch.Tensor: return mask.bool() @current_choice.setter - def current_choice(self, choice): + def current_choice(self, choices): """Set current choices. However, MutableChannelContainer doesn't support directly set mask. You can change the mask of MutableChannelContainer by changing its stored BaseMutableChannel. """ - raise NotImplementedError() + if isinstance(choices, list): + for choice, mutable in zip(choices, + self.mutable_channels.values()): + if isinstance(mutable, DerivedMutable): + continue + else: + mutable.current_choice = choice + + def dump_chosen(self): + chosen = [] + for mutable in self.mutable_channels.values(): + chosen.append(mutable.dump_chosen()) + return chosen @property def current_mask(self) -> torch.Tensor: diff --git a/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py index 7f949890c..8507d59a0 100644 --- a/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py @@ -30,6 +30,8 @@ def current_choice(self) -> torch.Tensor: @current_choice.setter def current_choice(self, choice: torch.Tensor): """Set current choice.""" + if isinstance(choice, list): + choice = torch.Tensor(choice).bool() self.mask = choice.to(self.mask.device).bool() @property diff --git a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py index 576412ec0..633014975 100644 --- a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py @@ -1,16 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -from typing import Dict, List +from typing import Dict import torch.nn as nn from mmengine.model import BaseModule -from mmrazor.structures.graph import ModuleGraph -from mmrazor.structures.graph.channel_graph import ChannelGraph -from mmrazor.structures.graph.channel_modules import (BaseChannel, - BaseChannelUnit) -from mmrazor.structures.graph.channel_nodes import \ - default_channel_node_converter +from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin +from mmrazor.registry import TASK_UTILS class Channel(BaseModule): @@ -25,7 +21,6 @@ class Channel(BaseModule): Channel. Defaults to None. is_output_channel (bool, optional): Is the channel output channel. Defaults to True. - expand_ratio (int, optional): Expand ratio of the mask. Defaults to 1. """ # init @@ -35,11 +30,10 @@ def __init__(self, module, index, node=None, - is_output_channel=True, - expand_ratio=1) -> None: + is_output_channel=True) -> None: super().__init__() self.name = name - self.module = module + self.module: nn.Module = module self.index = index self.start = index[0] self.end = index[1] @@ -47,7 +41,6 @@ def __init__(self, self.node = node self.is_output_channel = is_output_channel - self.expand_ratio = expand_ratio @classmethod def init_from_cfg(cls, model: nn.Module, config: Dict): @@ -56,29 +49,13 @@ def init_from_cfg(cls, model: nn.Module, config: Dict): name = config['name'] start = config['start'] end = config['end'] - expand_ratio = config['expand_ratio'] \ - if 'expand_ratio' in config else 1 is_output_channel = config['is_output_channel'] name2module = dict(model.named_modules()) name2module.pop('') module = name2module[name] if name in name2module else None return Channel( - name, - module, (start, end), - is_output_channel=is_output_channel, - expand_ratio=expand_ratio) - - @classmethod - def init_from_base_channel(cls, base_channel: BaseChannel): - """Init from a BaseChannel object.""" - return cls( - base_channel.name, - base_channel.module, - base_channel.index, - node=None, - is_output_channel=base_channel.is_output_channel, - expand_ratio=base_channel.expand_ratio) + name, module, (start, end), is_output_channel=is_output_channel) # config template @@ -89,7 +66,6 @@ def config_template(self): 'name': self.name, 'start': self.start, 'end': self.end, - 'expand_ratio': self.expand_ratio, 'is_output_channel': self.is_output_channel } @@ -103,29 +79,29 @@ def num_channels(self) -> int: @property def is_mutable(self) -> bool: """If the channel is prunable.""" - if isinstance(self.module, nn.Conv2d): - # group-wise conv - if self.module.groups != 1 and not (self.module.groups == - self.module.in_channels == - self.module.out_channels): - return False - return True + if self.module is not None: + has_prama = len(list(self.module.parameters())) != 0 + is_dynamic_op = isinstance(self.module, DynamicChannelMixin) + return (not has_prama) or is_dynamic_op + else: + is_unmutable = self.name in [ + 'input_placeholder', 'output_placeholder' + ] + return not is_unmutable def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'{self.name}, index={self.index}, ' f'is_output_channel=' f'{"true" if self.is_output_channel else "false"}, ' - f'expand_ratio={self.expand_ratio}' ')') def __eq__(self, obj: object) -> bool: - if isinstance(obj, BaseChannel): + if isinstance(obj, Channel): return self.name == obj.name \ and self.module == obj.module \ and self.index == obj.index \ and self.is_output_channel == obj.is_output_channel \ - and self.expand_ratio == obj.expand_ratio \ and self.node == obj.node else: return False @@ -200,30 +176,14 @@ def init_from_channel_unit(cls, return mutable_unit @classmethod - def init_from_graph(cls, - graph: ModuleGraph, - unit_args={}, - num_input_channel=3) -> List['ChannelUnit']: - """Parse a module-graph and get ChannelUnits.""" - - def init_from_base_channel_unit(base_channel_unit: BaseChannelUnit): - unit = cls(len(base_channel_unit.channel_elems), **unit_args) - unit.input_related = nn.ModuleList([ - Channel.init_from_base_channel(channel) - for channel in base_channel_unit.input_related - ]) - unit.output_related = nn.ModuleList([ - Channel.init_from_base_channel(channel) - for channel in base_channel_unit.output_related - ]) - return unit - - unit_graph = ChannelGraph.copy_from(graph, - default_channel_node_converter) - unit_graph.forward(num_input_channel) - units = unit_graph.collect_units() - units = [init_from_base_channel_unit(unit) for unit in units] - return units + def init_from_prune_tracer(cls, model, tracer=None): + if tracer is None: + from mmrazor.models.task_modules.tracer import PruneTracer + tracer = PruneTracer() + if isinstance(tracer, dict): + tracer = TASK_UTILS.build(tracer) + unit_config = tracer.trace(model) + return [cls.init_from_cfg(model, cfg) for cfg in unit_config.values()] # tools @@ -256,14 +216,18 @@ def config_template(self, def add_ouptut_related(self, channel: Channel): """Add a Channel which is output related.""" assert channel.is_output_channel - assert self.num_channels == channel.num_channels + assert self.num_channels == channel.num_channels or ( + channel.num_channels > self.num_channels + and channel.num_channels % self.num_channels == 0) if channel not in self.output_related: self.output_related.append(channel) def add_input_related(self, channel: Channel): """Add a Channel which is input related.""" assert channel.is_output_channel is False - assert self.num_channels == channel.num_channels + assert self.num_channels == channel.num_channels or ( + channel.num_channels > self.num_channels + and channel.num_channels % self.num_channels == 0) if channel not in self.input_related: self.input_related.append(channel) diff --git a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb index 5af2d496b..ad1bb77ce 100644 --- a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb +++ b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb @@ -36,20 +36,19 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# define a model\n", "from mmengine.model import BaseModel\n", "from torch import nn\n", - "import torch\n", "from collections import OrderedDict\n", "\n", - "class MyModel(BaseModel):\n", + "class MyModel(nn.Module):\n", "\n", " def __init__(self):\n", - " super().__init__(None, None)\n", + " super().__init__()\n", " self.net = nn.Sequential(\n", " OrderedDict([('conv0', nn.Conv2d(3, 8, 3, 1, 1)),\n", " ('relu', nn.ReLU()),\n", @@ -65,17 +64,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "This model has 4 MutableChannelUnit(SequentialMutableChannelUnit).\n" - ] - } - ], + "outputs": [], "source": [ "# There are multiple types of MutableChannelUnits. Here, We take SequentialMutableChannelUnit as the example.\n", "from mmrazor.models.mutables.mutable_channel.units import SequentialMutableChannelUnit\n", @@ -83,9 +74,8 @@ "from typing import List\n", "\n", "model = MyModel()\n", - "graph = ModuleGraph.init_from_backward_tracer(model)\n", "units: List[\n", - " SequentialMutableChannelUnit] = SequentialMutableChannelUnit.init_from_graph(graph) # type: ignore\n", + " SequentialMutableChannelUnit] = SequentialMutableChannelUnit.init_from_prune_tracer(model) # type: ignore\n", "print(\n", " f'This model has {len(units)} MutableChannelUnit(SequentialMutableChannelUnit).'\n", ")\n" @@ -93,26 +83,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "SequentialMutableChannelUnit(\n", - " name=net.conv0_(0, 8)_8\n", - " (output_related): ModuleList(\n", - " (0): Channel(net.conv0, index=(0, 8), is_output_channel=true, expand_ratio=1)\n", - " )\n", - " (input_related): ModuleList(\n", - " (0): Channel(net.conv1, index=(0, 8), is_output_channel=false, expand_ratio=1)\n", - " )\n", - " (mutable_channel): SquentialMutableChannel(num_channels=8, activated_channels=8)\n", - ")\n" - ] - } - ], + "outputs": [], "source": [ "unit1=units[1]\n", "print(unit1)" @@ -158,31 +131,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The current choice of unit1 is 8.\n", - "DynamicConv2d(\n", - " 3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)\n", - " (mutable_attrs): ModuleDict(\n", - " (in_channels): MutableChannelContainer(num_channels=3, activated_channels=3)\n", - " (out_channels): MutableChannelContainer(num_channels=8, activated_channels=8)\n", - " )\n", - ")\n", - "DynamicConv2d(\n", - " 8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)\n", - " (mutable_attrs): ModuleDict(\n", - " (in_channels): MutableChannelContainer(num_channels=8, activated_channels=8)\n", - " (out_channels): MutableChannelContainer(num_channels=16, activated_channels=16)\n", - " )\n", - ")\n" - ] - } - ], + "outputs": [], "source": [ "# We run \"prepare_for_pruning\" once before pruning to run step 1 and 2 above.\n", "unit1.prepare_for_pruning(model)\n", @@ -200,31 +151,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "We get a sampled choice 2.\n", - "DynamicConv2d(\n", - " 3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)\n", - " (mutable_attrs): ModuleDict(\n", - " (in_channels): MutableChannelContainer(num_channels=3, activated_channels=3)\n", - " (out_channels): MutableChannelContainer(num_channels=8, activated_channels=2)\n", - " )\n", - ")\n", - "DynamicConv2d(\n", - " 8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)\n", - " (mutable_attrs): ModuleDict(\n", - " (in_channels): MutableChannelContainer(num_channels=8, activated_channels=2)\n", - " (out_channels): MutableChannelContainer(num_channels=16, activated_channels=16)\n", - " )\n", - ")\n" - ] - } - ], + "outputs": [], "source": [ "sampled_choice=unit1.sample_choice()\n", "print(f'We get a sampled choice {sampled_choice}.')\n", @@ -264,22 +193,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The model has 4 MutableChannelUnits.\n" - ] - } - ], + "outputs": [], "source": [ "# 1. using tracer\n", "def get_mutable_channel_units_using_tracer(model):\n", - " graph = ModuleGraph.init_from_backward_tracer(model)\n", - " units = SequentialMutableChannelUnit.init_from_graph(graph)\n", + " units = SequentialMutableChannelUnit.init_from_prune_tracer(model)\n", " return units\n", "\n", "\n", @@ -290,26 +210,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "SequentialMutableChannelUnit(\n", - " name=net.conv0_(0, 8)_8\n", - " (output_related): ModuleList(\n", - " (0): Channel(net.conv0, index=(0, 8), is_output_channel=true, expand_ratio=1)\n", - " )\n", - " (input_related): ModuleList(\n", - " (0): Channel(net.conv1, index=(0, 8), is_output_channel=false, expand_ratio=1)\n", - " )\n", - " (mutable_channel): SquentialMutableChannel(num_channels=8, activated_channels=8)\n", - ")\n" - ] - } - ], + "outputs": [], "source": [ "# 2. using config\n", "config = {\n", @@ -332,17 +235,9 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The model has 2 MutableChannelUnits.\n" - ] - } - ], + "outputs": [], "source": [ "# 3. using predefined model\n", "\n", @@ -391,7 +286,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.12 ('mmlab')", + "display_name": "Python 3.9.13 ('lab2max')", "language": "python", "name": "python3" }, @@ -405,12 +300,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.9.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "feec882ee78c63cb8d4b485f1b52bbb873bb9a7b094435863200c7afba202382" + "hash": "e31a827d0913016ad78e01c7b97f787f4b9e53102dd62d238e8548bcd97ff875" } } }, diff --git a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py index 59039cd83..8099305de 100644 --- a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py @@ -291,7 +291,7 @@ def _register_mutable_channel(self, mutable_channel: BaseMutableChannel): mutable_channel is mutable for mutable in source_mutables ] - assert any(is_same) + assert any(is_same), 'existed a mutable channel.' else: container.register_mutable(mutable_channel_, start, end) diff --git a/mmrazor/models/mutables/mutable_channel/units/sequential_mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/sequential_mutable_channel_unit.py index 89a25d236..89dc785ed 100644 --- a/mmrazor/models/mutables/mutable_channel/units/sequential_mutable_channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/sequential_mutable_channel_unit.py @@ -3,7 +3,11 @@ from typing import Dict, Union import torch.nn as nn +from mmcv.cnn.bricks import Conv2dAdaptivePadding from mmengine import MMLogger +from mmengine.model.utils import _BatchNormXd +from mmengine.utils.dl_utils.parrots_wrapper import \ + SyncBatchNorm as EngineSyncBatchNorm from mmrazor.models.architectures import dynamic_ops from mmrazor.models.utils import make_divisible @@ -60,9 +64,14 @@ def prepare_for_pruning(self, model: nn.Module): # register MutableMask self._replace_with_dynamic_ops( model, { + Conv2dAdaptivePadding: + dynamic_ops.DynamicConv2dAdaptivePadding, nn.Conv2d: dynamic_ops.DynamicConv2d, nn.BatchNorm2d: dynamic_ops.DynamicBatchNorm2d, - nn.Linear: dynamic_ops.DynamicLinear + nn.Linear: dynamic_ops.DynamicLinear, + nn.SyncBatchNorm: dynamic_ops.DynamicSyncBatchNorm, + EngineSyncBatchNorm: dynamic_ops.DynamicSyncBatchNorm, + _BatchNormXd: dynamic_ops.DynamicBatchNormXd, }) self._register_channel_container(model, MutableChannelContainer) self._register_mutable_channel(self.mutable_channel) diff --git a/mmrazor/models/mutables/mutable_channel/units/utils.py b/mmrazor/models/mutables/mutable_channel/units/utils.py new file mode 100644 index 000000000..04982cfda --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/units/utils.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import List + +import torch + +from mmrazor.models.mutables.mutable_channel.units import \ + SequentialMutableChannelUnit +from mmrazor.utils import demo_inputs + + +def assert_model_is_changed(tensors1, tensors2): + shape1 = get_shape(tensors1, only_length=True) + shape2 = get_shape(tensors2, only_length=True) + assert shape1 == shape2, f'{shape1}!={shape2}' + + +def get_shape(tensor, only_length=False): + if isinstance(tensor, torch.Tensor): + if only_length: + return len(tensor.shape) + else: + return tensor.shape + elif isinstance(tensor, list) or isinstance(tensor, tuple): + shapes = [] + for x in tensor: + shapes.append(get_shape(x, only_length)) + return shapes + elif isinstance(tensor, dict): + shapes = {} + for key in tensor: + shapes[key] = get_shape(tensor[key], only_length) + return shapes + else: + raise NotImplementedError( + f'unsuppored type{type(tensor)} to get shape of tensors.') + + +def forward_units(model, try_units: List[SequentialMutableChannelUnit], + units: List[SequentialMutableChannelUnit], template_output): + model.eval() + for unit in units: + unit.current_choice = 1.0 + for unit in try_units: + unit.current_choice = min(max(0.1, unit.sample_choice()), 0.9) + inputs = demo_inputs(model, [1, 3, 224, 224]) + if isinstance(inputs, dict): + inputs['mode'] = 'loss' + tensors = model(**inputs) + else: + tensors = model(inputs) + assert_model_is_changed(template_output, tensors) + + +def find_mutable(model, try_units, units, template_tensors): + if len(try_units) == 0: + return [] + try: + forward_units(model, try_units, units, template_tensors) + return try_units + except Exception: + if len(try_units) == 1: + return [] + else: + num = len(try_units) + return find_mutable(model, try_units[:num // 2], units, + template_tensors) + find_mutable( + model, try_units[num // 2:], units, + template_tensors) diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb b/mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb index 1d7aad669..fb7086368 100644 --- a/mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb +++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb @@ -25,9 +25,18 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "# define a model\n", "from mmengine.model import BaseModel\n", @@ -35,10 +44,10 @@ "import torch\n", "from collections import OrderedDict\n", "\n", - "class MyModel(BaseModel):\n", + "class MyModel(nn.Module):\n", "\n", " def __init__(self):\n", - " super().__init__(None, None)\n", + " super().__init__()\n", " self.net = nn.Sequential(\n", " OrderedDict([('conv0', nn.Conv2d(3, 8, 3, 1, 1)),\n", " ('relu', nn.ReLU()),\n", @@ -63,13 +72,15 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ + "11/14 14:24:13 - mmengine - \u001b[5m\u001b[4m\u001b[33mWARNING\u001b[0m - add a input before net.conv0(net.conv0), error: net.conv0(net.conv0)\n", + "11/14 14:24:13 - mmengine - \u001b[5m\u001b[4m\u001b[33mWARNING\u001b[0m - add a output after head(head), error: head(head)\n", "The mutator has 2 mutable channel units.\n" ] } @@ -86,8 +97,7 @@ " units={},\n", " ),\n", " parse_cfg=dict(\n", - " type='BackwardTracer',\n", - " loss_calculator=dict(type='ImageClassifierPseudoLoss')))\n", + " type='PruneTracer'))\n", "# init the ChannelMutator object with a model\n", "mutator.prepare_from_supernet(model)\n", "print(f'The mutator has {len(mutator.mutable_units)} mutable channel units.')" @@ -116,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -175,7 +185,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -258,7 +268,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -282,7 +292,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -337,7 +347,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.12 ('mmlab')", + "display_name": "Python 3.9.13 ('lab2max')", "language": "python", "name": "python3" }, @@ -351,12 +361,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.9.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "feec882ee78c63cb8d4b485f1b52bbb873bb9a7b094435863200c7afba202382" + "hash": "e31a827d0913016ad78e01c7b97f787f4b9e53102dd62d238e8548bcd97ff875" } } }, diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.py b/mmrazor/models/mutators/channel_mutator/channel_mutator.py index c04aa0204..35b3a5d82 100644 --- a/mmrazor/models/mutators/channel_mutator/channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.py @@ -5,20 +5,15 @@ from mmengine import fileio from torch.nn import Module -from mmrazor.models.architectures.dynamic_ops import DynamicChannelMixin from mmrazor.models.mutables import (ChannelUnitType, MutableChannelUnit, SequentialMutableChannelUnit) from mmrazor.models.mutables.mutable_channel.units.channel_unit import \ ChannelUnit -from mmrazor.registry import MODELS -from mmrazor.structures.graph import ModuleGraph +from mmrazor.models.task_modules.tracer.prune_tracer import PruneTracer +from mmrazor.registry import MODELS, TASK_UTILS from ..base_mutator import BaseMutator -def is_dynamic_op_for_fx_tracer(module, name): - return isinstance(module, DynamicChannelMixin) - - @MODELS.register_module() class ChannelMutator(BaseMutator, Generic[ChannelUnitType]): """ChannelMutator manages the pruning structure of a model. @@ -68,17 +63,16 @@ def __init__(self, dict, Type[MutableChannelUnit]] = SequentialMutableChannelUnit, parse_cfg: Dict = dict( - type='BackwardTracer', - loss_calculator=dict(type='ImageClassifierPseudoLoss')), + type='PruneTracer', + input_shape=(1, 3, 224, 224), + tracer_type='BackwardTracer'), init_cfg: Optional[Dict] = None) -> None: super().__init__(init_cfg) # tracer if isinstance(parse_cfg, dict): - assert parse_cfg['type'] in [ - 'RazorFxTracer', 'BackwardTracer', 'Config', 'Predefined' - ] + assert parse_cfg['type'] in ['PruneTracer', 'Config', 'Predefined'] self.parse_cfg = parse_cfg # units @@ -98,7 +92,6 @@ def prepare_from_supernet(self, supernet: Module) -> None: 1. parse the model and get MutableChannelUnits. 2. call unit.prepare_for_pruning for each unit. """ - self._name2module = dict(supernet.named_modules()) if 'Tracer' in self.parse_cfg['type']: @@ -273,20 +266,18 @@ def _parse_channel_unit_cfg( def _prepare_from_tracer(self, model: Module, parse_cfg: Dict): """Initialize units using a tracer.""" - if 'num_input_channel' in parse_cfg: - num_input_channel = parse_cfg.pop('num_input_channel') - else: - num_input_channel = 3 - if self.parse_cfg['type'] == 'BackwardTracer': - graph = ModuleGraph.init_from_backward_tracer(model, parse_cfg) - elif self.parse_cfg['type'] == 'RazorFxTracer': - graph = ModuleGraph.init_from_fx_tracer(model, fx_tracer=parse_cfg) + + if isinstance(parse_cfg, Dict): + tracer: PruneTracer = TASK_UTILS.build(parse_cfg) else: - raise NotImplementedError() - self._graph = graph + tracer = parse_cfg + unit_configs = tracer.trace(model) + # get ChannelUnits - units = ChannelUnit.init_from_graph( - graph, num_input_channel=num_input_channel) + units = [ + ChannelUnit.init_from_cfg(model, cfg) + for cfg in unit_configs.values() + ] # convert to MutableChannelUnits units = self._convert_channel_unit_to_mutable(units) return units diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/conv_layer_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/conv_layer_counter.py index 959d88fa4..879fb456f 100644 --- a/mmrazor/models/task_modules/estimators/counters/op_counters/conv_layer_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/conv_layer_counter.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np +import torch.nn as nn from mmrazor.registry import TASK_UTILS from .base_counter import BaseCounter @@ -59,3 +60,42 @@ class Conv2dCounter(ConvCounter): class Conv3dCounter(ConvCounter): """FLOPs/params counter for Conv3d module.""" pass + + +@TASK_UTILS.register_module() +class DynamicConv2dCounter(ConvCounter): + + @staticmethod + def add_count_hook(module: nn.Conv2d, input, output): + + input = input[0] + + batch_size = input.shape[0] + output_dims = list(output.shape[2:]) + + kernel_dims = list(module.kernel_size) + + out_channels = module.mutable_attrs['out_channels'].activated_channels + in_channels = module.mutable_attrs['in_channels'].activated_channels + + groups = module.groups + + filters_per_channel = out_channels / groups + conv_per_position_flops = int( + np.prod(kernel_dims)) * in_channels * filters_per_channel + + active_elements_count = batch_size * int(np.prod(output_dims)) + + overall_conv_flops = conv_per_position_flops * active_elements_count + overall_params = conv_per_position_flops + + bias_flops = 0 + overall_params = conv_per_position_flops + if module.bias is not None: + bias_flops = out_channels * active_elements_count + overall_params += out_channels + + overall_flops = overall_conv_flops + bias_flops + + module.__flops__ += overall_flops + module.__params__ += int(overall_params) diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/linear_layer_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/linear_layer_counter.py index f8e9ea8fb..80c024c09 100644 --- a/mmrazor/models/task_modules/estimators/counters/op_counters/linear_layer_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/linear_layer_counter.py @@ -18,3 +18,8 @@ def add_count_hook(module, input, output): -1] # pytorch checks dimensions, so here we don't care much module.__flops__ += int(np.prod(input.shape) * output_last_dim) module.__params__ += get_model_parameters_number(module) + + +@TASK_UTILS.register_module() +class DynamicLinearCounter(LinearCounter): + pass diff --git a/mmrazor/models/task_modules/tracer/__init__.py b/mmrazor/models/task_modules/tracer/__init__.py index a9a6fde52..28b476a69 100644 --- a/mmrazor/models/task_modules/tracer/__init__.py +++ b/mmrazor/models/task_modules/tracer/__init__.py @@ -1,11 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. from .backward_tracer import BackwardTracer +# from .razor_tracer import RazorFxTracer from .loss_calculator import * # noqa: F401,F403 from .parsers import * # noqa: F401,F403 from .path import (Path, PathConcatNode, PathConvNode, PathDepthWiseConvNode, PathLinearNode, PathList, PathNode, PathNormNode) +from .prune_tracer import PruneTracer __all__ = [ 'BackwardTracer', 'PathConvNode', 'PathLinearNode', 'PathNormNode', - 'PathConcatNode', 'Path', 'PathList', 'PathNode', 'PathDepthWiseConvNode' + 'PathConcatNode', 'Path', 'PathList', 'PathNode', 'PathDepthWiseConvNode', + 'PruneTracer' ] diff --git a/mmrazor/models/task_modules/tracer/fx_tracer.py b/mmrazor/models/task_modules/tracer/fx_tracer.py new file mode 100644 index 000000000..439f7ed4e --- /dev/null +++ b/mmrazor/models/task_modules/tracer/fx_tracer.py @@ -0,0 +1,233 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This module define FxTracer and related classes.""" + +import copy +import functools +from types import FunctionType +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +import torch +from torch._C import ScriptObject # type: ignore[attr-defined] +from torch.fx._symbolic_trace import (Tracer, _autowrap_check, + _orig_module_call, _orig_module_getattr, + _patch_wrapped_functions, _Patcher) +from torch.fx.graph import Graph +from torch.fx.node import Argument +from torch.fx.proxy import Proxy + + +class FxTracer(Tracer): + """CostumFxTracer allow user to indicate leaf module.""" + + def __init__(self, + autowrap_modules: Tuple = (), + autowrap_functions: Tuple[Callable, ...] = (), + param_shapes_constant: bool = False) -> None: + super().__init__(autowrap_modules, autowrap_functions, + param_shapes_constant) + + from mmdet.models.dense_heads.base_dense_head import BaseDenseHead + from mmdet.models.dense_heads.rpn_head import RPNHead + from mmdet.models.roi_heads import StandardRoIHead + self.warp_method = { + RPNHead: RPNHead.predict_by_feat, + BaseDenseHead: BaseDenseHead.predict_by_feat, + StandardRoIHead: StandardRoIHead.forward, + } + self.warp_fn = { + torch: torch.arange, + } + + def trace(self, + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None) -> Graph: + if concrete_args is None: + concrete_args = {} + concrete_args = copy.copy(concrete_args) + return self._trace(root, concrete_args) + + def _trace(self, + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None) -> Graph: + if isinstance(root, torch.nn.Module): + self.root = root + + assert hasattr(type(root), self.traced_func_name), ( + f"traced_func_name={self.traced_func_name} doesn't exist in" + ' {type(root).__name__}') + + fn = getattr(type(root), self.traced_func_name) + self.submodule_paths = { + mod: name + for name, mod in root.named_modules() + } + else: + self.root = torch.nn.Module() + fn = root + + tracer_cls: Optional[Type['Tracer']] = getattr(self, '__class__', None) + self.graph = Graph(tracer_cls=tracer_cls) + + # When we encounter a Tensor value that's not a parameter, we look + # if it + # is some other attribute on the model. Construct a dict mapping + # Tensor + # values to the qualified name here for efficiency. This is used + # downstream + # in create_arg + self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {} + + def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): + for k, v in m.__dict__.items(): + if isinstance(v, (torch.Tensor, ScriptObject)): + self.tensor_attrs[v] = '.'.join(prefix_atoms + [k]) + for k, v in m.named_children(): + collect_tensor_attrs(v, prefix_atoms + [k]) + + collect_tensor_attrs(self.root, []) + + assert isinstance(fn, FunctionType) + + fn_globals = fn.__globals__ # run before it gets patched + fn, args = self.create_args_for_root(fn, + isinstance(root, torch.nn.Module), + concrete_args) + + parameter_proxy_cache: Dict[str, Proxy] = { + } # Reduce number of get_attr calls + + # Method dispatch on parameters is not recorded unless it's directly + # used. + # Thus, we need to insert a proxy when __getattr__ requests a + # parameter. + @functools.wraps(_orig_module_getattr) + def module_getattr_wrapper(mod, attr): + attr_val = _orig_module_getattr(mod, attr) + return self._module_getattr(attr, attr_val, parameter_proxy_cache) + + @functools.wraps(_orig_module_call) + def module_call_wrapper(mod, *args, **kwargs): + + def forward(*args, **kwargs): + return _orig_module_call(mod, *args, **kwargs) + + _autowrap_check( + patcher, + getattr(getattr(mod, 'forward', mod), '__globals__', {}), + self._autowrap_function_ids) + return self.call_module(mod, forward, args, kwargs) + + with _Patcher() as patcher: + # allow duplicate patches to support the case of nested calls + patcher.patch_method( + torch.nn.Module, + '__getattr__', + module_getattr_wrapper, + deduplicate=False) + patcher.patch_method( + torch.nn.Module, + '__call__', + module_call_wrapper, + deduplicate=False) + for obj, mth in self.warp_method.items(): + patcher.patch_method( + obj, + mth.__name__, + self.warp_a_method(obj, mth), + deduplicate=False) + for obj, mth in self.warp_fn.items(): + patcher.patch_method( + obj, + mth.__name__, + self.warp_a_function(obj, mth), + deduplicate=False) + _patch_wrapped_functions(patcher) + _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) + for module in self._autowrap_search: + _autowrap_check(patcher, module.__dict__, + self._autowrap_function_ids) + self.create_node( + 'output', + 'output', (self.create_arg(fn(*args)), ), {}, + type_expr=fn.__annotations__.get('return', None)) + + self.submodule_paths = None # type: ignore + + return self.graph + + def call_method(self, origin_fn, name, args: tuple, kwargs): + args = args[1:] + return self.create_proxy('call_function', origin_fn, args, kwargs, + name) + + def call_function(self, origin_fn, name, args, kwargs): + return self.create_proxy('call_function', origin_fn, args, kwargs, + name) + + def warp_a_method(self, obj, origin_fn): + + @functools.wraps(origin_fn) + def fn_wrapper(*args, **kwargs): + return self.call_method(origin_fn, origin_fn.__name__, args, + kwargs) + + return fn_wrapper + + def warp_a_function(self, obj, origin_fn): + + @functools.wraps(origin_fn) + def fn_wrapper(*args, **kwargs): + return self.call_function(origin_fn, origin_fn.__name__, args, + kwargs) + + return fn_wrapper + + def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], + args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: + module_qualified_name = self.path_of_module(m) + try: + proxy = super().call_module(m, forward, args, kwargs) + return proxy + except Exception as e: + module_qualified_name = self.path_of_module(m) + from mmengine import MMLogger + MMLogger.get_current_instance().warning( + f'{module_qualified_name}({type(m)}) encounter error when' + ' tracing. ' + f'It will be treated as a leaf module.\n {e}') + return self.create_proxy('call_module', module_qualified_name, + args, kwargs) + + def create_arg(self, a: Any) -> 'Argument': + try: + arg = super().create_arg(a) + return arg + except Exception: + return a + + +class CustomFxTracer(FxTracer): + + def __init__( + self, + autowrap_modules: Tuple = (), + autowrap_functions: Tuple[Callable, ...] = (), + param_shapes_constant: bool = False, + leaf_module: Tuple = (), + ) -> None: + super().__init__(autowrap_modules, autowrap_functions, + param_shapes_constant) + + self.leaf_module = leaf_module + + def is_leaf_module(self, m: torch.nn.Module, + module_qualified_name: str) -> bool: + is_torch_module = super().is_leaf_module(m, module_qualified_name) + + is_leaf = False + for module_type in self.leaf_module: + if isinstance(m, module_type): + is_leaf = True + break + + return is_leaf or is_torch_module diff --git a/mmrazor/models/task_modules/tracer/loss_calculator/__init__.py b/mmrazor/models/task_modules/tracer/loss_calculator/__init__.py index 0371a713a..91a004f2c 100644 --- a/mmrazor/models/task_modules/tracer/loss_calculator/__init__.py +++ b/mmrazor/models/task_modules/tracer/loss_calculator/__init__.py @@ -2,5 +2,9 @@ from .image_classifier_loss_calculator import ImageClassifierPseudoLoss from .single_stage_detector_loss_calculator import \ SingleStageDetectorPseudoLoss +from .sum_loss_calculator import SumPseudoLoss -__all__ = ['ImageClassifierPseudoLoss', 'SingleStageDetectorPseudoLoss'] +__all__ = [ + 'ImageClassifierPseudoLoss', 'SingleStageDetectorPseudoLoss', + 'SumPseudoLoss' +] diff --git a/mmrazor/models/task_modules/tracer/loss_calculator/sum_loss_calculator.py b/mmrazor/models/task_modules/tracer/loss_calculator/sum_loss_calculator.py new file mode 100644 index 000000000..00a1103ad --- /dev/null +++ b/mmrazor/models/task_modules/tracer/loss_calculator/sum_loss_calculator.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmrazor.registry import TASK_UTILS + + +@TASK_UTILS.register_module() +class SumPseudoLoss: + """Calculate the pseudo loss to trace the topology of a `ImageClassifier` + in MMClassification with `BackwardTracer`. + + Args: + input_shape (Tuple): The shape of the pseudo input. Defaults to + (2, 3, 224, 224). + """ + + def __init__(self, input_shape=(2, 3, 224, 224)): + self.input_shape = input_shape + + def __call__(self, model) -> torch.Tensor: + pseudo_img = torch.rand(self.input_shape) + model.eval() + pseudo_output = model(pseudo_img) + return self._sum_of_output(pseudo_output) + + def _sum_of_output(self, tensor): + + if isinstance(tensor, torch.Tensor): + return tensor.sum() + elif isinstance(tensor, list) or isinstance(tensor, tuple): + loss = 0 + for t in tensor: + loss = loss + self._sum_of_output(t) + return loss + elif isinstance(tensor, dict): + loss = 0 + for t in tensor.values(): + loss = loss + self._sum_of_output(t) + return loss + else: + raise NotImplementedError( + f'unsuppored type{type(tensor)} to get shape of tensors.') diff --git a/mmrazor/models/task_modules/tracer/prune_tracer.py b/mmrazor/models/task_modules/tracer/prune_tracer.py new file mode 100644 index 000000000..6d8200fcd --- /dev/null +++ b/mmrazor/models/task_modules/tracer/prune_tracer.py @@ -0,0 +1,141 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import copy +from typing import Dict, List + +import torch.nn as nn +from mmcv.cnn.bricks import Scale +from mmengine.model.utils import revert_sync_batchnorm + +from mmrazor.models.architectures.dynamic_ops import DynamicChannelMixin +from mmrazor.models.mutables.mutable_channel import ( + MutableChannelUnit, SequentialMutableChannelUnit) +from mmrazor.models.mutables.mutable_channel.units.utils import find_mutable +from mmrazor.registry import TASK_UTILS +from mmrazor.structures.graph import BaseGraph, ModuleGraph +from mmrazor.structures.graph.channel_graph import ( + ChannelGraph, default_channel_node_converter) +from mmrazor.structures.graph.module_graph import (FxTracerToGraphConverter, + PathToGraphConverter) +from mmrazor.utils import demo_inputs +from .backward_tracer import BackwardTracer +from .fx_tracer import CustomFxTracer +from .loss_calculator.sum_loss_calculator import SumPseudoLoss +from .razor_tracer import FxBaseNode, RazorFxTracer + +# where to config prune tracer +""" +- How to config PruneTracer using hard code + - fxtracer + - concrete args + - demo_inputs + - leaf module + - PruneTracer.default_leaf_modules + - method + - None + - ChannelNode + - channel_nodes.py + - DynamicOp + ChannelUnits +""" + +# concrete args + + +@TASK_UTILS.register_module() +class PruneTracer: + + default_leaf_modules = ( + # dynamic op + DynamicChannelMixin, + # torch + nn.Conv2d, + nn.Linear, + nn.modules.batchnorm._BatchNorm, + # mmcv + Scale, + ) + + def __init__(self, + input_shape=(1, 3, 224, 224), + tracer_type='BackwardTracer') -> None: + + self.input_shape = input_shape + + assert tracer_type in ['BackwardTracer', 'FxTracer'] + self.tracer_type = tracer_type + if tracer_type == 'BackwardTracer': + self.tracer = BackwardTracer( + loss_calculator=SumPseudoLoss(input_shape=input_shape)) + elif tracer_type == 'FxTracer': + self.tracer = CustomFxTracer(leaf_module=self.default_leaf_modules) + else: + raise NotImplementedError() + + def trace(self, model): + model = copy.deepcopy(model) + model = revert_sync_batchnorm(model) + model.eval() + if self.tracer_type == 'BackwardTracer': + path_list = self.tracer.trace(model) + module_graph: ModuleGraph = PathToGraphConverter(path_list, + model).graph + elif self.tracer_type == 'FxTracer': + fx_graph = self._fx_trace(model) + fx_graph.owning_module = model + fx_graph.graph = BaseGraph[FxBaseNode]() + base_graph = RazorFxTracer().parse_torch_graph(fx_graph) + + module_graph = FxTracerToGraphConverter(base_graph, model).graph + module_graph._model = model + else: + raise NotImplementedError() + + module_graph.refresh_module_name() + module_graph.check(fix=True) + module_graph.check() + + channel_graph = ChannelGraph.copy_from(module_graph, + default_channel_node_converter) + channel_graph.check(fix=True) + channel_graph.check() + + channel_graph.forward(self.input_shape[1]) + unit_configs = channel_graph.generate_units_config() + + return self._find_mutable_units(model, unit_configs) + + def _fx_trace(self, model): + args = demo_inputs(model, self.input_shape) + if isinstance(args, dict): + args.pop('inputs') + return self.tracer.trace(model, concrete_args=args) + else: + return self.tracer.trace(model) + + def _find_mutable_units(self, model, units_config: Dict): + model = copy.deepcopy(model) + units: List[SequentialMutableChannelUnit] = [ + SequentialMutableChannelUnit.init_from_cfg(model, cfg) + for cfg in units_config.values() + ] + for unit in units: + unit.prepare_for_pruning(model) + mutable_units = [unit for unit in units if unit.is_mutable] + inputs = demo_inputs(model, [1, 3, 224, 224]) + model.eval() + + if isinstance(inputs, dict): + inputs['mode'] = 'loss' + template_output = model(**inputs) + else: + template_output = model(inputs) + + mutable_units = find_mutable(model, mutable_units, units, + template_output) + mutable_unit_config = {} + for unit in mutable_units: + mutable_unit_config[ + unit.name] = MutableChannelUnit.config_template( + unit, with_channels=True, with_init_args=True) + return mutable_unit_config diff --git a/mmrazor/models/task_modules/tracer/razor_tracer.py b/mmrazor/models/task_modules/tracer/razor_tracer.py new file mode 100644 index 000000000..5d92ff674 --- /dev/null +++ b/mmrazor/models/task_modules/tracer/razor_tracer.py @@ -0,0 +1,127 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This module define FxTracer and related classes.""" + +from typing import Callable + +import torch +import torch.fx as fx +import torch.nn as nn +from torch.fx.node import Node as FxNode + +from mmrazor.registry import TASK_UTILS +from mmrazor.structures.graph.base_graph import BaseGraph, BaseNode +from .fx_tracer import FxTracer + + +class FxBaseNode(BaseNode): + """Node to record FxNode.""" + + def __init__(self, name: str, val: FxNode) -> None: + super().__init__(name, val) + + def module(self): + """Union[Module | None]: the module the fxnode corresponding to.""" + self.val: FxNode + model = self.val.graph.owning_module + if self.val.op == 'call_module': + target = self.val.target + target = target.split('.') + obj = model + for t in target: + obj = getattr(obj, t) + return obj + else: + return None + + def function(self): + """Union[Callable | Node]: the function the fxnode corresponding to.""" + if self.is_function(): + return self.val.target + else: + return None + + def method(self): + if self.is_method(): + return self.val.target + else: + return None + + # base type + # placeholder|call_method|call_module|call_function|get_attr|output + + def is_function(self): + """Bool: if the fxnode represents 'call_function'""" + return self.val.op == 'call_function' + + def is_module(self): + """Bool: if the fxnode represents 'call_module'""" + return self.val.op == 'call_module' + + def is_input(self): + """Bool: if the fxnode represents input or output tensors""" + return self.val.op == 'placeholder' + + def is_output(self): + return self.val.op == 'output' + + def is_method(self): + """Bool: if the fxnode represents 'call_method'""" + return self.val.op == 'call_method' + + def is_get_attr(self): + """Bool: if the fxnode represents 'get_attr'""" + return self.val.op == 'get_attr' + + # extended type + + def is_cat(self): + """Bool: if the fxnode represents a cat node""" + return self.is_function() and self.function() is torch.cat + + # other + + def __repr__(self) -> str: + return f'{self.name}({self.val.op})' + + +@TASK_UTILS.register_module() +class RazorFxTracer(FxTracer): + """A wapper for torch.fx.tracer.""" + + def __init__(self, + is_extra_leaf_module: Callable[[nn.Module, str], bool] = None, + concrete_args={}) -> None: + if isinstance(is_extra_leaf_module, dict): + is_extra_leaf_module = TASK_UTILS.build(is_extra_leaf_module) + + super().__init__() + + def add_node(self, graph: BaseGraph[FxBaseNode], fxnode: FxNode): + """FxBaseNode: convert a torch FxNode to a FxBaseNode, and add it the + self.graph""" + node = graph.add_or_find_node(FxBaseNode(fxnode.name, fxnode)) + return node + + def parse_torch_graph(self, torch_graph: fx.graph.Graph): + """None: convert torch graph to self.graph""" + + graph = BaseGraph[FxBaseNode]() + # copy_nodes + for fxnode in torch_graph.nodes: + self.add_node(graph, fxnode) + + # connect nodes + for fxnode in torch_graph.nodes: + for pre_node in fxnode.all_input_nodes: + graph.connect( + self.add_node(graph, pre_node), + self.add_node(graph, fxnode)) + + return graph + + def trace(self, model) -> BaseGraph[FxBaseNode]: # type: ignore + torch_graph = super().trace(model) + torch_graph.owning_module = model + + self.graph = BaseGraph[FxBaseNode]() + return self.parse_torch_graph(torch_graph) diff --git a/mmrazor/models/utils/make_divisible.py b/mmrazor/models/utils/make_divisible.py index 5fda15591..08b05e90b 100644 --- a/mmrazor/models/utils/make_divisible.py +++ b/mmrazor/models/utils/make_divisible.py @@ -23,7 +23,6 @@ def make_divisible(value: int, Returns: int: The modified output channel number """ - if min_value is None: min_value = divisor if min_value < divisor: diff --git a/mmrazor/structures/graph/base_graph.py b/mmrazor/structures/graph/base_graph.py index a7dba7e4f..4c40e0359 100644 --- a/mmrazor/structures/graph/base_graph.py +++ b/mmrazor/structures/graph/base_graph.py @@ -121,17 +121,27 @@ def add_node(self, node: BASENODE): if node.name not in self.nodes: self.nodes[node.name] = node else: - raise BaseException(f'{node.name} already exists in graph') + raise Exception(f'{node.name} already exists in graph') def connect(self, pre_node: BASENODE, next_node: BASENODE): """Add an edge from pre_node to next_node.""" - assert pre_node in self and next_node in self + pre_node_ = self.find_node(pre_node) + next_node_ = self.find_node(next_node) + assert pre_node_ is not None and next_node_ is not None, \ + f"{pre_node},{next_node} don't exist in the graph." + pre_node = pre_node_ + next_node = next_node_ pre_node.add_next_node(next_node) next_node.add_prev_node(pre_node) def disconnect(self, pre_node: BASENODE, next_node: BASENODE): """Remove the edge form pre_node to next_node.""" - assert pre_node in self and next_node in self + pre_node_ = self.find_node(pre_node) + next_node_ = self.find_node(next_node) + assert pre_node_ is not None and next_node_ is not None, \ + f"{pre_node},{next_node} don't exist in the graph." + pre_node = pre_node_ + next_node = next_node_ if next_node in pre_node.next_nodes: pre_node.next_nodes.remove(next_node) if pre_node in next_node.prev_nodes: @@ -185,7 +195,7 @@ def __len__(self) -> int: def __repr__(self): res = f'Graph with {len(self)} nodes:\n' for node in self: - res += '{0:<40} -> {1:^40} -> {2:<40}\n'.format( + res += '{0:<80} -> {1:^80} -> {2:<80}\n'.format( str(node.prev_nodes), node.__repr__(), str(node.next_nodes)) return res @@ -204,7 +214,7 @@ def find_zero_degree_node(in_degree): for node_name in in_degree: if in_degree[node_name] == 0: return node_name - return None + raise Exception(f'no zero degree node\n{in_degree}') in_degree = _in_degree(self) diff --git a/mmrazor/structures/graph/channel_flow.py b/mmrazor/structures/graph/channel_flow.py new file mode 100644 index 000000000..1ea0e548e --- /dev/null +++ b/mmrazor/structures/graph/channel_flow.py @@ -0,0 +1,185 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import itertools +import sys +from typing import List, Set, Union + +from mmrazor.utils import IndexDict + +sys.setrecursionlimit(int(pow(2, 20))) + + +class ChannelElem: + + def __init__(self, owning_tensor, index_in_tensor) -> None: + self._parent: Union[None, 'ChannelElem'] = None + self._subs: Set[ChannelElem] = set() + self.owing_tensor = owning_tensor + self.index_in_tensoor = index_in_tensor + self._hash_cache = None + self._min_elem_set_index_cache = None + + # channel elem operations + + @classmethod + def union_two(cls, elem1: 'ChannelElem', elem2: 'ChannelElem'): + root1 = elem1.root + root2 = elem2.root + if root1 is not root2: + root2._set_parent(root1) + + def union(self, elem: 'ChannelElem'): + ChannelElem.union_two(self, elem) + + # unit related + + @property + def owing_elem_set(self): + root = self.root + return root.subs + + def reset_cache(self): + self._hash_cache = None + self._min_elem_set_index_cache = None + + @property + def elem_set_hash(self): + if self._hash_cache is not None: + return self._hash_cache + else: + tensor_list = list(self.owing_elem_set) + tensor_set = set([elem.owing_tensor for elem in tensor_list]) + frozen_set = frozenset(tensor_set) + hash = frozen_set.__hash__() + for elem in self.owing_elem_set: + assert elem._hash_cache is None + elem._hash_cache = hash + return hash + + @property + def min_elem_set_index(self): + if self._min_elem_set_index_cache is not None: + return self._min_elem_set_index_cache + else: + elem_set = self.owing_elem_set + min_index = int(pow(2, 32)) + for elem in elem_set: + min_index = min(min_index, elem.index_in_tensoor) + for elem in elem_set: + assert elem._min_elem_set_index_cache is None + elem._min_elem_set_index_cache = min_index + return min_index + + # work as a disjoint set + + @property + def root(self) -> 'ChannelElem': + if self._parent is None: + return self + else: + root = self._parent.root + self._unset_parent() + self._set_parent(root) + return root + + @property + def subs(self): + subs = copy.copy(self._subs) + subs.add(self) + for elem in self._subs: + subs = subs.union(elem.subs) + return subs + + def _set_parent(self, parent: 'ChannelElem'): + assert self._parent is None + assert parent.root is not self + self._parent = parent + parent._subs.add(self) + + def _unset_parent(self): + assert self._parent is not None + old_parent = self._parent + old_parent._subs.remove(self) + self._parent = None + + +class ChannelTensor: + + def __init__(self, num_channel_elem: int) -> None: + self.elems = [ChannelElem(self, i) for i in range(num_channel_elem)] + + # tensor operations + + def union(self, tensor: 'ChannelTensor'): + return self.__class__.union_two(self, tensor) + + @classmethod + def union_two(cls, tensor1: 'ChannelTensor', tensor2: 'ChannelTensor'): + assert len(tensor1) == len(tensor2), f'{len(tensor1)}!={len(tensor2)}' + for e1, e2 in zip(tensor1, tensor2): + ChannelElem.union_two(e1, e2) + + @classmethod + def cat(cls, tensors: List['ChannelTensor']): + elems = list(itertools.chain(*[t.elems for t in tensors])) + new_tensor = ChannelTensor(len(elems)) + new_tensor.elems = elems + return new_tensor + + def expand(self, expand_ratio: int): + new_tensor = ChannelTensor(expand_ratio * len(self)) + + for i in range(len(self)): + for j in range(expand_ratio): + self[i].union(new_tensor[i * expand_ratio + j]) + return new_tensor + + # unit operation + + @property + def elems_hash_with_index(self): + elem_hashes = [(elem.elem_set_hash, elem.min_elem_set_index) + for elem in self.elems] + return elem_hashes + + @property + def elems_hash_dict(self): + elem_hash_with_index = self.elems_hash_with_index + unit_dict = IndexDict() + start = 0 + for e in range(1, len(self)): + if (elem_hash_with_index[e][0] != elem_hash_with_index[e - 1][0] + or elem_hash_with_index[e][1] < + elem_hash_with_index[e - 1][1]): + + unit_dict[(start, e)] = elem_hash_with_index[start][0] + start = e + unit_dict[start, len(self)] = elem_hash_with_index[start][0] + return unit_dict + + # work as a tensor + + def __getitem__(self, key: Union[int, slice]): + if isinstance(key, int): + return self.elems[key] + elif isinstance(key, slice): + elems = self.elems[key] + tensor = ChannelTensor(len(elems)) + tensor.elems = elems + return tensor + else: + raise NotImplementedError() + + def __len__(self): + return len(self.elems) + + def __iter__(self): + for e in self.elems: + yield e + + def __add__(self, tensor: 'ChannelTensor'): + return ChannelTensor.cat([self, tensor]) + + def _reset_channel_elem_cache(self): + for elem in self.elems: + elem.reset_cache() diff --git a/mmrazor/structures/graph/channel_graph.py b/mmrazor/structures/graph/channel_graph.py index a1629c587..00146ec74 100644 --- a/mmrazor/structures/graph/channel_graph.py +++ b/mmrazor/structures/graph/channel_graph.py @@ -1,12 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy from typing import Callable, Dict, List +from mmengine import MMLogger from torch.nn import Module from .base_graph import BaseGraph -from .channel_modules import BaseChannelUnit, ChannelTensor -from .channel_nodes import ChannelNode, default_channel_node_converter -from .module_graph import ModuleGraph +from .channel_flow import ChannelTensor +from .channel_nodes import (ChannelDismatchError, ChannelNode, EndNode, + ExpandChannelNode, InputChannelNode, + default_channel_node_converter) +from .module_graph import ModuleGraph, NoInputError, NoOutputError class ChannelGraph(ModuleGraph[ChannelNode]): @@ -22,38 +26,160 @@ def copy_from(cls, node_converter: Callable = default_channel_node_converter): """Copy from a ModuleGraph.""" assert isinstance(graph, ModuleGraph) - return super().copy_from(graph, node_converter) + channel_graph: ChannelGraph = super().copy_from(graph, node_converter) + channel_graph._insert_expand_node() + return channel_graph - def collect_units(self) -> List[BaseChannelUnit]: - """Collect channel units in the graph.""" - units = list() - for node in self.topo_traverse(): - node.register_channel_to_units() + def generate_units_config(self) -> Dict: + """Collect channel units in the graph. + "hash"{ + 'init_args':{ + 'num_channels': 10 + } + 'channels':{ + 'input_related':[ + { + "name":"backbone.bn1", + "start":0, + "end":64, + "expand_ratio":1, + "is_output_channel":false + } + ], + 'output_related':[ + ... + ] + } + }""" + + chanel_config_template: Dict = { + 'init_args': { + 'num_channels': 1 + }, + 'channels': { + 'input_related': [], + 'output_related': [] + } + } + + def process_tensor(node: ChannelNode, is_output_tensor, + unit_hash_dict: Dict): + if is_output_tensor: + tensor = node.out_channel_tensor + else: + tensor = node.in_channel_tensor + assert tensor is not None + for (start, end), hash in tensor.elems_hash_dict.items(): + channel_config = { + 'name': node.module_name if node.is_module else node.val, + 'start': start, + 'end': end, + 'is_output_channel': is_output_tensor + } + if hash not in unit_hash_dict: + unit_hash_dict[hash] = copy.deepcopy( + chanel_config_template) + related_dict = unit_hash_dict[hash]['channels'][ + 'output_related' if is_output_tensor else 'input_related'] + if channel_config not in related_dict: + related_dict.append(channel_config) + + def fill_num_channels(units_config: Dict): + + def min_num_channels(channel_configs: List[Dict]): + min_num_channels = int(pow(2, 32)) + for channel in channel_configs: + min_num_channels = min(min_num_channels, + channel['end'] - channel['start']) + return min_num_channels + + for name in units_config: + units_config[name]['init_args'][ + 'num_channels'] = min_num_channels( + units_config[name]['channels']['input_related'] + + units_config[name]['channels']['output_related']) + + unit_hash_dict: Dict = {} + self._reset_channel_elem_cache() for node in self.topo_traverse(): - for unit in node.in_channel_tensor.unit_list + \ - node.out_channel_tensor.unit_list: - if unit not in units: - units.append(unit) - return units + process_tensor(node, True, unit_hash_dict) + process_tensor(node, False, unit_hash_dict) + fill_num_channels(unit_hash_dict) + return unit_hash_dict def forward(self, num_input_channel=3): """Generate a ChanneelTensor and let it forwards through the graph.""" for node in self.topo_traverse(): node.reset_channel_tensors() - self._merge_same_module() for i, node in enumerate(self.topo_traverse()): node: ChannelNode if len(node.prev_nodes) == 0: - channel_list = ChannelTensor(num_input_channel) - node.forward([channel_list]) + tensor = ChannelTensor(num_input_channel) + node.forward([tensor]) else: node.forward() + self._merge_same_module() + + def _check(self, node: ChannelNode, fix=False): + + try: + node.check_channel() + node.check() + except Exception as e: + if not fix: + raise e + else: + try: + raise e + except NoOutputError as e: + MMLogger.get_current_instance().warn( + f'add a output after {node}, error: {e}') + self._add_output_after(node) + except NoInputError as e: + MMLogger.get_current_instance().warn( + f'add a input before {node}, error: {e}') + self._add_input_before(node) + except ChannelDismatchError as e: + MMLogger.get_current_instance().warn( + (f'{node} has channel error, so' + f'we convert it to a EndNode. error: {e}')) + self._convert_a_node_to_end_node(node) + + self._check(node, fix=True) + + def _add_input_before(self, node: ChannelNode): + try: + in_channels = node.in_channels + except Exception: + in_channels = 3 + input_node = InputChannelNode( + f'auto_input_{in_channels}', + 'input_placeholder', + input_channels=in_channels) # type: ignore + input_node = self.add_or_find_node(input_node) + self.connect(input_node, node) + + def _add_output_after(self, node: ChannelNode): + output_node = EndNode('auto_output', + 'output_placeholder') # type: ignore + output_node = self.add_or_find_node(output_node) + self.connect(node, output_node) + + def _convert_a_node_to_end_node(self, node: ChannelNode): + + end_node = EndNode('auto_end', 'output_placeholder') + end_node = self.add_or_find_node(end_node) + for prev in copy.copy(node.prev_nodes): + self.disconnect(prev, node) + self.connect(prev, end_node) + self._add_input_before(node) def _merge_same_module(self): """Union all nodes with the same module to the same unit.""" module2node: Dict[Module, List[ChannelNode]] = dict() for node in self: - if isinstance(node.val, Module): + if isinstance(node.val, + Module) and len(list(node.val.parameters())) > 0: if node.val not in module2node: module2node[node.val] = [] if node not in module2node[node.val]: @@ -62,10 +188,44 @@ def _merge_same_module(self): for module in module2node: if len(module2node[module]) > 1: nodes = module2node[module] - input_channel_tensor = ChannelTensor(nodes[0].in_channels) - out_channel_tensor = ChannelTensor(nodes[0].out_channels) - for node in nodes: - ChannelTensor.union(input_channel_tensor, - node.in_channel_tensor) - ChannelTensor.union(out_channel_tensor, - node.out_channel_tensor) + assert nodes[0].in_channel_tensor is not None and \ + nodes[0].out_channel_tensor is not None + for node in nodes[1:]: + nodes[0].in_channel_tensor.union(node.in_channel_tensor) + nodes[0].out_channel_tensor.union(node.out_channel_tensor) + + def _insert_expand_node(self): + num_expand_nodes = 0 + nodes: List[ChannelNode] = copy.copy(list(self.topo_traverse())) + for node in nodes: + try: + node.check_channel() + except Exception: + for pre_node in node.prev_nodes: + pre_node: ChannelNode + if (pre_node.out_channels < node.in_channels + and node.in_channels % pre_node.out_channels == 0): + from mmengine import MMLogger + MMLogger.get_current_instance().warning( + (f'As the channels of {pre_node} and {node} ' + 'dismatch, we add an ExpandNode between them.')) + expand_ratio = ( + node.in_channels // pre_node.out_channels) + # insert a expand node + new_node = ExpandChannelNode( + f'expand_{num_expand_nodes}', + 'expand', + expand_ratio=expand_ratio) + num_expand_nodes += 1 + self.add_node(new_node) + self.connect(pre_node, new_node) + self.connect(new_node, node) + self.disconnect(pre_node, node) + + def _reset_channel_elem_cache(self): + # may has bug, as some tensor not recorded by node.xxxx_tensors + for node in self.topo_traverse(): + assert (node.in_channel_tensor is not None + and node.out_channel_tensor is not None), f'{node}' + node.in_channel_tensor._reset_channel_elem_cache() + node.out_channel_tensor._reset_channel_elem_cache() diff --git a/mmrazor/structures/graph/channel_modules.py b/mmrazor/structures/graph/channel_modules.py deleted file mode 100644 index 1cfa2d5ff..000000000 --- a/mmrazor/structures/graph/channel_modules.py +++ /dev/null @@ -1,372 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -from typing import Dict, List, Tuple, Union - -# Channels - - -class BaseChannel: - """BaseChannel records information about channels for pruning. - - Args: - name (str): The name of the channel. When the channel is related with - a module, the name should be the name of the module in the model. - module (Any): Module of the channel. - index (Tuple[int,int]): Index(start,end) of the Channel in the Module - node (ChannelNode, optional): A ChannelNode corresponding to the - Channel. Defaults to None. - is_output_channel (bool, optional): Is the channel output channel. - Defaults to True. - expand_ratio (int, optional): Expand ratio of the mask. Defaults to 1. - """ - - # init - - def __init__(self, - name, - module, - index, - node=None, - is_output_channel=True, - expand_ratio=1) -> None: - self.name = name - self.module = module - self.index = index - self.start = index[0] - self.end = index[1] - - self.node = node - - self.is_output_channel = is_output_channel - self.expand_ratio = expand_ratio - - @property - def num_channels(self) -> int: - """The number of channels in the Channel.""" - return self.index[1] - self.index[0] - - # others - - def __repr__(self) -> str: - return f'{self.name}\t{self.index}\t \ - {"out" if self.is_output_channel else "in"}\t\ - expand:{self.expand_ratio}' - - def __eq__(self, obj: object) -> bool: - if isinstance(obj, BaseChannel): - return self.name == obj.name \ - and self.module == obj.module \ - and self.index == obj.index \ - and self.is_output_channel == obj.is_output_channel \ - and self.expand_ratio == obj.expand_ratio \ - and self.node == obj.node - else: - return False - - -class BaseChannelUnit: - """BaseChannelUnit is a collection of BaseChannel. - - All BaseChannels are saved in two lists: self.input_related and - self.output_related. - """ - - def __init__(self) -> None: - - self.channel_elems: Dict[int, List[ChannelElement]] = {} - self.input_related: List[BaseChannel] = [] - self.output_related: List[BaseChannel] = [] - - # ~ - - def add_channel_elem(self, channel_elem: 'ChannelElement', index): - """Add a ChannelElement to the BaseChannelUnit.""" - self._add_channel_info(channel_elem, index) - if channel_elem.unit is not None: - channel_elem.remove_from_unit() - channel_elem._register_unit(self, index) - - # unit operations - - @classmethod - def union_units(cls, units: List['BaseChannelUnit']): - """Union units.""" - assert len(units) > 1 - union_unit = units[0] - - for unit in units[1:]: - union_unit = BaseChannelUnit.union_two_units(union_unit, unit) - return union_unit - - @classmethod - def union_two_units(cls, unit1: 'BaseChannelUnit', - unit2: 'BaseChannelUnit'): - """Union two units.""" - if unit1 is unit2: - return unit1 - else: - assert len(unit1) == len(unit2) - for i in unit1: - for channel_elem in copy.copy(unit2[i]): - unit1.add_channel_elem(channel_elem, i) - return unit1 - - @classmethod - def split_unit(cls, unit: 'BaseChannelUnit', nums: List[int]): - """Split a unit to multiple units.""" - new_units = [] - if len(nums) == 1: - return [unit] - assert sum(nums) == len(unit) - for num in nums: - new_unit = unit._split_a_new_unit(list(range(0, num))) - new_units.append(new_unit) - return new_units - - # private methods - - def _clean_channel_info(self, channel_elem: 'ChannelElement', index: int): - """Clean the info of a ChannelElement.""" - self[index].remove(channel_elem) - - def _add_channel_info(self, channel_elem: 'ChannelElement', index): - """Add the info of a ChannelElemnt.""" - assert channel_elem.unit is not self - if index not in self.channel_elems: - self.channel_elems[index] = [] - self.channel_elems[index].append(channel_elem) - - def _split_a_new_unit(self, indexes: List[int]): - """Split a part of the unit to a new unit.""" - new_unit = BaseChannelUnit() - j = 0 - for i in indexes: - for channel_elem in copy.copy(self[i]): - new_unit.add_channel_elem(channel_elem, j) - self.channel_elems.pop(i) - j += 1 - self._reindex() - return new_unit - - def _reindex(self): - """Re-index the owning ChannelElements.""" - j = 0 - for i in copy.copy(self.channel_elems): - if len(self.channel_elems[i]) == 0: - self.channel_elems.pop(i) - else: - if j < i: - for channel_elem in copy.copy(self.channel_elems[i]): - if channel_elem.unit is not None: - channel_elem.remove_from_unit() - self.add_channel_elem(channel_elem, j) - self.channel_elems.pop(i) - j += 1 - elif j == i: - pass - else: - raise Exception() - - # others - - def __repr__(self) -> str: - - def add_prefix(string: str, prefix=' '): - str_list = string.split('\n') - str_list = [ - prefix + line if line != '' else line for line in str_list - ] - return '\n'.join(str_list) - - def list_repr(lit: List): - s = '[\n' - for item in lit: - s += add_prefix(item.__repr__(), ' ') + '\n' - s += ']\n' - return s - - s = ('xxxxx_' - f'\t{len(self.output_related)},{len(self.input_related)}\n') - s += ' output_related:\n' - s += add_prefix(list_repr(self.output_related), ' ' * 4) - s += ' input_related\n' - s += add_prefix(list_repr(self.input_related), ' ' * 4) - return s - - def __iter__(self): - for i in self.channel_elems: - yield i - - def __len__(self): - return len(self.channel_elems) - - def __getitem__(self, key): - return self.channel_elems[key] - - -class ChannelElement: - """Each ChannelElement is the basic element of a ChannelTensor. It records - its owing ChannelTensor and BaseChannelUnit. - - Args: - index (int): The index of the ChannelElement in the ChannelTensor. - """ - - def __init__(self, index_in_tensor: int) -> None: - - self.index_in_channel_tensor = index_in_tensor - - self.unit: Union[BaseChannelUnit, None] = None - self.index_in_unit = -1 - - def remove_from_unit(self): - """Remove the ChannelElement from its owning BaseChannelUnit.""" - self.unit._clean_channel_info(self, self.index_in_unit) - self._clean_unit_info() - - # private methods - - def _register_unit(self, unit, index): - """Register the ChannelElement to a BaseChannelUnit.""" - self.unit = unit - self.index_in_unit = index - - def _clean_unit_info(self): - """Clean the unit info in the ChannelElement.""" - self.unit = None - self.index_in_unit = -1 - - -class ChannelTensor: - """A ChannelTensor is a list of ChannelElemnts. It can forward through a - ChannelGraph. - - Args: - num_channel_elems (int): Number of ChannelElements. - """ - - def __init__(self, num_channel_elems: int) -> None: - - unit = BaseChannelUnit() - self.channel_elems: List[ChannelElement] = [ - ChannelElement(i) for i in range(num_channel_elems) - ] - for channel_elem in self.channel_elems: - unit.add_channel_elem(channel_elem, - channel_elem.index_in_channel_tensor) - - # unit operations - - def align_units_with_nums(self, nums: List[int]): - """Align owning units to certain lengths.""" - i = 0 - for start, end in self.unit_dict: - start_ = start - new_nums = [] - while start_ < end: - new_nums.append(nums[i]) - start_ += nums[i] - i += 1 - BaseChannelUnit.split_unit(self.unit_dict[(start, end)], new_nums) - - @property - def unit_dict(self) -> Dict[Tuple[int, int], BaseChannelUnit]: - """Get a dict of owning units.""" - units: Dict[Tuple[int, int], BaseChannelUnit] = {} - # current_unit = ... - current_unit_idx = -1 - start = 0 - for i in range(len(self)): - if i == 0: - current_unit = self[i].unit - current_unit_idx = self[i].index_in_unit - start = 0 - else: - if current_unit is not self[i].unit or \ - current_unit_idx > self[i].index_in_unit: - units[(start, i)] = current_unit - current_unit = self[i].unit - current_unit_idx = self[i].index_in_unit - start = i - current_unit_idx = self[i].index_in_unit - units[(start, len(self))] = current_unit - return units - - @property - def unit_list(self) -> List[BaseChannelUnit]: - """Get a list of owning units.""" - return list(self.unit_dict.values()) - - # tensor operations - - @classmethod - def align_tensors(cls, *tensors: 'ChannelTensor'): - """Align the lengths of the units of the tensors.""" - assert len(tensors) >= 2 - for tensor in tensors: - assert len(tensor) == len( - tensors[0]), f'{len(tensor)}!={len(tensors[0])}' - aligned_index = cls._index2points( - *[list(tenser.unit_dict.keys()) for tenser in tensors]) - nums = cls._points2num(aligned_index) - if len(nums) > 1: - for tensor in tensors: - tensor.align_units_with_nums(nums) - - def union(self, tensor1: 'ChannelTensor'): - """Union the units with the tensor1.""" - # align - ChannelTensor.align_tensors(self, tensor1) - # union - for ch1, ch2 in zip(self.channel_elems, tensor1.channel_elems): - assert ch1.unit is not None and ch2.unit is not None - for ch in copy.copy(ch2.unit.channel_elems[ch2.index_in_unit]): - ch1.unit.add_channel_elem(ch, ch1.index_in_unit) - - def expand(self, ratio) -> 'ChannelTensor': - """Get a new ChannelTensor which is expanded from this - ChannelTensor.""" - expanded_tensor = ChannelTensor(len(self) * ratio) - for i, ch in enumerate(self.channel_elems): - assert ch.unit is not None - unit = ch.unit - for j in range(0, ratio): - ex_ch = expanded_tensor[i * ratio + j] - unit.add_channel_elem(ex_ch, ch.index_in_unit) - return expanded_tensor - - # others - - def __getitem__(self, i: int): - """Get ith ChannelElement in the ChannelTensor.""" - return self.channel_elems[i] - - def __len__(self): - """Get length of the ChannelTensor.""" - return len(self.channel_elems) - - @classmethod - def _index2points(cls, *indexes: List[Tuple[int, int]]): - """Convert indexes to points.""" - new_index = [] - for index in indexes: - new_index.extend(index) - points = set() - for start, end in new_index: - points.add(start) - points.add(end) - points_list = list(points) - points_list.sort() - return points_list - - @classmethod - def _points2num(cls, indexes: List[int]): - """Convert a list of sorted points to the length of each block.""" - if len(indexes) == 0: - return [] - nums = [] - start = 0 - for end in indexes[1:]: - nums.append(end - start) - start = end - return nums diff --git a/mmrazor/structures/graph/channel_nodes.py b/mmrazor/structures/graph/channel_nodes.py index 1749b5875..44ec7a134 100644 --- a/mmrazor/structures/graph/channel_nodes.py +++ b/mmrazor/structures/graph/channel_nodes.py @@ -2,16 +2,26 @@ import operator from abc import abstractmethod -from typing import Union +from typing import List, Union import torch import torch.nn as nn +from mmcv.cnn.bricks import Scale from mmengine import MMLogger -from .channel_modules import BaseChannel, BaseChannelUnit, ChannelTensor +from .channel_flow import ChannelTensor from .module_graph import ModuleNode +class ChannelDismatchError(Exception): + pass + + +def assert_channel(condition, node): + if not condition: + raise ChannelDismatchError(node.name) + + class ChannelNode(ModuleNode): """A ChannelNode is like a torch module. It accepts a ChannelTensor and output a ChannelTensor. The difference is that the torch module transforms @@ -32,70 +42,46 @@ class ChannelNode(ModuleNode): def __init__(self, name: str, val: Union[nn.Module, str], - expand_ratio: int = 1, module_name='') -> None: - super().__init__(name, val, expand_ratio, module_name) - self.in_channel_tensor = ChannelTensor(self.in_channels) - self.out_channel_tensor = ChannelTensor(self.out_channels) + super().__init__(name, val, module_name) + self.in_channel_tensor: Union[None, ChannelTensor] = None + self.out_channel_tensor: Union[None, ChannelTensor] = None + self.return_tensor: Union[None, ChannelTensor] = None @classmethod def copy_from(cls, node): """Copy from a ModuleNode.""" assert isinstance(node, ModuleNode) - return cls(node.name, node.val, node.expand_ratio, node.module_name) + return cls(node.name, node.val, node.module_name) def reset_channel_tensors(self): """Reset the owning ChannelTensors.""" - self.in_channel_tensor = ChannelTensor(self.in_channels) - self.out_channel_tensor = ChannelTensor(self.out_channels) + self.in_channel_tensor = None + self.out_channel_tensor = None # forward - def forward(self, in_channel_tensor=None): + def forward(self, in_channel_tensors=None): """Forward with ChannelTensors.""" - assert self.in_channel_tensor is not None and \ - self.out_channel_tensor is not None - if in_channel_tensor is None: + if in_channel_tensors is None: out_channel_tensors = [ - node.out_channel_tensor for node in self.prev_nodes + node.return_tensor for node in self.prev_nodes ] - - in_channel_tensor = out_channel_tensors - self.channel_forward(*in_channel_tensor) - if self.expand_ratio > 1: - self.out_channel_tensor = self.out_channel_tensor.expand( - self.expand_ratio) + in_channel_tensors = out_channel_tensors + try: + self.return_tensor = self.channel_forward(in_channel_tensors) + except Exception as e: + raise Exception(f'{e},{self.name}') @abstractmethod - def channel_forward(self, *channel_tensors: ChannelTensor): + def channel_forward(self, channel_tensors: List[ChannelTensor]): """Forward with ChannelTensors.""" assert len(channel_tensors) == 1, f'{len(channel_tensors)}' - BaseChannelUnit.union_two_units( - list(self.in_channel_tensor.unit_dict.values())[0], - list(channel_tensors[0].unit_dict.values())[0]) - - if self.in_channels == self.out_channels: - BaseChannelUnit.union_two_units( - self.in_channel_tensor.unit_list[0], - self.out_channel_tensor.unit_list[0]) - - # register unit - - def register_channel_to_units(self): - """Register the module of this node to corresponding units.""" - name = self.module_name if isinstance(self.val, - nn.Module) else self.name - for index, unit in self.in_channel_tensor.unit_dict.items(): - channel = BaseChannel(name, self.val, index, None, False, - self.expand_ratio) - if channel not in unit.input_related: - unit.input_related.append(channel) - for index, unit in self.out_channel_tensor.unit_dict.items(): - channel = BaseChannel(name, self.val, index, None, True, - self.expand_ratio) - if channel not in unit.output_related: - unit.output_related.append(channel) + + self.in_channel_tensor = channel_tensors[0] + self.out_channel_tensor = ChannelTensor(self.out_channels) + return self.out_channel_tensor # channels @@ -103,140 +89,227 @@ def register_channel_to_units(self): @property def in_channels(self) -> int: """Get the number of input channels of the node.""" - raise NotImplementedError() + try: + return self._in_channels + except NotImplementedError: + return \ + self._get_in_channels_by_prev_nodes(self.prev_nodes) # @abstractmethod @property def out_channels(self) -> int: """Get the number of output channels of the node.""" - raise NotImplementedError() + try: + return self._out_channels + except NotImplementedError: + return self._get_out_channel_by_in_channels(self.in_channels) + + def check_channel(self): + for node in self.prev_nodes: + assert_channel(node.out_channels == self.in_channels, self) + + @property + def _in_channels(self) -> int: + raise NotImplementedError( + f'{self.name}({self.__class__.__name__}) has no _in_channels') + + @property + def _out_channels(self) -> int: + raise NotImplementedError( + f'{self.name}({self.__class__.__name__}) has no _out_channels') + + def _get_out_channel_by_in_channels(self, in_channels): + return in_channels + + def _get_in_channels_by_prev_nodes(self, prev_nodes): + if len(prev_nodes) == 0: + from mmengine import MMLogger + MMLogger.get_current_instance().debug( + (f'As {self.name} ' + 'has no prev nodes, so we set the in channels of it to 3.')) + return 3 + else: + return prev_nodes[0].out_channels + + def __repr__(self) -> str: + return f'{self.name}_({self.in_channels},{self.out_channels})' # basic nodes -class PassChannelNode(ChannelNode): - """A PassChannelNode has the same number of input channels and output +class PassUnionChannelNode(ChannelNode): + """A PassUnionChannelNode has the same number of input channels and output channels. Besides, the corresponding input channels and output channels belong to one channel unit. Such as BatchNorm, Relu. """ - def channel_forward(self, *in_channel_tensor: ChannelTensor): + def channel_forward(self, channel_tensors: List[ChannelTensor]): """Channel forward.""" - PassChannelNode._channel_forward(self, *in_channel_tensor) + return PassUnionChannelNode._channel_forward(self, channel_tensors[0]) - @property - def in_channels(self) -> int: - """Get the number of input channels of the node.""" - if len(self.prev_nodes) > 0: - return self.prev_nodes[0].out_channels - else: - return 0 + @staticmethod + def _channel_forward(node: ChannelNode, tensor: ChannelTensor): + """Channel forward.""" + assert node.in_channels == node.out_channels + assert isinstance(tensor, ChannelTensor) + node.in_channel_tensor = tensor + node.out_channel_tensor = tensor + return node.out_channel_tensor - @property - def out_channels(self) -> int: - """Get the number of output channels of the node.""" - return self.in_channels + def __repr__(self) -> str: + return super().__repr__() + '_uion' + + +class PassChannelNode(ChannelNode): + + def _get_in_channels_by_prev_nodes(self, prev_nodes): + assert len(self.prev_nodes) == 1 + node0: ChannelNode = self.prev_nodes[0] + return node0.out_channels + + def channel_forward(self, channel_tensors: List[ChannelTensor]): + assert len(channel_tensors) == 1 + self.in_channel_tensor = ChannelTensor(1) + self.out_channel_tensor = ChannelTensor(1) + return channel_tensors[0] def __repr__(self) -> str: return super().__repr__() + '_pass' - @staticmethod - def _channel_forward(node: ChannelNode, *in_channel_tensor: ChannelTensor): - """Channel forward.""" - assert len(in_channel_tensor) == 1 and \ - node.in_channels == node.out_channels - in_channel_tensor[0].union(node.in_channel_tensor) - node.in_channel_tensor.union(node.out_channel_tensor) - class MixChannelNode(ChannelNode): """A MixChannelNode has independent input channels and output channels.""" - def channel_forward(self, *in_channel_tensor: ChannelTensor): + def channel_forward(self, channel_tensors: List[ChannelTensor]): """Channel forward.""" - assert len(in_channel_tensor) <= 1 - if len(in_channel_tensor) == 1: - in_channel_tensor[0].union(self.in_channel_tensor) - - @property - def in_channels(self) -> int: - """Get the number of input channels of the node.""" - if len(self.prev_nodes) > 0: - return self.prev_nodes[0].in_channels - else: - return 0 - - @property - def out_channels(self) -> int: - """Get the number of output channels of the node.""" - if len(self.next_nodes) > 0: - return self.next_nodes[0].in_channels + assert len(channel_tensors) <= 1 + if len(channel_tensors) == 1: + self.in_channel_tensor = channel_tensors[0] + self.out_channel_tensor = ChannelTensor(self.out_channels) else: - return 0 + raise NotImplementedError() + return self.out_channel_tensor def __repr__(self) -> str: return super().__repr__() + '_mix' -class BindChannelNode(PassChannelNode): +class BindChannelNode(ChannelNode): """A BindChannelNode has multiple inputs, and all input channels belong to the same channel unit.""" - def channel_forward(self, *in_channel_tensor: ChannelTensor): + def channel_forward(self, channel_tensors: List[ChannelTensor]): """Channel forward.""" - assert len(in_channel_tensor) > 1 + assert len(channel_tensors) > 0, f'{self}' # align channel_tensors - ChannelTensor.align_tensors(*in_channel_tensor) - - # union tensors - node_units = [ - channel_lis.unit_dict for channel_lis in in_channel_tensor - ] - for key in node_units[0]: - BaseChannelUnit.union_units([units[key] for units in node_units]) - super().channel_forward(in_channel_tensor[0]) + for tensor in channel_tensors[1:]: + channel_tensors[0].union(tensor) + self.in_channel_tensor = channel_tensors[0] + self.out_channel_tensor = channel_tensors[0] + return self.out_channel_tensor def __repr__(self) -> str: - return super(ChannelNode, self).__repr__() + '_bind' + return super().__repr__() + '_bind' + + def check_channel(self): + for node in self.prev_nodes: + assert_channel(node.out_channels == self.in_channels, self) class CatChannelNode(ChannelNode): """A CatChannelNode cat all input channels.""" - def channel_forward(self, *in_channel_tensors: ChannelTensor): - BaseChannelUnit.union_two_units(self.in_channel_tensor.unit_list[0], - self.out_channel_tensor.unit_list[0]) - num_ch = [] - for in_ch_tensor in in_channel_tensors: - for start, end in in_ch_tensor.unit_dict: - num_ch.append(end - start) + def channel_forward(self, channel_tensors: List[ChannelTensor]): + tensor_cat = ChannelTensor.cat(channel_tensors) + self.in_channel_tensor = tensor_cat + self.out_channel_tensor = tensor_cat + return self.out_channel_tensor - split_units = BaseChannelUnit.split_unit( - self.in_channel_tensor.unit_list[0], num_ch) + def check_channel(self): + in_num = [node.out_channels for node in self.prev_nodes] + assert_channel(sum(in_num) == self.in_channels, self) - i = 0 - for in_ch_tensor in in_channel_tensors: - for in_unit in in_ch_tensor.unit_dict.values(): - BaseChannelUnit.union_two_units(split_units[i], in_unit) - i += 1 + def _get_in_channels_by_prev_nodes(self, prev_nodes): + assert len(prev_nodes) > 0 + nums = [node.out_channels for node in prev_nodes] + return sum(nums) - @property - def in_channels(self) -> int: - """Get the number of input channels of the node.""" - return sum([node.out_channels for node in self.prev_nodes]) + def __repr__(self) -> str: + return super().__repr__() + '_cat' + + +class ExpandChannelNode(ChannelNode): + + def __init__(self, + name: str, + val: Union[nn.Module, str], + module_name='', + expand_ratio=1) -> None: + super().__init__(name, val, module_name) + self.expand_ratio = expand_ratio + + def _get_out_channel_by_in_channels(self, in_channels): + return in_channels * self.expand_ratio + + def channel_forward(self, channel_tensors: List[ChannelTensor]): + assert len(channel_tensors) == 1, f'{self}' + assert self.out_channels >= self.in_channels, f'{self}' + assert self.out_channels % self.in_channels == 0, f'{self}' + tensor0 = channel_tensors[0] + self.in_channel_tensor = tensor0 + self.out_channel_tensor = tensor0.expand(self.expand_ratio) + return self.out_channel_tensor + + def __repr__(self) -> str: + return super().__repr__() + f'_expand({self.expand_ratio})' + + +class InputChannelNode(ChannelNode): + + def __init__(self, + name: str, + val: Union[nn.Module, str], + module_name='', + input_channels=3) -> None: + super().__init__(name, val, module_name) + self._input_channels = input_channels + + def channel_forward(self, channel_tensors: List[ChannelTensor]): + input_tensor = ChannelTensor(self._input_channels) + self.in_channel_tensor = input_tensor + self.out_channel_tensor = input_tensor + return input_tensor @property - def out_channels(self) -> int: - """Get the number of output channels of the node.""" - return self.in_channels + def _in_channels(self) -> int: + return self._input_channels def __repr__(self) -> str: - return super().__repr__() + '_cat' + return super().__repr__() + '_input' +class EndNode(ChannelNode): + + def channel_forward(self, channel_tensors: List[ChannelTensor]): + tensor_end = ChannelTensor(1) + self.in_channel_tensor = tensor_end + self.out_channel_tensor = tensor_end + for channel in channel_tensors: + channel.union(tensor_end.expand(len(channel))) + return self.out_channel_tensor + + def __repr__(self) -> str: + return super().__repr__() + '_end' + + def check_channel(self): + pass + + +# class StackChannelNode(ChannelNode): + # module nodes @@ -249,33 +322,47 @@ class ConvNode(MixChannelNode): def __init__(self, name: str, val: Union[nn.Module, str], - expand_ratio: int = 1, module_name='') -> None: - super().__init__(name, val, expand_ratio, module_name) + super().__init__(name, val, module_name) assert isinstance(self.val, nn.Conv2d) + + @property + def conv_type(self): if self.val.groups == 1: - self.conv_type = 'conv' + return 'conv' elif self.val.in_channels == self.out_channels == self.val.groups: - self.conv_type = 'dwconv' + return 'dwconv' else: - self.conv_type = 'gwconv' + return 'gwconv' - def channel_forward(self, *in_channel_tensor: ChannelTensor): + def channel_forward(self, channel_tensors: List[ChannelTensor]): if self.conv_type == 'conv': - return super().channel_forward(*in_channel_tensor) + return super().channel_forward(channel_tensors) elif self.conv_type == 'dwconv': - return PassChannelNode._channel_forward(self, *in_channel_tensor) + return PassUnionChannelNode._channel_forward( + self, channel_tensors[0]) elif self.conv_type == 'gwconv': - return super().channel_forward(*in_channel_tensor) + return self._gw_conv_channel_forward(channel_tensors) else: - pass + raise NotImplementedError(f'{self}') + + def _gw_conv_channel_forward(self, channel_tensors: List[ChannelTensor]): + + assert len(channel_tensors) == 1 + tensor0 = channel_tensors[0] + conv: nn.Conv2d = self.val + group_union(tensor0, conv.groups) + self.in_channel_tensor = tensor0 + self.out_channel_tensor = ChannelTensor(self.out_channels) + group_union(self.out_channel_tensor, conv.groups) + return self.out_channel_tensor @property - def in_channels(self) -> int: + def _in_channels(self) -> int: return self.val.in_channels @property - def out_channels(self) -> int: + def _out_channels(self) -> int: return self.val.out_channels def __repr__(self) -> str: @@ -288,73 +375,118 @@ class LinearNode(MixChannelNode): def __init__(self, name: str, val: Union[nn.Module, str], - expand_ratio: int = 1, module_name='') -> None: - super().__init__(name, val, expand_ratio, module_name) + super().__init__(name, val, module_name) assert isinstance(self.val, nn.Linear) @property - def in_channels(self) -> int: + def _in_channels(self) -> int: return self.val.in_features @property - def out_channels(self) -> int: + def _out_channels(self) -> int: return self.val.out_features def __repr__(self) -> str: - return super().__repr__() + 'linear' + return super().__repr__() + '_linear' -class NormNode(PassChannelNode): +class BnNode(PassUnionChannelNode): """A NormNode corresponds to a BatchNorm2d module.""" def __init__(self, name: str, val: Union[nn.Module, str], - expand_ratio: int = 1, module_name='') -> None: - super().__init__(name, val, expand_ratio, module_name) + super().__init__(name, val, module_name) assert isinstance(self.val, nn.BatchNorm2d) @property - def in_channels(self) -> int: + def _in_channels(self) -> int: return self.val.num_features @property - def out_channels(self) -> int: + def _out_channels(self) -> int: return self.val.num_features def __repr__(self) -> str: return super().__repr__() + '_bn' -# converter +class GroupNormNode(PassUnionChannelNode): + + def __init__(self, + name: str, + val: Union[nn.Module, str], + module_name='') -> None: + super().__init__(name, val, module_name) + assert isinstance(self.val, nn.GroupNorm) + self.val: nn.GroupNorm + @property + def _in_channels(self) -> int: + return self.val.num_channels -def default_channel_node_converter(node: ModuleNode) -> ChannelNode: - """The default node converter for ChannelNode.""" + @property + def _out_channels(self) -> int: + return self.val.num_channels + + def channel_forward(self, channel_tensors: List[ChannelTensor]): + out_tensor = super().channel_forward(channel_tensors) + group_tensor = ChannelTensor(self.in_channels // self.val.num_groups) + group_union(out_tensor, self.val.num_groups, group_tensor) + return out_tensor - def warn(default='PassChannelNode'): - logger = MMLogger('mmrazor', 'mmrazor') - logger.warn((f"{node.name}({node.val}) node can't find match type of" - 'channel_nodes,' - f'replaced with {default} by default.')) + def __repr__(self) -> str: + return super().__repr__() + '_gn' + + +# converter - module_mapping = { +channel_nodes_mapping = { + 'module': { nn.Conv2d: ConvNode, - nn.BatchNorm2d: NormNode, + nn.modules.batchnorm._BatchNorm: BnNode, nn.Linear: LinearNode, - } - function_mapping = { + nn.modules.ReLU: PassChannelNode, + nn.modules.Hardtanh: PassChannelNode, + # pools + nn.modules.pooling._AvgPoolNd: PassChannelNode, + nn.modules.pooling._AdaptiveAvgPoolNd: PassChannelNode, + nn.modules.pooling._MaxPoolNd: PassChannelNode, + nn.modules.pooling._AdaptiveMaxPoolNd: PassChannelNode, + Scale: PassChannelNode, + nn.modules.GroupNorm: GroupNormNode, + }, + 'function': { torch.add: BindChannelNode, torch.cat: CatChannelNode, - operator.add: BindChannelNode - } - name_mapping = { + operator.add: BindChannelNode, + }, + 'str': { 'bind_placeholder': BindChannelNode, - 'pass_placeholder': PassChannelNode, + 'pass_placeholder': PassUnionChannelNode, 'cat_placeholder': CatChannelNode, - } + 'input_placeholder': InputChannelNode, + 'output_placeholder': EndNode + }, +} + + +def default_channel_node_converter( + node: ModuleNode, + module_mapping=channel_nodes_mapping['module'], + function_mapping=channel_nodes_mapping['function'], + name_mapping=channel_nodes_mapping['str']) -> ChannelNode: + """The default node converter for ChannelNode.""" + + def warn(default='PassUnionChannelNode'): + logger = MMLogger.get_current_instance() + logger.info( + (f"{node.name}({node.module_name}) node can't find match type of" + 'channel_nodes,' + f'replaced with {default} by default.')) + if isinstance(node.val, nn.Module): # module_mapping for module_type in module_mapping: @@ -365,7 +497,6 @@ def warn(default='PassChannelNode'): for module_type in name_mapping: if node.val == module_type: return name_mapping[module_type].copy_from(node) - else: for fun_type in function_mapping: if node.val == fun_type: @@ -374,5 +505,17 @@ def warn(default='PassChannelNode'): warn('BindChannelNode') return BindChannelNode.copy_from(node) else: - warn('PassChannelNode') - return PassChannelNode.copy_from(node) + warn('PassUnionChannelNode') + return PassUnionChannelNode.copy_from(node) + + +# helper functions + + +def group_union(tensor: ChannelTensor, groups: int, group_tensor=None): + c_per_group = len(tensor) // groups + if group_tensor is None: + group_tensor = ChannelTensor(c_per_group) + assert groups * len(group_tensor) == len(tensor) + for i in range(groups): + tensor[i * c_per_group:(i + 1) * c_per_group].union(group_tensor) diff --git a/mmrazor/structures/graph/module_graph.py b/mmrazor/structures/graph/module_graph.py index bc7e90dac..b7ce64738 100644 --- a/mmrazor/structures/graph/module_graph.py +++ b/mmrazor/structures/graph/module_graph.py @@ -8,6 +8,7 @@ from typing import Dict, List, TypeVar, Union import torch.nn as nn +from mmengine import MMLogger from torch.nn import Module from mmrazor.models.task_modules.tracer.backward_tracer import BackwardTracer @@ -15,10 +16,32 @@ ImageClassifierPseudoLoss from mmrazor.models.task_modules.tracer.path import (Path, PathConcatNode, PathList, PathNode) +from mmrazor.models.task_modules.tracer.razor_tracer import (FxBaseNode, + RazorFxTracer) from mmrazor.registry import TASK_UTILS from .base_graph import BaseGraph, BaseNode + # ModuleNode && ModuleGraph +class NoOutputError(Exception): + + def __init__(self, node, *args: object) -> None: + super().__init__(f'{node}', *args) + self.node = node + + pass + + +class NoInputError(Exception): + + def __init__(self, node, *args: object) -> None: + super().__init__(f'{node}', *args) + self.node = node + + +def my_assert(condiion, exception): + if not condiion: + raise exception class ModuleNode(BaseNode): @@ -35,7 +58,6 @@ class ModuleNode(BaseNode): def __init__(self, name: str, val: Union[Module, str], - expand_ratio: int = 1, module_name='') -> None: """ Args: @@ -43,8 +65,6 @@ def __init__(self, val (Module | str): content of the node. It can be Module or string. If val is a string, the string can only be one of self.pre_defined_node_val_str - expand_ratio (int): expand_ratio is used in bind node, - where the out_channel is always a multiple of the in_channel. Note: Here, we give an example of expand_ratio. >>> class Pool(nn.Module): @@ -54,78 +74,23 @@ def forward(x): >>> assert node.out_channels == node.in_channels*4 """ - assert (isinstance(val, Module) - or val in self.__class__.pre_defined_node_val_str - ), f'{val} node is not allowed' - if expand_ratio != 1: - assert val == 'pass_placeholder', \ - 'expand != 1 is only valid when val=="pass"' + # assert (isinstance(val, Module) + # or val in self.__class__.pre_defined_node_val_str + # ), f'{val} node is not allowed' super().__init__(name, val) - self.expand_ratio = expand_ratio self.module_name = module_name - # channel - - @property - def in_channels(self) -> int: - """int: the in_channels of the node.""" - if isinstance(self.val, nn.Module): - MAPPING = { - nn.Conv2d: 'in_channels', - nn.modules.batchnorm._BatchNorm: 'num_features', - nn.modules.Linear: 'in_features', - } - for basetype in MAPPING: - if isinstance(self.val, basetype): - return getattr(self.val, MAPPING[basetype]) - raise NotImplementedError(f'unsupported module: {self.val}') - elif self.is_bind_node() or self.is_pass_node(): - if len(self.prev_nodes) > 0: - return self.prev_nodes[0].out_channels - else: - return 0 - elif self.is_cat_node(): - return sum([ - node.out_channels if node.out_channels is not None else 0 - for node in self.prev_nodes - ]) - else: - raise NotImplementedError( - f'unsupported node type: {self.basic_type}') + # other @property - def out_channels(self) -> int: - """int: the out_channels of the node.""" - if isinstance(self.val, nn.Module): - MAPPING = { - nn.Conv2d: 'out_channels', - nn.modules.batchnorm._BatchNorm: 'num_features', - nn.modules.Linear: 'out_features', - } - for basetype in MAPPING: - if isinstance(self.val, basetype): - return getattr(self.val, MAPPING[basetype]) - raise NotImplementedError(f'unsupported module: {self.val}') - elif self.is_bind_node(): - if len(self.prev_nodes) > 0: - return self.prev_nodes[0].out_channels - else: - return 0 - elif self.is_pass_node(): - return self.in_channels * self.expand_ratio - elif self.is_cat_node(): - return sum([ - node.out_channels if node.out_channels is not None else 0 - for node in self.prev_nodes - ]) - else: - raise NotImplementedError( - f'unsupported node type: {self.basic_type}') - - # other + def is_module(self): + return isinstance(self.val, nn.Module) def __repr__(self) -> str: - return f'{self.name}_({self.in_channels},{self.out_channels})' + repr = f'{self.name}' + if self.module_name != '': + repr += f'({self.module_name})' + return repr # node type @@ -150,7 +115,7 @@ def basic_type(self) -> str: elif isinstance(self.val, nn.Linear): return 'linear' else: - raise NotImplementedError(f'{self}') + raise NotImplementedError(f'{self.val}') else: if self.val in [ 'cat_placeholder', 'bind_placeholder', 'pass_placeholder' @@ -178,31 +143,23 @@ def is_mix_node(self): generete new output channels, such as conv and linear.""" return self.basic_type in ['conv2d', 'linear', 'gwconv2d'] - # check + def is_input(self): + return self.val == 'input_placeholder' - def check_channel(self): - """Check if the channels of the node is matchable with previous nodes - and next nodes.""" - if self.is_cat_node(): - pass - else: - for pre in self.prev_nodes: - assert pre.out_channels == self.in_channels, \ - f'{self} has channel error' - - def check_type(self): - """Check if the node has right number of previous nodes according to - their type.""" - if self.is_pass_node(): - assert len(self.prev_nodes) <= 1, '{name} pass node error' - elif self.is_cat_node(): - pass - elif self.is_bind_node(): - assert len(self.prev_nodes) > 1, '{name} bind node error' - elif self.is_mix_node(): - assert len(self.prev_nodes) <= 1, '{name} mix node error' + def is_output(self): + return self.val == 'output_placeholder' + + def check(self): + + if self.is_input(): + assert len(self.prev_nodes) == 0, f'{self}' + my_assert(len(self.next_nodes) > 0, NoOutputError(self)) + elif self.is_output(): + my_assert(len(self.prev_nodes) > 0, NoInputError(self)) + assert len(self.next_nodes) == 0, f'{self}' else: - raise NotImplementedError(f'{self}') + my_assert(len(self.prev_nodes) > 0, NoInputError(self)) + my_assert(len(self.next_nodes) > 0, NoOutputError(self)) MODULENODE = TypeVar('MODULENODE', bound=ModuleNode) @@ -235,7 +192,18 @@ def init_from_backward_tracer( def init_from_fx_tracer(model: Module, fx_tracer={'type': 'RazorFxTracer'}): """init module graph using torch fx tracer.""" - pass + if isinstance(fx_tracer, dict): + tracer: RazorFxTracer = TASK_UTILS.build(fx_tracer) + else: + tracer = fx_tracer + + base_graph = tracer.trace(model) + + converter = FxTracerToGraphConverter(base_graph, model) + + converter.graph._model = model + converter.graph.refresh_module_name() + return converter.graph @staticmethod def init_from_model(model: Module): @@ -243,34 +211,6 @@ def init_from_model(model: Module): the relation among modules.""" pass - # check - - def check(self): - """Check if the graph is valid.""" - for node in self: - node.check_channel() - node.check_type() - - # static method for models that can't use tracer - - @staticmethod - def connect_module(pre: Module, next: Module): - """This function is used to write hardcode in modules to generate Graph - object using init_from_model.""" - if hasattr(pre, '_next'): - _next = getattr(pre, '_next') - assert isinstance(_next, List) - else: - pre._next = set() - pre._next.add(next) - - if hasattr(next, '_pre'): - _pre = getattr(next, '_pre') - assert isinstance(_pre, List) - else: - next._pre = set() - next._pre.add(pre) - # others def refresh_module_name(self): module2name = {} @@ -281,6 +221,42 @@ def refresh_module_name(self): if isinstance(node.val, nn.Module): node.module_name = module2name[node.val] + def check(self, fix=False): + for node in copy.copy(list(self.topo_traverse())): + self._check(node, fix=fix) + + def _check(self, node, fix=False): + try: + node.check() + except Exception as e: + if not fix: + raise e + else: + try: + raise e + except NoOutputError as e: + MMLogger.get_current_instance().warn( + f'add a output after {node}, error: {e}') + self._add_output_after(node) + except NoInputError as e: + MMLogger.get_current_instance().warn( + f'add a input before {node}, error: {e}') + self._add_input_before(node) + + self._check(node, fix=True) + + def _add_input_before(self, node: MODULENODE): + input_node: MODULENODE = ModuleNode( + 'auto_input', 'input_placeholder') # type: ignore + input_node = self.add_or_find_node(input_node) + self.connect(input_node, node) + + def _add_output_after(self, node: MODULENODE): + output_node: MODULENODE = ModuleNode( + 'auto_output', 'output_placeholder') # type: ignore + output_node = self.add_or_find_node(output_node) + self.connect(node, output_node) + # Converter @@ -314,7 +290,7 @@ def _new_placeholder_node(self, type: str, expand_ratio=1): self.bind_placeholder_num += 1 else: pass - node = ModuleNode(f'{type}_{num}', type, expand_ratio=expand_ratio) + node = ModuleNode(f'{type}_{num}', type) self.graph.add_or_find_node(node) return node @@ -349,7 +325,8 @@ def _insert_pass_nodes(self): if len(node.prev_nodes) == 1: pre: ModuleNode = node.prev_nodes[0] if node.in_channels != pre.out_channels: - assert node.in_channels % pre.out_channels == 0 + assert node.in_channels % pre.out_channels == 0, \ + f'{node.name} channel error' pass_node = self._new_placeholder_node( 'pass_placeholder', node.in_channels // pre.out_channels) @@ -393,9 +370,8 @@ def _topo_rename(self): # other def _post_process(self): """Some post process after init a basic module graph.""" - self._remove_redundant_pass_nodes() + # self._remove_redundant_pass_nodes() self._insert_bind_nodes() - self._insert_pass_nodes() self._topo_rename() @@ -415,7 +391,8 @@ def __init__(self, path_list: PathList, model: Module) -> None: self.name2module = dict(model.named_modules()) self._parse(self.path_list) - self._post_process() + self._insert_bind_nodes() + self._topo_rename() def _parse(self, path_list: PathList): """Parse path list.""" @@ -494,3 +471,45 @@ def _connect_nexts(self, node, nexts: List[ModuleNode]): """Connext the node and the nodes in nexts.""" for next in nexts: self.graph.connect(node, next) + + +class FxTracerToGraphConverter(GraphConverter): + """Use fx tracer to parse model, and generate module-graph.""" + + def __init__(self, base_graph, model=None) -> None: + """ + Args: + model (Module): the model which will be parsed + is_extra_leaf_module (Callable): a function used to determine, + if a module is a leaf module except torch pre-defined modules + """ + super().__init__(model) + self.base_graph = base_graph + self._convert_graph() + + def _node_converter(self, node: FxBaseNode): + """Convert a fxnode to a module-node.""" + if node.is_function(): + val = node.function() + elif node.is_input(): + val = 'input_placeholder' + elif node.is_output(): + val = 'output_placeholder' + elif node.is_method(): + val = node.method() + elif node.is_get_attr(): + val = 'get_attr' + elif node.is_module(): + val = node.module() + else: + raise NotImplementedError(f'{node} is unsupported') + + new_node = ModuleNode(node.name, val) + return new_node + + def _convert_graph(self): + """Convert a torch-graph to a module-graph.""" + base_graph = self.base_graph + # copy_nodes and connect + module_graph = ModuleGraph.copy_from(base_graph, self._node_converter) + self.graph = module_graph diff --git a/mmrazor/utils/__init__.py b/mmrazor/utils/__init__.py index 8490e8eef..d1d69c10e 100644 --- a/mmrazor/utils/__init__.py +++ b/mmrazor/utils/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .demo_inputs import demo_inputs from .index_dict import IndexDict from .misc import find_latest_checkpoint from .placeholder import get_placeholder @@ -11,5 +12,5 @@ 'find_latest_checkpoint', 'setup_multi_processes', 'register_all_modules', 'FixMutable', 'ValidFixMutable', 'SingleMutatorRandomSubnet', 'MultiMutatorsRandomSubnet', 'SupportRandomSubnet', 'get_placeholder', - 'IndexDict' + 'IndexDict', 'demo_inputs' ] diff --git a/mmrazor/utils/demo_inputs/__init__.py b/mmrazor/utils/demo_inputs/__init__.py new file mode 100644 index 000000000..69c45edc0 --- /dev/null +++ b/mmrazor/utils/demo_inputs/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .demo_input import demo_inputs + +__all__ = ['demo_inputs'] diff --git a/mmrazor/utils/demo_inputs/demo_input.py b/mmrazor/utils/demo_inputs/demo_input.py new file mode 100644 index 000000000..63d55b6aa --- /dev/null +++ b/mmrazor/utils/demo_inputs/demo_input.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.model import BaseModel + +from ..placeholder import get_placeholder +from .mmcls_demo_input import mmcls_demo_input + +try: + from mmdet.models import BaseDetector +except Exception: + BaseDetector = get_placeholder('mmdet') + +try: + from mmcls.models import ImageClassifier +except Exception: + ImageClassifier = get_placeholder('mmcls') + +try: + from mmseg.models import BaseSegmentor +except Exception: + BaseSegmentor = get_placeholder('mmseg') + + +def default_mm_concrete_args(model, input_shape): + x = torch.rand(input_shape) + return {'inputs': x, 'mode': 'tensor'} + + +def default_concrete_args(model, input_shape): + x = torch.rand(input_shape) + return x + + +def seg_concrete_args(model, input_shape): + assert isinstance(model, BaseSegmentor) + from .mmseg_demo_input import demo_mmseg_inputs + data = demo_mmseg_inputs(model, input_shape) + data['mode'] = 'tensor' + return data + + +def det_concrete_args(model, input_shape): + assert isinstance(model, BaseDetector) + from mmdet.testing._utils import demo_mm_inputs + data = demo_mm_inputs(1, [input_shape[1:]]) + data = model.data_preprocessor(data, False) + data['mode'] = 'tensor' + return data + + +default_concrete_args_fun = { + BaseDetector: det_concrete_args, + ImageClassifier: mmcls_demo_input, + BaseSegmentor: seg_concrete_args, + BaseModel: default_mm_concrete_args, + nn.Module: default_concrete_args +} + + +def demo_inputs(model, input_shape): + for module_type, concrete_args_fun in default_concrete_args_fun.items( # noqa + ): # noqa + if isinstance(model, module_type): + return concrete_args_fun(model, input_shape) + # default + x = torch.rand(input_shape) + return x diff --git a/mmrazor/utils/demo_inputs/mmcls_demo_input.py b/mmrazor/utils/demo_inputs/mmcls_demo_input.py new file mode 100644 index 000000000..a828e99c0 --- /dev/null +++ b/mmrazor/utils/demo_inputs/mmcls_demo_input.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from ..placeholder import get_placeholder + +try: + from mmcls.models import ImageClassifier + from mmcls.structures import ClsDataSample +except ImportError: + ImageClassifier = get_placeholder('mmcls') + ClsDataSample = get_placeholder('mmcls') + + +def mmcls_demo_input(model: ImageClassifier, input_shape: tuple): + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): + input batch dimensions + num_classes (int): + number of semantic classes + """ + x = torch.rand(input_shape) + mm_inputs = { + 'inputs': + x, + 'data_samples': + [ClsDataSample().set_gt_label(1) for _ in range(input_shape[0])], + } + mm_inputs['mode'] = 'tensor' + return mm_inputs diff --git a/mmrazor/utils/demo_inputs/mmseg_demo_input.py b/mmrazor/utils/demo_inputs/mmseg_demo_input.py new file mode 100644 index 000000000..e7a05952a --- /dev/null +++ b/mmrazor/utils/demo_inputs/mmseg_demo_input.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.structures import PixelData +from torch import nn + +from ..placeholder import get_placeholder + +try: + from mmseg.models import SegDataPreProcessor + from mmseg.structures import SegDataSample +except ImportError: + SegDataPreProcessor = get_placeholder('mmseg') + SegDataSample = get_placeholder('mmseg') + + +def demo_mmseg_inputs(segmentor, input_shape, for_training=False): + + if isinstance(segmentor.decode_head, nn.ModuleList): + num_classes = segmentor.decode_head[-1].num_classes + else: + num_classes = segmentor.decode_head.num_classes + # batch_size=2 for BatchNorm + mm_inputs = _demo_mmseg_inputs( + num_classes=num_classes, input_shape=input_shape) + + # convert to cuda Tensor if applicabled + # if torch.cuda.is_available(): + # segmentor = segmentor.cuda() + + # check data preprocessor + if not hasattr(segmentor, + 'data_preprocessor') or segmentor.data_preprocessor is None: + segmentor.data_preprocessor = SegDataPreProcessor() + + mm_inputs = segmentor.data_preprocessor(mm_inputs, for_training) + + return mm_inputs + + +def _demo_mmseg_inputs(input_shape=(1, 3, 8, 16), num_classes=10): + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): + input batch dimensions + + num_classes (int): + number of semantic classes + """ + (N, C, H, W) = input_shape + + imgs = torch.randn(*input_shape) + segs = torch.randint( + low=0, high=num_classes - 1, size=(N, H, W), dtype=torch.long) + + img_metas = [{ + 'img_shape': (H, W), + 'ori_shape': (H, W), + 'pad_shape': (H, W, C), + 'filename': '.png', + 'scale_factor': 1.0, + 'flip': False, + 'flip_direction': 'horizontal' + } for _ in range(N)] + + data_samples = [ + SegDataSample( + gt_sem_seg=PixelData(data=segs[i]), metainfo=img_metas[i]) + for i in range(N) + ] + + mm_inputs = { + 'inputs': torch.FloatTensor(imgs), + 'data_samples': data_samples + } + + return mm_inputs diff --git a/tests/__init__.py b/tests/__init__.py index ef101fec6..7e93bf1a9 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1,19 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .test_core.test_graph.test_graph import TestGraph # isort:skip +from .test_core.test_graph.test_channel_graph import TestChannelGraph +from .test_core.test_tracer.test_backward_tracer import TestBackwardTracer +from .test_data import TestModelLibrary +from .test_models.test_algorithms.test_autoslim import TestAutoSlim +from .test_models.test_algorithms.test_prune_algorithm import \ + TestItePruneAlgorithm +from .test_models.test_algorithms.test_slimmable_network import ( + TestSlimmable, TestSlimmableDDP) +from .test_models.test_mutables.test_mutable_channel.test_units.test_mutable_channel_units import \ + TestMutableChannelUnit # noqa: E501 +from .test_models.test_mutators.test_channel_mutator import TestChannelMutator + +__all__ = [ + 'TestGraph', 'TestMutableChannelUnit', 'TestChannelMutator', + 'TestBackwardTracer', 'TestItePruneAlgorithm', 'TestAutoSlim', + 'TestSlimmable', 'TestSlimmableDDP', 'TestChannelGraph', 'TestModelLibrary' +] diff --git a/tests/data/MBV2_slimmable_config.json b/tests/data/MBV2_slimmable_config.json index f63029872..9010b83e2 100644 --- a/tests/data/MBV2_slimmable_config.json +++ b/tests/data/MBV2_slimmable_config.json @@ -373,20 +373,5 @@ "choice_mode": "number" }, "choice": 1920 - }, - "head.fc_(0, 1000)_1000": { - "init_args": { - "num_channels": 1000, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 1000, - 1000, - 1000 - ], - "choice_mode": "number" - }, - "choice": 1000 } } \ No newline at end of file diff --git a/tests/data/model_library.py b/tests/data/model_library.py new file mode 100644 index 000000000..6119f0c0f --- /dev/null +++ b/tests/data/model_library.py @@ -0,0 +1,596 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List +from typing import Dict, Callable +from mmrazor.registry import MODELS +from mmengine.config import Config +import os +from mmengine.utils import get_installed_path +from mmrazor.registry import MODELS +import torch +import torch.nn as nn +from .models import (AddCatModel, ConcatModel, ConvAttnModel, DwConvModel, + ExpandLineModel, GroupWiseConvModel, SingleLineModel, + MultiBindModel, MultiConcatModel, MultiConcatModel2, + ResBlock, Xmodel, MultipleUseModel, Icep, SelfAttention) +import json +# model generator +from mmdet.testing._utils import demo_mm_inputs + +# helper functions + + +def get_shape(tensor, only_length=False): + if isinstance(tensor, torch.Tensor): + if only_length: + return len(tensor.shape) + else: + return tensor.shape + elif isinstance(tensor, list) or isinstance(tensor, tuple): + shapes = [] + for x in tensor: + shapes.append(get_shape(x, only_length)) + return shapes + elif isinstance(tensor, dict): + shapes = {} + for key in tensor: + shapes[key] = get_shape(tensor[key], only_length) + return shapes + else: + raise NotImplementedError( + f'unsuppored type{type(tensor)} to get shape of tensors.') + + +# generators + + +class ModelGenerator(nn.Module): + + def __init__(self, name: str, model_src) -> None: + super().__init__() + self.name = name + self.model_src = model_src + self._model = None + + def __call__(self, *args, **kwargs): + if len(args) == 0 and len(kwargs) == 0: + self.init_model() + return self + else: + return self.forward(*args, **kwargs) + + def init_model(self): + self._model = self.model_src() + + def forward(self, x): + assert self._model is not None + return self._model(x, *self.input()) + + def input(self): + return [] + + def assert_model_is_changed(self, tensors_org, tensors_new): + shape1 = get_shape(tensors_org) + shape2 = get_shape(tensors_new) + assert shape1 == shape2, f'{shape1}!={shape2}' + + def __repr__(self) -> str: + return self.name + + @classmethod + def get_base_name(cls, name: str): + names = name.split('.') + return '.'.join(names[1:]) + + @classmethod + def get_short_name(cls, name: str): + base_name = cls.get_base_name(name) + names = base_name.replace('-', '.').replace('_', '.').split('.') + return names[0] + + @property + def short_name(self): + return self.__class__.get_short_name(self.name) + + +class MMModelGenerator(ModelGenerator): + + def __init__(self, name, cfg) -> None: + self.cfg = cfg + super().__init__(name, self.get_model_src) + + def get_model_src(self): + model = MODELS.build(self.cfg) + model = revert_sync_batchnorm(model) + return model + + def __repr__(self) -> str: + return self.name + + +class MMDetModelGenerator(MMModelGenerator): + + def forward(self, x): + assert self._model is not None + self._model.eval() + return self._model(x, **self.input(), mode='tensor') + + def input(self): + data = demo_mm_inputs(1, [[3, 224, 224]]) + data = self._model.data_preprocessor(data, False) + data.pop('inputs') + return data + + def assert_model_is_changed(self, tensors_org, tensors_new): + assert get_shape(tensors_org, True) == get_shape(tensors_new, True) + + +# model library + + +class ModelLibrary: + default_includes: List = [] + _models = None + + def __init__(self, include=default_includes, exclude=[]) -> None: + self.include_key = include + self.exclude_key = exclude + self._include_models, self._uninclude_models, self.exclude_models =\ + self._classify_models(self.models) + + @property + def models(self): + if self.__class__._models is None: + self.__class__._models: Dict[ + str, Callable] = self.__class__.get_models() + return self.__class__._models + + @classmethod + def get_models(cls): + raise NotImplementedError() + + def include_models(self): + return self._include_models + + def uninclude_models(self): + return self._uninclude_models + + def is_include(self, name: str, includes: List[str], start_with=True): + for key in includes: + if start_with: + if name.startswith(key): + return True + else: + if key in name: + return True + return False + + def is_default_includes_cover_all_models(self): + models = copy.copy(self._models) + is_covered = True + for name in models: + if self.is_include(name, self.__class__.default_includes): + pass + else: + is_covered = False + print(name, '\tnot include') + return is_covered + + def short_names(self): + + short_names = set() + for name in self.models: + short_names.add(self.models[name].short_name) + return short_names + + def _classify_models(self, models: Dict): + include = [] + uninclude = [] + exclude = [] + for name in models: + if self.is_include(name, self.exclude_key, start_with=False): + exclude.append(models[name]) + elif self.is_include(name, self.include_key, start_with=True): + include.append(models[name]) + else: + uninclude.append(models[name]) + return include, uninclude, exclude + + def get_short_name_of_model(self, name: str): + names = name.replace('-', '.').replace('_', '.').split('.') + return names[0] + + +class DefaultModelLibrary(ModelLibrary): + + default_includes: List = [ + 'SingleLineModel', + 'ResBlock', + 'AddCatModel', + 'ConcatModel', + 'MultiConcatModel', + 'MultiConcatModel2', + 'GroupWiseConvModel', + 'Xmodel', + 'MultipleUseModel', + 'Icep', + 'ExpandLineModel', + 'MultiBindModel', + 'DwConvModel', + 'ConvAttnModel', + 'SelfAttention', + ] + + def __init__(self, include=default_includes, exclude=[]) -> None: + super().__init__(include, exclude) + + @classmethod + def get_models(cls): + models = [ + SingleLineModel, + ResBlock, + AddCatModel, + ConcatModel, + MultiConcatModel, + MultiConcatModel2, + GroupWiseConvModel, + Xmodel, + MultipleUseModel, + Icep, + ExpandLineModel, + MultiBindModel, + DwConvModel, # + ConvAttnModel, + SelfAttention, + ] + model_dict = {} + for model in models: + model_dict[model.__name__] = ModelGenerator( + 'default.' + model.__name__, model) + return model_dict + + +class TorchModelLibrary(ModelLibrary): + + default_includes = [ + 'alexnet', 'densenet', 'efficientnet', 'googlenet', 'inception', + 'mnasnet', 'mobilenet', 'regnet', 'resnet', 'resnext', 'shufflenet', + 'squeezenet', 'vgg', 'wide_resnet', "vit", "swin", "convnext" + ] + + def __init__(self, include=default_includes, exclude=[]) -> None: + super().__init__(include, exclude) + + @classmethod + def get_models(cls): + from inspect import isfunction + + import torchvision + + attrs = dir(torchvision.models) + models = {} + for name in attrs: + module = getattr(torchvision.models, name) + if isfunction(module) and name is not 'get_weight': + models[name] = ModelGenerator('torch.' + name, module) + return models + + +class MMModelLibrary(ModelLibrary): + default_includes = [] + base_config_path = '/' + repo = 'mmxx' + + def __init__(self, include=default_includes, exclude=[]) -> None: + super().__init__(include, exclude) + + @classmethod + def config_path(cls): + repo_path = get_installed_path(cls.repo) + path = repo_path + '/.mim/configs/' + cls.base_config_path + return path + + @classmethod + def get_models(cls): + models = {} + added_models = set() + for dirpath, dirnames, filenames in os.walk(cls.config_path()): + for filename in filenames: + if filename.endswith('.py'): + + cfg_path = dirpath + '/' + filename + try: + config = Config.fromfile(cfg_path) + except: + continue + if 'model' in config: + + # get model_name + model_type_name = '_'.join( + dirpath.replace(cls.config_path(), '').split('/')) + model_type_name = model_type_name if model_type_name == '' else model_type_name + '_' + model_name = model_type_name + \ + os.path.basename(filename).split('.')[0] + + model_cfg = config['model'] + model_cfg = cls._config_process(model_cfg) + if json.dumps(model_cfg) not in added_models: + models[model_name] = cls.generator_type()( + cls.repo + '.' + model_name, model_cfg) + added_models.add(json.dumps(model_cfg)) + return models + + @classmethod + def generator_type(cls): + return MMModelGenerator + + @classmethod + def _config_process(cls, config: Dict): + config['_scope_'] = cls.repo + config = cls._remove_certain_key(config, 'init_cfg') + config = cls._remove_certain_key(config, 'pretrained') + config = cls._remove_certain_key(config, 'Pretrained') + return config + + @classmethod + def _remove_certain_key(cls, config: Dict, key: str = 'init_cfg'): + if isinstance(config, dict): + if key in config: + config.pop(key) + for keyx in config: + config[keyx] = cls._remove_certain_key(config[keyx], key) + return config + + +class MMClsModelLibrary(MMModelLibrary): + + default_includes = [ + 'vgg', + 'efficientnet', + 'resnet', + 'mobilenet', + 'resnext', + 'wide-resnet', + 'shufflenet', + 'hrnet', + 'resnest', + 'inception', + 'res2net', + 'densenet', + 'convnext', + 'regnet', + 'van', + 'swin_transformer', + 'convmixer', + 't2t', + 'twins', + 'repmlp', + 'tnt', + 't2t', + 'mlp_mixer', + 'conformer', + 'poolformer', + 'vit', + 'efficientformer', + 'mobileone', + 'edgenext', + 'mvit', + 'seresnet', + 'repvgg', + 'seresnext', + 'deit' + ] + base_config_path = '_base_/models/' + repo = 'mmcls' + + def __init__(self, + include=default_includes, + exclude=['cutmix', 'cifar', 'gem']) -> None: + super().__init__(include=include, exclude=exclude) + + +class MMDetModelLibrary(MMModelLibrary): + + default_includes = [ + '_base', + 'gfl', + 'sparse', + 'simple', + 'pisa', + 'lvis', + 'carafe', + 'selfsup', + 'solo', + 'ssd', + 'res2net', + 'yolof', + 'reppoints', + 'htc', + 'groie', + 'dyhead', + 'grid', + 'soft', + 'swin', + 'regnet', + 'gcnet', + 'ddod', + 'instaboost', + 'point', + 'vfnet', + 'pafpn', + 'ghm', + 'mask', + 'resnest', + 'tood', + 'detectors', + 'cornernet', + 'convnext', + 'cascade', + 'paa', + 'detr', + 'rpn', + 'ld', + 'lad', + 'ms', + 'faster', + 'centripetalnet', + 'gn', + 'dcnv2', + 'legacy', + 'panoptic', + 'strong', + 'fpg', + 'deformable', + 'free', + 'scratch', + 'openimages', + 'fsaf', + 'rtmdet', + 'solov2', + 'yolact', + 'empirical', + 'centernet', + 'hrnet', + 'guided', + 'deepfashion', + 'fast', + 'mask2former', + 'retinanet', + 'autoassign', + 'gn+ws', + 'dcn', + 'yolo', + 'foveabox', + 'libra', + 'double', + 'queryinst', + 'resnet', + 'nas', + 'sabl', + 'fcos', + 'scnet', + 'maskformer', + 'pascal', + 'cityscapes', + 'timm', + 'seesaw', + 'pvt', + 'atss', + 'efficientnet', + 'wider', + 'tridentnet', + 'dynamic', + 'yolox', + 'albu', + ] + base_config_path = '/' + repo = 'mmdet' + + def __init__( + self, + include=default_includes, + exclude=[ + 'lad', + 'ld', + 'faster_rcnn_faster-rcnn_r50-caffe-c4_ms-1x_coco', + ] + ) -> None: + super().__init__(include=include, exclude=exclude) + + @classmethod + def _config_process(cls, config: Dict): + config = super()._config_process(config) + if 'preprocess_cfg' in config: + config.pop('preprocess_cfg') + return config + + @classmethod + def generator_type(cls): + return MMDetModelGenerator + + +class MMSegModelLibrary(MMModelLibrary): + default_includes: List = [ + '_base_', + 'knet', + 'sem', + 'dnlnet', + 'dmnet', + 'icnet', + 'apcnet', + 'swin', + 'isanet', + 'fastfcn', + 'poolformer', + 'mae', + 'segformer', + 'ccnet', + 'twins', + 'emanet', + 'upernet', + 'beit', + 'hrnet', + 'bisenetv2', + 'vit', + 'setr', + 'cgnet', + 'ocrnet', + 'ann', + 'erfnet', + 'point', + 'bisenetv1', + 'nonlocal', + 'unet', + 'danet', + 'stdc', + 'fcn', + 'encnet', + 'resnest', + 'mobilenet', + 'convnext', + 'deeplabv3', + 'pspnet', + 'gcnet', + 'fastscnn', + 'segmenter', + 'dpt', + 'deeplabv3plus', + 'psanet', + ] + base_config_path = '/' + repo = 'mmsegmentation' + + def __init__(self, include=default_includes, exclude=['_base_']) -> None: + super().__init__(include, exclude) + + @classmethod + def _config_process(cls, config: Dict): + config['_scope_'] = 'mmseg' + return config + + +# tools + + +def revert_sync_batchnorm(module): + # this is very similar to the function that it is trying to revert: + # https://github.com/pytorch/pytorch/blob/c8b3686a3e4ba63dc59e5dcfe5db3430df256833/torch/nn/modules/batchnorm.py#L679 + module_output = module + if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm): + new_cls = nn.BatchNorm2d + module_output = nn.BatchNorm2d(module.num_features, module.eps, + module.momentum, module.affine, + module.track_running_stats) + if module.affine: + with torch.no_grad(): + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + if hasattr(module, "qconfig"): + module_output.qconfig = module.qconfig + for name, child in module.named_children(): + module_output.add_module(name, revert_sync_batchnorm(child)) + del module + return module_output diff --git a/tests/data/models.py b/tests/data/models.py index 60c8a7058..39b6febda 100644 --- a/tests/data/models.py +++ b/tests/data/models.py @@ -1,19 +1,67 @@ # Copyright (c) OpenMMLab. All rights reserved. +# this file includes models for tesing. +import math from torch.nn import Module from torch import Tensor import torch.nn as nn +import torch.nn.functional as F import torch from mmrazor.models.architectures.dynamic_ops import DynamicBatchNorm2d, DynamicConv2d, DynamicLinear, DynamicChannelMixin from mmrazor.models.mutables.mutable_channel import MutableChannelContainer from mmrazor.models.mutables import MutableChannelUnit from mmrazor.models.mutables import DerivedMutable from mmrazor.models.mutables import BaseMutable -from mmrazor.models.mutables import OneShotMutableChannelUnit, SquentialMutableChannel, OneShotMutableChannel -from mmrazor.registry import MODELS -from mmengine.model import BaseModel +from mmrazor.models.mutables import OneShotMutableChannelUnit, OneShotMutableChannel # this file includes models for tesing. +class subnet(Module): + + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + add = torch.arange(x.shape[-1]).unsqueeze(0) + # add = torch.arange(1000).unsqueeze(0) + end = torch.add(x, add) + return end + + +class UnTracableModel(Module): + + def __init__(self) -> None: + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(), + nn.Conv2d(8, 16, 3, 1, 1), nn.BatchNorm2d(16), + nn.AdaptiveAvgPool2d(1)) + self.linear = nn.Linear(16, 1000) + self.end = subnet() + + def forward(self, x): + x1 = self.net(x) + x1 = x1.reshape([x1.shape[0], -1]) + logit = self.linear(x1) + return self.end(logit) + + +class ConvAttnModel(Module): + + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv2d(3, 8, 3, 1, 1) + self.pool = nn.AdaptiveAvgPool2d(1) + self.conv2 = nn.Conv2d(8, 16, 3, 1, 1) + self.head = LinearHead(16, 1000) + + def forward(self, x): + x1 = self.conv(x) + attn = F.sigmoid(self.pool(x1)) + x_attn = x1 * attn + x_last = self.conv2(x_attn) + return self.head(x_last) + + class LinearHead(Module): def __init__(self, in_channel, num_class=1000) -> None: @@ -185,7 +233,7 @@ def forward(self, x: Tensor) -> Tensor: return output -class LineModel(BaseModel): +class SingleLineModel(nn.Module): """ x |net0,net1 @@ -471,6 +519,48 @@ def forward(self, x): return self.head(self.net(x)) +class SelfAttention(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.stem = nn.Conv2d(3, 32, 4, 4, 4) + + self.num_head = 4 + self.qkv = nn.Linear(32, 32 * 3) + self.proj = nn.Linear(32, 32) + + self.head = LinearHead(32, 1000) + + def forward(self, x: torch.Tensor): + x = self.stem(x) + h, w = x.shape[-2:] + x = self._to_token(x) + x = x + self._forward_attention(x) + x = self._to_img(x, h, w) + return self.head(x) + + def _to_img(self, x, h, w): + x = x.reshape([x.shape[0], h, w, x.shape[2]]) + x = x.permute(0, 3, 1, 2) + return x + + def _to_token(self, x): + x = x.flatten(2).transpose(-1, -2) + return x + + def _forward_attention(self, x: torch.Tensor): + qkv = self.qkv(x) + qkv = qkv.reshape([ + x.shape[0], x.shape[1], 3, self.num_head, + x.shape[2] // self.num_head + ]).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv + attn = q @ k.transpose(-1, -2) / math.sqrt(32 // self.num_head) + y = attn @ v # B H N h + y = y.permute(0, 2, 1, 3).flatten(-2) + return self.proj(y) + + # models with dynamicop @@ -566,95 +656,3 @@ def _register_mutable(self): self.net[4], mutable2, True) MutableChannelContainer.register_mutable_channel_to_module( self.linear, mutable2, False) - - -default_models = [ - LineModel, - ResBlock, - AddCatModel, - ConcatModel, - MultiConcatModel, - MultiConcatModel2, - GroupWiseConvModel, - Xmodel, - MultipleUseModel, - Icep, - ExpandLineModel, - DwConvModel, -] - - -class ModelLibrary: - - # includes = [ - # 'alexnet', # pass - # 'densenet', # pass - # # 'efficientnet', # pass - # # 'googlenet', # pass. - # # googlenet return a tuple when training, - # # so it should trace in eval mode - # # 'inception', # failed - # # 'mnasnet', # pass - # # 'mobilenet', # pass - # # 'regnet', # failed - # # 'resnet', # pass - # # 'resnext', # failed - # # 'shufflenet', # failed - # # 'squeezenet', # pass - # # 'vgg', # pass - # # 'wide_resnet', # pass - # ] - - def __init__(self, include=[]) -> None: - - self.include_key = include - - self.model_creator = self.get_torch_models() - - def __repr__(self) -> str: - s = f'model: {len(self.model_creator)}\n' - for creator in self.model_creator: - s += creator.__name__ + '\n' - return s - - def get_torch_models(self): - from inspect import isfunction - - import torchvision - - attrs = dir(torchvision.models) - models = [] - for name in attrs: - module = getattr(torchvision.models, name) - if isfunction(module): - models.append(module) - return models - - def export_models(self): - models = [] - for creator in self.model_creator: - if self.is_include(creator.__name__): - models.append(creator) - return models - - def is_include(self, name): - for key in self.include_key: - if key in name: - return True - return False - - def include(self): - include = [] - for creator in self.model_creator: - for key in self.include_key: - if key in creator.__name__: - include.append(creator) - return include - - def uninclude(self): - include = self.include() - uninclude = [] - for creator in self.model_creator: - if creator not in include: - uninclude.append(creator) - return uninclude diff --git a/tests/data/tracer_passed_models.py b/tests/data/tracer_passed_models.py new file mode 100644 index 000000000..c07584d19 --- /dev/null +++ b/tests/data/tracer_passed_models.py @@ -0,0 +1,475 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .model_library import (MMClsModelLibrary, MMDetModelLibrary, ModelLibrary, + DefaultModelLibrary, TorchModelLibrary, + MMSegModelLibrary) + + +class PassedModelManager: + + def __init__(self) -> None: + pass + + def include_models(self, full_test=False): + models = [] + for library in self.libraries(full_test): + models.extend(library.include_models()) + return models + + def uninclude_models(self, full_test=False): + models = [] + for library in self.libraries(full_test): + models.extend(library.uninclude_models()) + return models + + def libraries(self, full=False): + return [] + + +class FxPassedModelManager(PassedModelManager): + + _default_library = None + _torch_library = None + _mmcls_library = None + _mmseg_library = None + _mmdet_library = None + + def libraries(self, full=False): + if full: + return [ + self.__class__.default_library(), + self.__class__.torch_library(), + self.__class__.mmcls_library(), + self.__class__.mmseg_library(), + self.__class__.mmdet_library(), + ] + else: + return [self.__class__.default_library()] + + @classmethod + def default_library(cls): + if cls._default_library is None: + cls._default_library = DefaultModelLibrary(include=[ + 'SingleLineModel', + 'ResBlock', + 'AddCatModel', + 'ConcatModel', + 'MultiConcatModel', + 'MultiConcatModel2', + 'GroupWiseConvModel', + 'Xmodel', + 'MultipleUseModel', + 'Icep', + 'ExpandLineModel', + 'MultiBindModel', + 'DwConvModel', + 'ConvAttnModel', + ]) + + return cls._default_library + + @classmethod + def torch_library(cls): + """ + googlenet: return a tuple when training, so it should + trace in eval mode + """ + torch_includes = [ + 'alexnet', + 'densenet', + 'efficientnet', + 'googlenet', + 'inception', + 'mnasnet', + 'mobilenet', + 'regnet', + 'resnet', + 'resnext', + 'squeezenet', + 'vgg', + 'wide_resnet', + "vit", + "swin", + "convnext", + # error + # 'shufflenet', # bug + ] + if cls._torch_library is None: + cls._torch_library = TorchModelLibrary(include=torch_includes) + return cls._torch_library + + @classmethod + def mmcls_library(cls): + """ + shufflenet consists of chunk operations. + resnest: resnest has two problems. First it uses *x.shape() which is + not tracerable using fx tracer. Second, it uses channel folding. + res2net: res2net consists of split operations. + convnext: consist of layernorm. + """ + mmcls_include = [ + 'vgg', + 'efficientnet', + 'resnet', + 'mobilenet', + 'resnext', + 'wide-resnet', + 'hrnet', + 'inception', + 'densenet', + 'regnet', + 'convmixer', + 'efficientformer', + 'mobileone', + 'edgenext', + 'seresnet', + 'repvgg', + 'seresnext', + 'conformer', + 'poolformer', + 'res2net', + 'resnest', + 'convnext', + # errors + # 'mvit', # error + # 'van', # bug + # 'twins', # bug + # 'tnt', # bug + # 'repmlp', # bug + # 't2t', # bug + # 'swin', # bug + # 'shufflenet', # bug + # 'vit', # bug + # 'mlp', # bug + ] + if cls._mmcls_library is None: + cls._mmcls_library = MMClsModelLibrary(include=mmcls_include) + return cls._mmcls_library + + @classmethod + def mmdet_library(cls): + mmdet_include = [ + 'retinanet', + 'faster_rcnn', + 'mask_rcnn', + 'fcos', + 'yolo', + 'gfl', + 'simple', + 'lvis', + 'selfsup', + 'solo', + 'soft', + 'instaboost', + 'point', + 'pafpn', + 'ghm', + 'paa', + 'rpn', + 'faster', + 'centripetalnet', + 'gn', + 'free', + 'scratch', + 'centernet', + 'deepfashion', + 'autoassign', + 'gn+ws', + 'foveabox', + 'resnet', + 'cityscapes', + 'timm', + 'atss', + 'dynamic', + 'panoptic', + 'solov2', + 'fsaf', + 'double', + 'cornernet', + # 'vfnet', # error + # 'carafe', # error + # 'sparse', # error + # '_base', # error + # 'ssd', # error + # 'res2net', # error + # 'reppoints', # error + # 'groie', # error + # 'dyhead', # error + # 'ms', # error + # 'detr', # error + # 'swin', # error + # 'regnet', # error + # 'gcnet', # error + # 'ddod', # error + # 'resnest', # error + # 'tood', # error + # 'cascade', # error + # 'dcnv2', # error + # 'strong', # error + # 'fpg', # error + # 'deformable', # error + # 'mask2former', # error + # 'hrnet', # error + # 'guided', # error + # 'nas', # error + # 'yolact', # error + # 'empirical', # error + # 'dcn', # error + # 'fast', # error + # 'queryinst', # error + # 'pascal', # error + # 'efficientnet', # error + # 'tridentnet', # error + # 'rtmdet', # error + # 'seesaw', # error + # 'pvt',# error + # 'detectors',# error + # 'htc',# error + # 'wider',# error + # 'maskformer',# error + # 'grid',# error + # 'openimages',# error + # 'legacy',# error + # 'pisa',# error + # 'libra',# error + # 'convnext',# error + # 'scnet',# error + # 'sabl',# error + ] + if cls._mmdet_library is None: + cls._mmdet_library = MMDetModelLibrary(mmdet_include) + return cls._mmdet_library + + @classmethod + def mmseg_library(cls): + # a common error: unet related models + include = [ + 'deeplabv3plus', + # '_base_', + # 'knet', + # 'sem', + # 'dnlnet', + # 'dmnet', + # 'icnet', + # 'apcnet', + # 'swin', + # 'isanet', + # 'fastfcn', + # 'poolformer', + # 'mae', + # 'segformer', + # 'ccnet', + # 'twins', + # 'emanet', + # 'upernet', + # 'beit', + # 'hrnet', + # 'bisenetv2', + # 'vit', + # 'setr', + # 'cgnet', + # 'ocrnet', + # 'ann', + # 'erfnet', + # 'point', + # 'bisenetv1', + # 'nonlocal', + # 'unet', + # 'danet', + # 'stdc', + # 'fcn', + # 'encnet', + # 'resnest', + # 'mobilenet', + # 'convnext', + # 'deeplabv3', + # 'pspnet', + # 'gcnet', + # 'fastscnn', + # 'segmenter', + # 'dpt', + # 'psanet', + ] + if cls._mmseg_library is None: + cls._mmseg_library = MMSegModelLibrary(include=include) + return cls._mmseg_library + + # for backward tracer + + +class BackwardPassedModelManager(PassedModelManager): + + _default_library = None + _torch_library = None + _mmcls_library = None + _mmseg_library = None + _mmdet_library = None + + def libraries(self, full=False): + if full: + return [ + self.__class__.default_library(), + self.__class__.torch_library(), + self.__class__.mmcls_library(), + self.__class__.mmseg_library(), + self.__class__.mmdet_library(), + ] + else: + return [self.__class__.default_library()] + + @classmethod + def default_library(cls): + if cls._default_library is None: + cls._default_library = DefaultModelLibrary(include=[ + 'SingleLineModel', + 'ResBlock', + 'AddCatModel', + 'ConcatModel', + 'MultiConcatModel', + 'MultiConcatModel2', + 'GroupWiseConvModel', + 'Xmodel', + # 'MultipleUseModel', # bug + 'Icep', + 'ExpandLineModel', + 'MultiBindModel', + 'DwConvModel', + 'ConvAttnModel', + ]) + return cls._default_library + + @classmethod + def torch_library(cls): + """ + googlenet return a tuple when training, so it + should trace in eval mode + """ + + torch_includes = [ + 'alexnet', + 'densenet', + 'efficientnet', + 'googlenet', + 'inception', + 'mnasnet', + 'mobilenet', + 'regnet', + 'resnet', + 'resnext', + # 'shufflenet', # bug + 'squeezenet', + 'vgg', + 'wide_resnet', + # "vit", + # "swin", + # "convnext" + ] + if cls._torch_library is None: + cls._torch_library = TorchModelLibrary(include=torch_includes) + return cls._torch_library + + @classmethod + def mmcls_library(cls): + """ + shufflenet consists of chunk operations. + resnest: resnest has two problems. First it uses *x.shape() which is + not tracerable using fx tracer. Second, it uses channel folding. + res2net: res2net consists of split operations. + convnext: consist of layernorm. + """ + mmcls_model_include = [ + 'vgg', + 'efficientnet', + 'resnet', + 'mobilenet', + 'resnext', + 'wide-resnet', + # 'shufflenet', # bug + 'hrnet', + # 'resnest', # bug + 'inception', + # 'res2net', # bug + 'densenet', + # 'convnext', # bug + 'regnet', + # 'van', # bug + # 'swin_transformer', # bug + # 'convmixer', # bug + # 't2t', # bug + # 'twins', # bug + # 'repmlp', # bug + # 'tnt', # bug + # 't2t', # bug + # 'mlp_mixer', # bug + # 'conformer', # bug + # 'poolformer', # bug + # 'vit', # bug + # 'efficientformer', + # 'mobileone', + # 'edgenext' + ] + mmcls_exclude = ['cutmix', 'cifar', 'gem'] + if cls._mmcls_library is None: + cls._mmcls_library = MMClsModelLibrary( + include=mmcls_model_include, exclude=mmcls_exclude) + return cls._mmcls_library + + @classmethod + def mmdet_library(cls): + mmdet_include = [ + # 'rpn', # + # 'faster-rcnn', + # 'cascade-rcnn', + # 'fast-rcnn', # mmdet has bug. + # 'retinanet', + # 'mask-rcnn', + # 'ssd300' + ] + if cls._mmdet_library is None: + cls._mmdet_library = MMDetModelLibrary(mmdet_include) + return cls._mmdet_library + + @classmethod + def mmseg_library(cls): + include = [ + # 'cgnet', + # 'gcnet', + # 'setr', + # 'deeplabv3', + # 'twins', + # 'fastfcn', + # 'fpn', + # 'upernet', + # 'dnl', + # 'icnet', + # 'segmenter', + # 'encnet', + # 'erfnet', + # 'segformer', + # 'apcnet', + # 'fast', + # 'ocrnet', + # 'lraspp', + # 'dpt', + # 'fcn', + # 'psanet', + # 'bisenetv2', + # 'pointrend', + # 'ccnet', + 'pspnet', + # 'dmnet', + # 'stdc', + # 'ann', + # 'nonlocal', + # 'isanet', + # 'danet', + # 'emanet', + # 'deeplabv3plus', + # 'bisenetv1', + ] + if cls._mmseg_library is None: + cls._mmseg_library = MMSegModelLibrary(include=include) + return cls._mmseg_library + + +fx_passed_library = FxPassedModelManager() +backward_passed_library = BackwardPassedModelManager() diff --git a/tests/test_core/test_graph/test_channel_flow.py b/tests/test_core/test_graph/test_channel_flow.py new file mode 100644 index 000000000..87dee6747 --- /dev/null +++ b/tests/test_core/test_graph/test_channel_flow.py @@ -0,0 +1,80 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +from mmrazor.structures.graph.channel_flow import ChannelElem, ChannelTensor + + +class TestChannelTensor(unittest.TestCase): + + def test_union(self): + tensor1 = ChannelTensor(8) + tensor2 = ChannelTensor(8) + tensor3 = ChannelTensor(8) + tensor4 = ChannelTensor(8) + + ChannelTensor.union_two(tensor1, tensor2) + ChannelTensor.union_two(tensor3, tensor4) + self.assertUionedTensor(tensor1, tensor2) + self.assertUionedTensor(tensor3, tensor4) + + ChannelTensor.union_two(tensor1, tensor4) + + self.assertUionedTensor(tensor1, tensor2) + self.assertUionedTensor(tensor2, tensor3) + self.assertUionedTensor(tensor3, tensor4) + self.assertUionedTensor(tensor1, tensor4) + + def test_cat(self): + tensor1 = ChannelTensor(8) + tensor2 = ChannelTensor(8) + tensor3 = ChannelTensor(16) + + tensor_cat = ChannelTensor.cat([tensor1, tensor2]) + self.assertEqual(len(tensor_cat), 16) + ChannelTensor.union_two(tensor_cat, tensor3) + + tensor31 = tensor3[:8] + tensor32 = tensor3[8:] + self.assertUionedTensor(tensor1, tensor31) + self.assertUionedTensor(tensor2, tensor32) + + def test_add_cat(self): + """8+8 && 4+12 -> 4+4+8.""" + tensor1 = ChannelTensor(8) + tensor2 = ChannelTensor(8) + tensor_cat1 = ChannelTensor.cat([tensor1, tensor2]) + + tensor3 = ChannelTensor(4) + tensor4 = ChannelTensor(12) + tensor_cat2 = ChannelTensor.cat([tensor3, tensor4]) + + ChannelTensor.union_two(tensor_cat1, tensor_cat2) + self.assertUionedTensor(tensor_cat1, tensor_cat2) + + self.assertUionedTensor(tensor_cat1[0:4], tensor3[0:4]) + self.assertUionedTensor(tensor_cat1[4:8], tensor4[0:4]) + self.assertUionedTensor(tensor_cat1[8:16], tensor4[4:12]) + + self.assertUionedTensor(tensor_cat2[0:4], tensor1[0:4]) + self.assertUionedTensor(tensor_cat2[4:8], tensor1[4:8]) + self.assertUionedTensor(tensor_cat2[8:], tensor2) + + def assertUionedTensor(self, tensor1: ChannelTensor, + tensor2: ChannelTensor): + assert len(tensor1) == len(tensor2) + for e1, e2 in zip(tensor1, tensor2): + self.assertEqual(e1.root, e2.root) + + +class TestChannelElem(unittest.TestCase): + + def test_union(self): + tensor = ChannelTensor(10) + elem1 = tensor[1] + elem2 = tensor[2] + ChannelElem.union_two(elem1, elem2) + self.assertEqual(elem1.root, elem2.root) + + elem3 = tensor[3] + ChannelElem.union_two(elem2, elem3) + self.assertEqual(elem1.root, elem3.root) diff --git a/tests/test_core/test_graph/test_channel_graph.py b/tests/test_core/test_graph/test_channel_graph.py index 6eb3e1454..d6f1c3ffa 100644 --- a/tests/test_core/test_graph/test_channel_graph.py +++ b/tests/test_core/test_graph/test_channel_graph.py @@ -8,12 +8,9 @@ from mmrazor.registry import TASK_UTILS from mmrazor.structures.graph import ModuleGraph from mmrazor.structures.graph.channel_graph import ChannelGraph -from mmrazor.structures.graph.channel_modules import (BaseChannelUnit, - ChannelTensor) from mmrazor.structures.graph.channel_nodes import \ default_channel_node_converter -from ...data.models import LineModel -from .test_graph import TestGraph +from ...data.models import SingleLineModel NodeMap = {} @@ -32,23 +29,25 @@ def __call__(self, model) -> torch.Tensor: class TestChannelGraph(unittest.TestCase): def test_init(self): - model = LineModel() + model = SingleLineModel() module_graph = ModuleGraph.init_from_backward_tracer(model) _ = ChannelGraph.copy_from(module_graph, default_channel_node_converter) - def test_forward(self): - for model_data in TestGraph.backward_tracer_passed_models(): - with self.subTest(model=model_data): - model = model_data() - module_graph = ModuleGraph.init_from_backward_tracer(model) + # def test_forward(self): + # for model_data in BackwardPassedModelManager.include_models( # noqa + # ): # noqa + # with self.subTest(model=model_data): + # model = model_data() + # module_graph = ModuleGraph.init_from_backward_tracer(model) - channel_graph = ChannelGraph.copy_from( - module_graph, default_channel_node_converter) - channel_graph.forward() + # channel_graph = ChannelGraph.copy_from( + # module_graph, default_channel_node_converter) + # channel_graph.forward() - _ = channel_graph.collect_units + # # units = channel_graph.collect_units() + # _ = channel_graph.generate_units_config() def test_forward_with_config_num_in_channel(self): @@ -57,7 +56,7 @@ class MyModel(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(6, 3, 3, 1, 1) - self.net = LineModel() + self.net = SingleLineModel() def forward(self, x): return self.net(self.conv1(x)) @@ -72,107 +71,4 @@ def forward(self, x): default_channel_node_converter) channel_graph.forward(num_input_channel=6) - _ = channel_graph.collect_units - - -class TestChannelUnit(unittest.TestCase): - - def test_union(self): - channel_tensor1 = ChannelTensor(8) - channel_tensor2 = ChannelTensor(8) - channel_tensor3 = ChannelTensor(8) - channel_tensor4 = ChannelTensor(8) - unit1 = channel_tensor1.unit_dict[(0, 8)] - unit2 = channel_tensor2.unit_dict[(0, 8)] - unit3 = channel_tensor3.unit_dict[(0, 8)] - unit4 = channel_tensor4.unit_dict[(0, 8)] - - unit12 = BaseChannelUnit.union_two_units(unit1, unit2) - self.assertDictEqual(channel_tensor1.unit_dict, - channel_tensor2.unit_dict) - - unit34 = BaseChannelUnit.union_two_units(unit3, unit4) - BaseChannelUnit.union_two_units(unit12, unit34) - self.assertDictEqual(channel_tensor1.unit_dict, - channel_tensor4.unit_dict) - - def test_split(self): - channel_tensor1 = ChannelTensor(8) - channel_tensor2 = ChannelTensor(8) - BaseChannelUnit.union_two_units(channel_tensor1.unit_dict[(0, 8)], - channel_tensor2.unit_dict[(0, 8)]) - unit1 = channel_tensor1.unit_dict[(0, 8)] - BaseChannelUnit.split_unit(unit1, [2, 6]) - - self.assertDictEqual(channel_tensor1.unit_dict, - channel_tensor2.unit_dict) - - -class TestChannelTensor(unittest.TestCase): - - def test_init(self): - channel_tensor = ChannelTensor(8) - self.assertIn((0, 8), channel_tensor.unit_dict) - - def test_align_with_nums(self): - channel_tensor = ChannelTensor(8) - channel_tensor.align_units_with_nums([2, 6]) - self.assertSequenceEqual( - list(channel_tensor.unit_dict.keys()), [(0, 2), (2, 8)]) - channel_tensor.align_units_with_nums([2, 2, 4]) - self.assertSequenceEqual( - list(channel_tensor.unit_dict.keys()), [(0, 2), (2, 4), (4, 8)]) - - def test_align_units(self): - channel_tensor1 = ChannelTensor(8) - channel_tensor2 = ChannelTensor(8) - channel_tensor3 = ChannelTensor(8) - - BaseChannelUnit.split_unit(channel_tensor1.unit_list[0], [2, 6]) - BaseChannelUnit.split_unit(channel_tensor2.unit_list[0], [4, 4]) - BaseChannelUnit.split_unit(channel_tensor3.unit_list[0], [6, 2]) - """ - xxoooooo - xxxxoooo - xxxxxxoo - """ - - ChannelTensor.align_tensors(channel_tensor1, channel_tensor2, - channel_tensor3) - for lst in [channel_tensor1, channel_tensor2, channel_tensor3]: - self.assertSequenceEqual( - list(lst.unit_dict.keys()), [ - (0, 2), - (2, 4), - (4, 6), - (6, 8), - ]) - - def test_expand(self): - channel_tensor = ChannelTensor(8) - expanded = channel_tensor.expand(4) - self.assertIn((0, 32), expanded.unit_dict) - - def test_union(self): - channel_tensor1 = ChannelTensor(8) - channel_tensor2 = ChannelTensor(8) - channel_tensor3 = ChannelTensor(8) - channel_tensor4 = ChannelTensor(8) - channel_tensor3.union(channel_tensor4) - - self.assertEqual( - id(channel_tensor3.unit_dict[(0, 8)]), - id(channel_tensor4.unit_dict[(0, 8)])) - - channel_tensor2.union(channel_tensor3) - channel_tensor1.union(channel_tensor2) - - self.assertEqual( - id(channel_tensor1.unit_dict[(0, 8)]), - id(channel_tensor2.unit_dict[(0, 8)])) - self.assertEqual( - id(channel_tensor2.unit_dict[(0, 8)]), - id(channel_tensor3.unit_dict[(0, 8)])) - self.assertEqual( - id(channel_tensor3.unit_dict[(0, 8)]), - id(channel_tensor4.unit_dict[(0, 8)])) + _ = channel_graph.generate_units_config diff --git a/tests/test_core/test_graph/test_graph.py b/tests/test_core/test_graph/test_graph.py index 1383dccd8..14df464c8 100644 --- a/tests/test_core/test_graph/test_graph.py +++ b/tests/test_core/test_graph/test_graph.py @@ -1,99 +1,31 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os import sys from unittest import TestCase import torch -from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin -from mmrazor.structures.graph import ModuleGraph -from ...data.models import Icep # noqa -from ...data.models import MultipleUseModel # noqa -from ...data.models import Xmodel # noqa -from ...data.models import (AddCatModel, ConcatModel, DwConvModel, - ExpandLineModel, GroupWiseConvModel, LineModel, - ModelLibrary, MultiBindModel, MultiConcatModel, - MultiConcatModel2, ResBlock) - -FULL_TEST = os.getenv('FULL_TEST') == 'true' - sys.setrecursionlimit(int(1e8)) - -def is_dynamic_op_fx(module, name): - return isinstance(module, DynamicChannelMixin) - - -class ToyCNNPseudoLoss: - - def __call__(self, model): - pseudo_img = torch.rand(2, 3, 16, 16) - pseudo_output = model(pseudo_img) - return pseudo_output.sum() +DEVICE = torch.device('cpu') class TestGraph(TestCase): - - @classmethod - def backward_tracer_passed_models(cls): - '''MultipleUseModel: backward tracer can't distinguish multiple use and - first bind then use.''' - default_models = [ - LineModel, - ResBlock, - AddCatModel, - ConcatModel, - MultiConcatModel, - MultiConcatModel2, - GroupWiseConvModel, - Xmodel, - # MultipleUseModel, # bug - # Icep, bug - ExpandLineModel, - MultiBindModel, - DwConvModel - ] - """ - googlenet return a tuple when training, so it - should trace in eval mode - """ - - torch_models_includes = [ - 'alexnet', - 'densenet', - 'efficientnet', - 'googlenet', - # 'inception', bug - 'mnasnet', - 'mobilenet', - 'regnet', - 'resnet', - 'resnext', - # 'shufflenet', # bug - 'squeezenet', - 'vgg', - 'wide_resnet', - ] - model_library = ModelLibrary(torch_models_includes) - - models = default_models + model_library.export_models( - ) if FULL_TEST else default_models - return models - - def test_init_from_backward_tracer(self) -> None: - TestData = self.backward_tracer_passed_models() - - for data in TestData: - with self.subTest(data=data): - model = data() - model.eval() - graph = ModuleGraph.init_from_backward_tracer(model) - - # check channels - self._valid_graph(graph) - - def _valid_graph(self, graph: ModuleGraph): - try: - graph.check() - except Exception as e: - self.fail(str(e) + '\n' + str(graph)) + pass + # def test_init_from_fx_tracer(self) -> None: + # TestData = BackwardPassedModelManager.include_models() + # with SetTorchThread(1): + # with mp.Pool() as p: + # result = p.map(_test_init_from_fx_tracer, TestData) + # for res, model in zip(result, TestData): + # with self.subTest(model=model): + # self.assertTrue(res[0], res[1]) + + # def test_init_from_backward_tracer(self) -> None: + # TestData = FxPassedModelManager.include_models() + # with SetTorchThread(1) as _: + # with mp.Pool() as p: + # result = p.map(_test_init_from_backward_tracer, TestData) + # for res, model in zip(result, TestData): + # # test_init_from_backward_tracer(model) + # with self.subTest(model=model): + # self.assertTrue(res[0], res[1]) diff --git a/tests/test_core/test_graph/test_tracer_model.py b/tests/test_core/test_graph/test_tracer_model.py new file mode 100644 index 000000000..3eef9befe --- /dev/null +++ b/tests/test_core/test_graph/test_tracer_model.py @@ -0,0 +1,333 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import multiprocessing as mp +import os +import signal +import sys +import time +from contextlib import contextmanager +from functools import partial +from typing import List +from unittest import TestCase + +import torch +import torch.nn as nn + +from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin +from mmrazor.models.mutables.mutable_channel.units import \ + SequentialMutableChannelUnit +from mmrazor.models.task_modules.tracer.backward_tracer import BackwardTracer +from mmrazor.models.task_modules.tracer.fx_tracer import CustomFxTracer +from mmrazor.models.task_modules.tracer.prune_tracer import PruneTracer +from mmrazor.models.task_modules.tracer.razor_tracer import (FxBaseNode, + RazorFxTracer) +from mmrazor.structures.graph import BaseGraph, ModuleGraph +from mmrazor.structures.graph.channel_graph import ( + ChannelGraph, default_channel_node_converter) +from mmrazor.structures.graph.module_graph import (FxTracerToGraphConverter, + PathToGraphConverter) +from ...data.model_library import ModelGenerator +from ...data.tracer_passed_models import (PassedModelManager, + backward_passed_library, + fx_passed_library) +from ...utils import SetTorchThread + +sys.setrecursionlimit(int(pow(2, 20))) +# test config + +DEVICE = torch.device('cpu') +FULL_TEST = os.getenv('FULL_TEST') == 'true' +MP = os.getenv('MP') == 'true' + +DEBUG = os.getenv('DEBUG') == 'true' + +if MP: + POOL_SIZE = mp.cpu_count() + TORCH_THREAD_SIZE = 1 + torch.set_num_interop_threads(1) +else: + POOL_SIZE = 1 + TORCH_THREAD_SIZE = -1 + +print(f'DEBUG: {DEBUG}') +print(f'FULL_TEST: {FULL_TEST}') +print(f'POOL_SIZE: {POOL_SIZE}') +print(f'TORCH_THREAD_SIZE: {TORCH_THREAD_SIZE}') + +# tools for tesing + + +@contextmanager +def time_limit(seconds, msg='', activated=(not DEBUG)): + + class TimeoutException(Exception): + pass + + def signal_handler(signum, frame): + if activated: + raise TimeoutException(f'{msg} run over {seconds} s!') + + signal.signal(signal.SIGALRM, signal_handler) + signal.alarm(seconds) + try: + yield + finally: + signal.alarm(0) + + +# functional functions (need move to code) + + +def forward_units(model: ModelGenerator, + try_units: List[SequentialMutableChannelUnit], + units: List[SequentialMutableChannelUnit], template_output): + model.eval() + for unit in units: + unit.current_choice = 1.0 + for unit in try_units: + unit.current_choice = min(max(0.1, unit.sample_choice()), 0.9) + x = torch.rand([1, 3, 224, 224]).to(DEVICE) + tensors = model(x) + model.assert_model_is_changed(template_output, tensors) + + +def find_mutable(model, try_units, units, template_tensors): + if len(try_units) == 0: + return [] + try: + forward_units(model, try_units, units, template_tensors) + return try_units + except Exception as e: + if len(try_units) == 1: + print(f'{model} find an unmutable units.') + print(f'{e}') + print(try_units[0]) + return [] + else: + num = len(try_units) + return find_mutable(model, try_units[:num // 2], units, + template_tensors) + find_mutable( + model, try_units[num // 2:], units, + template_tensors) + + +class SumLoss: + + def __call__(self, model): + img = torch.zeros([2, 3, 224, 224]) + y = model(img) + return self.get_loss(y) + + def get_loss(self, output): + if isinstance(output, torch.Tensor): + return output.sum() + elif isinstance(output, list) or isinstance(output, tuple): + loss = 0 + for out in output: + loss += self.get_loss(out) + return loss + elif isinstance(output, dict): + return self.get_loss(list(output.values())) + else: + raise NotImplementedError(f'{type(output)}') + + +def is_dynamic_op_fx(module, name): + from mmcv.cnn.bricks import Scale + + is_leaf = ( + isinstance(module, DynamicChannelMixin) + or isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) + or isinstance(module, nn.modules.batchnorm._BatchNorm) + or isinstance(module, Scale)) + + return is_leaf + + +# test functions for mp + + +def _test_tracer(model, tracer_type='fx'): + + def _test_fx_tracer(model): + tracer = CustomFxTracer(leaf_module=PruneTracer.default_leaf_modules) + return tracer.trace(model) + + def _test_backward_tracer(model): + model.eval() + tracer = BackwardTracer(loss_calculator=SumLoss()) + return tracer.trace(model) + + if tracer_type == 'fx': + graph = _test_fx_tracer(model) + else: + graph = _test_backward_tracer(model) + return graph + + +def _test_tracer_result_2_module_graph(model, tracer_res, tracer_type='fx'): + + def _fx_graph_2_module_graph(model, fx_graph): + fx_graph.owning_module = model + fx_graph.graph = BaseGraph[FxBaseNode]() + base_graph = RazorFxTracer().parse_torch_graph(fx_graph) + + module_graph = FxTracerToGraphConverter(base_graph, model).graph + module_graph._model = model + module_graph.refresh_module_name() + return module_graph + + def _path_2_module_graph(model, path_list): + module_graph = PathToGraphConverter(path_list, model).graph + module_graph.refresh_module_name() + return module_graph + + if tracer_type == 'fx': + graph = _fx_graph_2_module_graph(model, tracer_res) + else: + graph = _path_2_module_graph(model, tracer_res) + return graph + + +def _test_units(units: List[SequentialMutableChannelUnit], model): + x = torch.rand([1, 3, 224, 224]).to(DEVICE) + model.eval() + tensors_org = model(x) + # prune + for unit in units: + unit.prepare_for_pruning(model) + mutable_units = [unit for unit in units if unit.is_mutable] + found_mutable_units = mutable_units + # found_mutable_units = find_mutable(model, mutable_units, units, + # tensors_org) + assert len(found_mutable_units) >= 1, \ + 'len of mutable units should greater or equal than 0.' + forward_units(model, found_mutable_units, units, tensors_org) + return found_mutable_units + + +def _test_a_model(Model, tracer_type='fx'): + start = time.time() + + try: + Model.init_model() + model = Model + model.eval() + print(f'test {Model} using {tracer_type} tracer.') + """ + model + -> fx_graph/path_list + -> module_graph + -> channel_graph + -> units + """ + with time_limit(10, 'trace'): + tracer_result = _test_tracer(model, tracer_type) + out = len(tracer_result.nodes if tracer_type == + 'fx' else tracer_result) + + with time_limit(10, 'to_module_graph'): + module_graph: ModuleGraph = _test_tracer_result_2_module_graph( + model, tracer_result, tracer_type) + module_graph.check(fix=True) + module_graph.check() + out = len(module_graph) + + with time_limit(10, 'to channel graph'): + channel_graph = ChannelGraph.copy_from( + module_graph, default_channel_node_converter) + channel_graph.check(fix=True) + channel_graph.check() + + with time_limit(80, 'to units'): + channel_graph.forward(3) + units_config = channel_graph.generate_units_config() + units = [ + SequentialMutableChannelUnit.init_from_cfg(model, cfg) + for cfg in units_config.values() + ] + + with time_limit(80, 'test units'): + # get unit + mutable_units = _test_units(units, model) + out = len(mutable_units) + + print(f'test {Model} successful.') + return Model.name, True, '', time.time() - start, out + except Exception as e: + if DEBUG: + raise e + else: + print(f'test {Model} failed.') + return Model.name, False, f'{e}', time.time() - start, -1 + + +# TestCase + + +class TestTraceModel(TestCase): + + def test_init_from_fx_tracer(self) -> None: + TestData = fx_passed_library.include_models(FULL_TEST) + with SetTorchThread(TORCH_THREAD_SIZE): + if POOL_SIZE != 1: + with mp.Pool(POOL_SIZE) as p: + result = p.map( + partial(_test_a_model, tracer_type='fx'), TestData) + else: + result = map( + partial(_test_a_model, tracer_type='fx'), TestData) + self.report(result, fx_passed_library, 'fx') + + def test_init_from_backward_tracer(self) -> None: + TestData = backward_passed_library.include_models(FULL_TEST) + with SetTorchThread(TORCH_THREAD_SIZE): + if POOL_SIZE != 1: + with mp.Pool(POOL_SIZE) as p: + result = p.map( + partial(_test_a_model, tracer_type='backward'), + TestData) + else: + result = map( + partial(_test_a_model, tracer_type='fx'), TestData) + self.report(result, backward_passed_library, 'backward') + + def report(self, result, model_manager: PassedModelManager, fx_type='fx'): + print() + print(f'Trace model summary using {fx_type} tracer.') + + passd_test = [res for res in result if res[1] is True] + unpassd_test = [res for res in result if res[1] is False] + + # long summary + + print(f'{len(passd_test)},{len(unpassd_test)},' + f'{len(model_manager.uninclude_models(full_test=FULL_TEST))}') + + print('Passed:') + print('\tmodel\ttime\tlen(mutable)') + for model, passed, msg, used_time, out in passd_test: + with self.subTest(model=model): + print(f'\t{model}\t{int(used_time)}s\t{out}') + self.assertTrue(passed, msg) + + print('UnPassed:') + for model, passed, msg, used_time, out in unpassd_test: + with self.subTest(model=model): + print(f'\t{model}\t{int(used_time)}s\t{out}') + print(f'\t\t{msg}') + self.assertTrue(passed, msg) + + print('UnTest:') + for model in model_manager.uninclude_models(full_test=FULL_TEST): + print(f'\t{model}') + + # short summary + print('Short Summary:') + short_passed = set( + [ModelGenerator.get_short_name(res[0]) for res in passd_test]) + + print('Passed\n', short_passed) + + short_unpassed = set( + [ModelGenerator.get_short_name(res[0]) for res in unpassd_test]) + print('Unpassed\n', short_unpassed) diff --git a/tests/test_core/test_tracer/__init__.py b/tests/test_core/test_tracer/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_core/test_tracer/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_core/test_tracer/test_backward_tracer.py b/tests/test_core/test_tracer/test_backward_tracer.py index 33f5eff78..55ddaccc0 100644 --- a/tests/test_core/test_tracer/test_backward_tracer.py +++ b/tests/test_core/test_tracer/test_backward_tracer.py @@ -259,6 +259,12 @@ def test_path_list(self): with self.assertRaisesRegex(AssertionError, ''): _ = PathList({}) + def test_sum_pseudo_loss(self): + model = ResBlock() + tracer = BackwardTracer(loss_calculator={'type': 'SumPseudoLoss'}) + path = tracer.trace(model) + print(path) + def _test_reset_bn_running_stats(should_fail): import os diff --git a/tests/test_core/test_tracer/test_fx_tracer.py b/tests/test_core/test_tracer/test_fx_tracer.py new file mode 100644 index 000000000..238e51cd3 --- /dev/null +++ b/tests/test_core/test_tracer/test_fx_tracer.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +# import torch + +# from mmrazor.models.task_modules.tracer.razor_tracer import FxTracer +# from ...data.models import UnTracableModel + + +class TestFxTracer(unittest.TestCase): + + pass diff --git a/tests/test_core/test_tracer/test_prune_tracer.py b/tests/test_core/test_tracer/test_prune_tracer.py new file mode 100644 index 000000000..b5ee5993b --- /dev/null +++ b/tests/test_core/test_tracer/test_prune_tracer.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +from mmrazor.models.task_modules.tracer import PruneTracer +from ...data.models import SingleLineModel + + +class TestPruneTracer(TestCase): + + def test_backward_tracer(self): + model = SingleLineModel() + tracer = PruneTracer(tracer_type='BackwardTracer') + unit_configs = tracer.trace(model) + print(unit_configs) + + def test_fx_tracer(self): + model = SingleLineModel() + tracer = PruneTracer(tracer_type='FxTracer') + unit_configs = tracer.trace(model) + print(unit_configs) diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 000000000..9bbbb71f7 --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch + +from .data.model_library import (DefaultModelLibrary, MMClsModelLibrary, + MMDetModelLibrary, MMSegModelLibrary, + ModelGenerator, TorchModelLibrary) +from .data.models import SingleLineModel +from .data.tracer_passed_models import (BackwardPassedModelManager, + FxPassedModelManager) + + +class TestModelLibrary(unittest.TestCase): + + def test_mmcls(self): + library = MMClsModelLibrary(exclude=['cutmax', 'cifar']) + self.assertTrue(library.is_default_includes_cover_all_models()) + + def test_defaul_library(self): + library = DefaultModelLibrary() + self.assertTrue(library.is_default_includes_cover_all_models()) + + def test_torchlibrary(self): + library = TorchModelLibrary() + self.assertTrue(library.is_default_includes_cover_all_models()) + + def test_mmdet(self): + library = MMDetModelLibrary() + self.assertTrue(library.is_default_includes_cover_all_models()) + + def test_mmseg(self): + library = MMSegModelLibrary() + self.assertTrue(library.is_default_includes_cover_all_models()) + + def test_passed_models(self): + try: + print(FxPassedModelManager().include_models()) + print(BackwardPassedModelManager().include_models()) + except Exception: + self.fail() + + +class TestModels(unittest.TestCase): + + def _test_a_model(self, Model): + model = Model() + x = torch.rand(2, 3, 224, 224) + y = model(x) + self.assertSequenceEqual(y.shape, [2, 1000]) + + def test_models(self): + library = DefaultModelLibrary() + for Model in library.include_models(): + with self.subTest(model=Model): + self._test_a_model(Model) + + def test_generator(self): + Model = ModelGenerator('model', SingleLineModel) + model = Model() + model.eval() + self.assertEqual(model.training, False) + model.train() + self.assertEqual(model.training, True) diff --git a/tests/test_doc.py b/tests/test_doc.py new file mode 100644 index 000000000..facf8ad47 --- /dev/null +++ b/tests/test_doc.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import nbformat +from nbconvert.preprocessors import ExecutePreprocessor + +notebook_paths = [ + './mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb', + './mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb' # noqa +] + + +class TestDocs(TestCase): + + def test_notebooks(self): + for path in notebook_paths: + with self.subTest(path=path): + with open(path) as file: + nb_in = nbformat.read(file, nbformat.NO_CONVERT) + ep = ExecutePreprocessor( + timeout=600, kernel_name='python3') + try: + _ = ep.preprocess(nb_in) + except Exception: + self.fail() diff --git a/tests/test_models/test_algorithms/test_autoslim.py b/tests/test_models/test_algorithms/test_autoslim.py index 79169b3cf..228050a33 100644 --- a/tests/test_models/test_algorithms/test_autoslim.py +++ b/tests/test_models/test_algorithms/test_autoslim.py @@ -36,9 +36,7 @@ default_args=dict( candidate_choices=list(i / 12 for i in range(2, 13)), choice_mode='ratio')), - parse_cfg=dict( - type='BackwardTracer', - loss_calculator=dict(type='ImageClassifierPseudoLoss'))) + parse_cfg=dict(type='PruneTracer')) DISTILLER_CFG = dict( type='ConfigurableDistiller', diff --git a/tests/test_models/test_algorithms/test_slimmable_network.py b/tests/test_models/test_algorithms/test_slimmable_network.py index 13efbcf84..f2b143707 100644 --- a/tests/test_models/test_algorithms/test_slimmable_network.py +++ b/tests/test_models/test_algorithms/test_slimmable_network.py @@ -29,9 +29,7 @@ MUTATOR_CFG = dict( type='SlimmableChannelMutator', channel_unit_cfg=dict(type='SlimmableChannelUnit', units=CHANNEL_CFG_PATH), - parse_cfg=dict( - type='BackwardTracer', - loss_calculator=dict(type='ImageClassifierPseudoLoss'))) + parse_cfg=dict(type='PruneTracer')) CHANNEL_CFG_PATHS = [ 'tests/data/MBV2_220M.yaml', diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_norm.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_norm.py index ce6ae7b36..cafa1e7b9 100644 --- a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_norm.py +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_norm.py @@ -1,15 +1,18 @@ # Copyright (c) OpenMMLab. All rights reserved. +import unittest from typing import Tuple, Type from unittest.mock import MagicMock import pytest import torch +import torch.distributed as dist from torch import nn from mmrazor.models.architectures.dynamic_ops import (DynamicBatchNorm1d, DynamicBatchNorm2d, DynamicBatchNorm3d, - DynamicMixin) + DynamicMixin, + DynamicSyncBatchNorm) from mmrazor.models.mutables import SquentialMutableChannel from mmrazor.structures.subnet import export_fix_subnet, load_fix_subnet from ..utils import fix_dynamic_op @@ -115,3 +118,37 @@ def test_bn_track_running_stats( x = torch.rand(*input_shape) assert torch.equal(d_bn(x), s_bn(x)) + + +class TestDynamicSyncBn(unittest.TestCase): + + def test_init(self): + if not torch.cuda.is_available(): + self.skipTest('no cuda') + import os + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + # initialize the process group + if torch.cuda.is_available(): + backend = 'nccl' + device = torch.device('cuda:0') + else: + backend = 'gloo' + device = torch.device('cpu') + dist.init_process_group(backend, rank=0, world_size=1) + + x = torch.rand([2, 8, 224, 224]).to(device) + norm = DynamicSyncBatchNorm(8).to(device) + _ = norm(x) + + mutable_num_features = SquentialMutableChannel(8) + mutable_num_features.current_choice = 4 + norm.register_mutable_attr('in_channels', mutable_num_features) + + with pytest.raises(Exception): + norm(x) + + x = torch.rand([2, 4, 32, 32]).to(device) + _ = norm(x) + dist.destroy_process_group() diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py b/tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py index c93a43842..6330005d1 100644 --- a/tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py +++ b/tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import unittest -import pytest import torch from mmrazor.models.mutables import (SimpleMutableChannel, @@ -31,5 +30,4 @@ def test_SimpleMutableChannel(self): channel.current_choice = torch.tensor([1, 0, 0, 0]).bool() self.assertEqual(channel.activated_channels, 1) channel.fix_chosen() - with pytest.raises(NotImplementedError): - channel.dump_chosen() + channel.dump_chosen() diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_units/test_l1_mutable_channel_unit.py b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_l1_mutable_channel_unit.py index f1a0d8529..6cf292f10 100644 --- a/tests/test_models/test_mutables/test_mutable_channel/test_units/test_l1_mutable_channel_unit.py +++ b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_l1_mutable_channel_unit.py @@ -5,13 +5,13 @@ from mmrazor.models.mutables import L1MutableChannelUnit from mmrazor.models.mutators import ChannelMutator -from .....data.models import LineModel +from .....data.models import SingleLineModel class TestL1MutableChannelUnit(TestCase): def test_init(self): - model = LineModel() + model = SingleLineModel() mutator = ChannelMutator( channel_unit_cfg={ 'type': 'L1MutableChannelUnit', @@ -21,9 +21,6 @@ def test_init(self): }) mutator.prepare_from_supernet(model) mutator.set_choices(mutator.sample_choices()) - print(mutator.units) - print(mutator.mutable_units) - print(mutator.choice_template) def test_convnd(self): unit = L1MutableChannelUnit(8) diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_units/test_mutable_channel_units.py b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_mutable_channel_units.py index ad5b5e56b..897d0633a 100644 --- a/tests/test_models/test_mutables/test_mutable_channel/test_units/test_mutable_channel_units.py +++ b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_mutable_channel_units.py @@ -8,19 +8,16 @@ from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin from mmrazor.models.mutables.mutable_channel import ( L1MutableChannelUnit, MutableChannelUnit, SequentialMutableChannelUnit) -from mmrazor.models.mutables.mutable_channel.units.channel_unit import ( # noqa - Channel, ChannelUnit) -from mmrazor.structures.graph import ModuleGraph as ModuleGraph -from .....data.models import LineModel -from .....test_core.test_graph.test_graph import TestGraph +from mmrazor.models.mutables.mutable_channel.units.channel_unit import \ + ChannelUnit # noqa +from .....data.models import SingleLineModel +from .....data.tracer_passed_models import backward_passed_library MUTABLE_CFG = dict(type='SimpleMutablechannel') PARSE_CFG = dict( type='BackwardTracer', loss_calculator=dict(type='ImageClassifierPseudoLoss')) -# DEVICE = torch.device('cuda:0') if torch.cuda.is_available() \ -# else torch.device('cpu') DEVICE = torch.device('cpu') GROUPS: List[MutableChannelUnit] = [ L1MutableChannelUnit, SequentialMutableChannelUnit @@ -29,17 +26,30 @@ DefaultChannelUnit = SequentialMutableChannelUnit +def _test_units(units: List[MutableChannelUnit], model): + for unit in units: + unit.prepare_for_pruning(model) + mutable_units = [unit for unit in units if unit.is_mutable] + assert len(mutable_units) >= 1, \ + 'len of mutable units should greater or equal than 0.' + for unit in mutable_units: + choice = unit.sample_choice() + unit.current_choice = choice + assert abs(unit.current_choice - choice) < 0.1 + x = torch.rand([2, 3, 224, 224]).to(DEVICE) + y = model(x) + assert list(y.shape) == [2, 1000] + + class TestMutableChannelUnit(TestCase): - def test_init_from_graph(self): - model = LineModel() - # init using tracer - graph = ModuleGraph.init_from_backward_tracer(model) - units = DefaultChannelUnit.init_from_graph(graph) - self._test_units(units, model) + def test_init_from_tracer(self): + model = SingleLineModel() + units = DefaultChannelUnit.init_from_prune_tracer(model) + _test_units(units, model) def test_init_from_cfg(self): - model = LineModel() + model = SingleLineModel() # init using tracer config = { @@ -76,53 +86,26 @@ def test_init_from_cfg(self): } } units = [DefaultChannelUnit.init_from_cfg(model, config)] - self._test_units(units, model) + _test_units(units, model) def test_init_from_channel_unit(self): - model = LineModel() + model = SingleLineModel() # init using tracer - graph = ModuleGraph.init_from_backward_tracer(model) - units: List[ChannelUnit] = ChannelUnit.init_from_graph(graph) + units: List[ChannelUnit] = ChannelUnit.init_from_prune_tracer(model) mutable_units = [ DefaultChannelUnit.init_from_channel_unit(unit) for unit in units ] - self._test_units(mutable_units, model) - - def _test_units(self, units: List[MutableChannelUnit], model): - for unit in units: - unit.prepare_for_pruning(model) - mutable_units = [unit for unit in units if unit.is_mutable] - self.assertGreaterEqual(len(mutable_units), 1) - for unit in mutable_units: - choice = unit.sample_choice() - unit.current_choice = choice - self.assertAlmostEqual(unit.current_choice, choice, delta=0.1) - x = torch.rand([2, 3, 224, 224]).to(DEVICE) - y = model(x) - self.assertSequenceEqual(y.shape, [2, 1000]) - - def _test_a_model_from_backward_tracer(self, model): - model.eval() - model = model.to(DEVICE) - graph = ModuleGraph.init_from_backward_tracer(model) - self._test_a_graph(model, graph) - - def test_with_backward_tracer(self): - test_models = TestGraph.backward_tracer_passed_models() - for model_data in test_models: - with self.subTest(model=model_data): - model = model_data() - self._test_a_model_from_backward_tracer(model) + _test_units(mutable_units, model) def test_replace_with_dynamic_ops(self): - model_datas = TestGraph.backward_tracer_passed_models() + model_datas = backward_passed_library.include_models() for model_data in model_datas: for unit_type in GROUPS: with self.subTest(model=model_data, unit=unit_type): model: nn.Module = model_data() - graph = ModuleGraph.init_from_backward_tracer(model) units: List[ - MutableChannelUnit] = unit_type.init_from_graph(graph) + MutableChannelUnit] = unit_type.init_from_prune_tracer( + model) for unit in units: unit.prepare_for_pruning(model) @@ -138,10 +121,3 @@ def test_replace_with_dynamic_ops(self): if isinstance(module, nn.BatchNorm2d): self.assertTrue( isinstance(module, DynamicChannelMixin)) - - def _test_a_graph(self, model, graph): - try: - units = DefaultChannelUnit.init_from_graph(graph) - self._test_units(units, model) - except Exception as e: - self.fail(f'{e}') diff --git a/tests/test_models/test_mutators/test_channel_mutator.py b/tests/test_models/test_mutators/test_channel_mutator.py index 96908d807..f2ae076d2 100644 --- a/tests/test_models/test_mutators/test_channel_mutator.py +++ b/tests/test_models/test_mutators/test_channel_mutator.py @@ -11,7 +11,7 @@ from mmrazor.models.mutators.channel_mutator import ChannelMutator from mmrazor.registry import MODELS from ...data.models import DynamicLinearModel -from ...test_core.test_graph.test_graph import TestGraph +from ...data.tracer_passed_models import backward_passed_library sys.setrecursionlimit(2000) @@ -46,7 +46,7 @@ def _test_a_mutator(self, mutator: ChannelMutator, model): self.assertEqual(list(y.shape), [2, 1000]) def test_sample_subnet(self): - data_models = TestGraph.backward_tracer_passed_models() + data_models = backward_passed_library.include_models()[:2] for i, data in enumerate(data_models): with self.subTest(i=i, data=data): @@ -60,7 +60,7 @@ def test_sample_subnet(self): self._test_a_mutator(mutator, model) def test_generic_support(self): - data_models = TestGraph.backward_tracer_passed_models() + data_models = backward_passed_library.include_models() for data_model in data_models[:1]: for unit_type in DATA_UNITS: @@ -105,7 +105,7 @@ def test_init_units_from_cfg(self): self._test_a_mutator(mutator2, model2) def test_mix_config_tracer(self): - model = TestGraph.backward_tracer_passed_models()[0]() + model = backward_passed_library.include_models()[0]() model0 = copy.deepcopy(model) mutator0 = ChannelMutator() diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 000000000..f4f0562e4 --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .set_torch_thread import SetTorchThread + +__all__ = ['SetTorchThread'] diff --git a/tests/utils/set_torch_thread.py b/tests/utils/set_torch_thread.py new file mode 100644 index 000000000..a3cc482e8 --- /dev/null +++ b/tests/utils/set_torch_thread.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +class SetTorchThread: + + def __init__(self, num_thread: int = -1) -> None: + self.prev_num_threads = torch.get_num_threads() + self.num_threads = num_thread + + def __enter__(self): + if self.num_threads != -1: + torch.set_num_threads(self.num_threads) + + def __exit__(self, exc_type, exc_value, tb): + if self.num_threads != -1: + torch.set_num_threads(self.prev_num_threads) diff --git a/tools/get_channel_units.py b/tools/get_channel_units.py index dc8818fbf..de95d13ca 100644 --- a/tools/get_channel_units.py +++ b/tools/get_channel_units.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse import json +import sys import torch.nn as nn from mmengine import MODELS @@ -9,6 +10,8 @@ from mmrazor.models import BaseAlgorithm from mmrazor.models.mutators import ChannelMutator +sys.setrecursionlimit(int(pow(2, 20))) + def parse_args(): parser = argparse.ArgumentParser( diff --git a/tools/get_prune_config.py b/tools/get_prune_config.py new file mode 100644 index 000000000..dbb80c729 --- /dev/null +++ b/tools/get_prune_config.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy +from typing import Dict + +from mmengine import Config + +from mmrazor.models.mutators import ChannelMutator +from mmrazor.registry import MODELS + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Get the config to prune a model.') + parser.add_argument('config', help='config of the model') + parser.add_argument( + '--checkpoint', + default=None, + type=str, + help='checkpoint path of the model') + parser.add_argument( + '-o', + type=str, + default='./prune.py', + help='output path to store the pruning config.') + args = parser.parse_args() + return args + + +def wrap_prune_config(config: Config, prune_target: Dict, + checkpoint_path: str): + config = copy.deepcopy(config) + + arch_config: Dict = config['model'] + + # update checkpoint_path + if checkpoint_path is not None: + arch_config.update({ + 'init_cfg': { + 'type': 'Pretrained', + 'checkpoint': checkpoint_path # noqa + }, + }) + + # deal with data_preprocessor + if 'data_preprocessor' in config: + data_preprocessor = config['data_preprocessor'] + arch_config.update({'data_preprocessor': data_preprocessor}) + config['data_preprocessor'] = None + else: + data_preprocessor = None + + # prepare algorithm + algorithm_config = dict( + _scope_='mmrazor', + type='ItePruneAlgorithm', + architecture=arch_config, + target_pruning_ratio=prune_target, + mutator_cfg=dict( + type='ChannelMutator', + channel_unit_cfg=dict( + type='L1MutableChannelUnit', + default_args=dict(choice_mode='ratio')), + parse_cfg=dict(type='PruneTracer', tracer_type='FxTracer'))) + config['model'] = algorithm_config + + return config + + +if __name__ == '__main__': + args = parse_args() + config_path = args.config + checkpoint_path = args.checkpoint + target_path = args.o + + origin_config = Config.fromfile(config_path) + + # get subnet config + model = MODELS.build(origin_config['model']) + mutator: ChannelMutator = ChannelMutator( + channel_unit_cfg=dict( + type='L1MutableChannelUnit', + default_args=dict(choice_mode='ratio'), + ), + parse_cfg={ + 'type': 'PruneTracer', + 'tracer_type': 'FxTracer' + }) + mutator.prepare_from_supernet(model) + choice_template = mutator.choice_template + + # prune and finetune + + prune_config: Config = wrap_prune_config(origin_config, choice_template, + checkpoint_path) + prune_config.dump(target_path) diff --git a/tools/get_resource.py b/tools/get_resource.py new file mode 100644 index 000000000..4e466b18f --- /dev/null +++ b/tools/get_resource.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +from mmengine import Config + +from mmrazor.models.task_modules import ResourceEstimator +from mmrazor.registry import MODELS + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('config') + parser.add_argument('-H', default=224, type=int) + parser.add_argument('-W', default=224, type=int) + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + config = Config.fromfile(args.config) + H = args.H + W = args.W + + model_config = config['model'] + model = MODELS.build(model_config) + + estimator = ResourceEstimator( + flops_params_cfg={'input_shape': (1, 3, H, W)}) + result = estimator.estimate(model) + print(result) diff --git a/tools/get_search_config.py b/tools/get_search_config.py new file mode 100644 index 000000000..1957b0a0b --- /dev/null +++ b/tools/get_search_config.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy +from typing import Dict, Tuple + +from mmengine import Config + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Get the config to search the pruning structure of a model' + ) + parser.add_argument('config', help='config of the model') + parser.add_argument( + '--checkpoint', + default=None, + type=str, + help='checkpoint path of the model') + parser.add_argument( + '--flops-min', type=float, default=0.45, help='minimal flops') + parser.add_argument( + '--flops-max', type=float, default=0.55, help='maximal flops') + parser.add_argument( + '-o', + type=str, + default='./search.py', + help='output path to store the search config.') + args = parser.parse_args() + return args + + +def wrap_search_config(config: Config, checkpoint_path: str, + flop_range: Tuple): + config = copy.deepcopy(config) + + arch_config: Dict = config['model'] + + # deal with data_preprocessor + if 'data_preprocessor' in config: + data_preprocessor = config['data_preprocessor'] + arch_config.update({'data_preprocessor': data_preprocessor}) + config['data_preprocessor'] = None + else: + data_preprocessor = None + + # deal with checkpoint + if checkpoint_path is not None: + arch_config.update({ + 'init_cfg': { + 'type': 'Pretrained', + 'checkpoint': checkpoint_path # noqa + }, + }) + + model_config = dict( + _scope_='mmrazor', + type='SearchWrapper', + architecture=arch_config, + mutator_cfg=dict( + type='ChannelMutator', + channel_unit_cfg=dict( + type='L1MutableChannelUnit', + default_args=dict(choice_mode='ratio')), + parse_cfg=dict(type='PruneTracer', tracer_type='FxTracer'))) + + config['model'] = model_config + + val_evaluator_config = config['val_evaluator'] + val_evaluator_config[ + 'type'] = config['default_scope'] + '.' + val_evaluator_config['type'] + + def prepare_dataloader(val_loader_config): + + val_loader_config['dataset']['type'] = config[ + 'default_scope'] + '.' + val_loader_config['dataset']['type'] + return val_loader_config + + val_loader_config = config['val_dataloader'] + val_loader_config = prepare_dataloader(val_loader_config) + train_loader_config = prepare_dataloader(config['train_dataloader']) + + searcher_config = dict( + type='mmrazor.PruneEvolutionSearchLoop', + dataloader=val_loader_config, + bn_dataloader=train_loader_config, + evaluator=val_evaluator_config, + max_epochs=20, + num_candidates=20, + top_k=5, + num_mutation=10, + num_crossover=10, + mutate_prob=0.2, + flops_range=flop_range, + resource_estimator_cfg=dict( + flops_params_cfg=dict(input_shape=(1, 3, 224, 224))), + score_key='accuracy/top1') + config['train_cfg'] = searcher_config + return config + + +if __name__ == '__main__': + args = parse_args() + config_path = args.config + checkpoint_path = args.checkpoint + flops_range = (args.flops_min, args.flops_max) + assert flops_range[1] > flops_range[0] + target_path = args.o + + origin_config = Config.fromfile(config_path) + + # wrap config for search + + search_config = wrap_search_config(origin_config, checkpoint_path, + flops_range) + search_config.dump(target_path)