diff --git a/.circleci/config.yml b/.circleci/config.yml index a6bd898d0..27e235306 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -26,7 +26,7 @@ workflows: tools/.* lint_only false configs/.* lint_only false .circleci/.* lint_only false - base-revision: dev-1.x + base-revision: main # this is the path of the configuration we should trigger once # path filtering and pipeline parameter value updates are # complete. In this case, we are using the parent dynamic diff --git a/.circleci/test.yml b/.circleci/test.yml index 25140a879..9acc7fdfc 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -67,10 +67,10 @@ jobs: pip install git+https://github.com/open-mmlab/mmengine.git@main pip install -U openmim mim install 'mmcv >= 2.0.0rc1' - pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x - pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x - pip install git+https://github.com/open-mmlab/mmsegmentation.git@dev-1.x - python -m pip install git+ssh://git@github.com/open-mmlab/mmpose.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmpretrain.git@mmcls-1.x + pip install git+https://github.com/open-mmlab/mmdetection.git@main + pip install git+https://github.com/open-mmlab/mmsegmentation.git@main + python -m pip install git+ssh://git@github.com/open-mmlab/mmpose.git@main pip install -r requirements.txt - run: name: Build and install @@ -103,9 +103,9 @@ jobs: name: Clone Repos command: | git clone -b main --depth 1 https://github.com/open-mmlab/mmengine.git /home/circleci/mmengine - git clone -b dev-3.x --depth 1 https://github.com/open-mmlab/mmdetection.git /home/circleci/mmdetection - git clone -b dev-1.x --depth 1 https://github.com/open-mmlab/mmclassification.git /home/circleci/mmclassification - git clone -b dev-1.x --depth 1 https://github.com/open-mmlab/mmsegmentation.git /home/circleci/mmsegmentation + git clone -b main --depth 1 https://github.com/open-mmlab/mmdetection.git /home/circleci/mmdetection + git clone -b 1.x --depth 1 https://github.com/open-mmlab/mmclassification.git /home/circleci/mmclassification + git clone -b main --depth 1 https://github.com/open-mmlab/mmsegmentation.git /home/circleci/mmsegmentation - run: name: Build Docker image command: | @@ -139,7 +139,7 @@ workflows: filters: branches: ignore: - - dev-1.x + - main - 1.x pr_stage_test: when: @@ -150,18 +150,18 @@ workflows: filters: branches: ignore: - - dev-1.x + - main - build_cpu: name: minimum_version_cpu - torch: 1.6.0 - torchvision: 0.7.0 - python: 3.7.9 + torch: 1.8.1 + torchvision: 0.9.1 + python: 3.7.4 requires: - lint - build_cpu: name: maximum_version_cpu - torch: 1.12.1 - torchvision: 0.13.1 + torch: 1.13.1 + torchvision: 0.14.1 python: 3.9.0 requires: - lint @@ -183,11 +183,11 @@ workflows: jobs: - build_cuda: name: minimum_version_gpu - torch: 1.6.0 + torch: 1.8.1 # Use double quotation mark to explicitly specify its type # as string instead of number - cuda: "10.1" + cuda: "10.2" filters: branches: only: - - dev-1.x + - main diff --git a/.dev_scripts/benchmark_summary_analyse.py b/.dev_scripts/benchmark_summary_analyse.py new file mode 100644 index 000000000..372e1326c --- /dev/null +++ b/.dev_scripts/benchmark_summary_analyse.py @@ -0,0 +1,67 @@ +import argparse +import os + +import mmengine + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Analyse summary.yml generated by benchmark test') + parser.add_argument('file_path', help='Summary.yml path') + args = parser.parse_args() + return args + + +metric_mapping = { + 'Top 1 Accuracy': 'accuracy/top1', + 'Top 5 Accuracy': 'accuracy/top5', + 'box AP': 'coco/bbox_mAP', + 'mIoU': 'mIoU' +} + + +def compare_metric(result, metric): + expect_val = result['expect'][metric] + actual_val = result['actual'].get(metric_mapping[metric], None) + if actual_val is None: + return None, None + if metric == 'box AP': + actual_val *= 100 + decimal_bit = len(str(expect_val).split('.')[-1]) + actual_val = round(actual_val, decimal_bit) + error = round(actual_val - expect_val, decimal_bit) + error_percent = round(abs(error) * 100 / expect_val, 3) + return error, error_percent + + +def main(): + args = parse_args() + file_path = args.file_path + results = mmengine.load(file_path, 'yml') + miss_models = dict() + sort_by_error = dict() + for k, v in results.items(): + valid_keys = v['expect'].keys() + compare_res = dict() + for m in valid_keys: + error, error_percent = compare_metric(v, m) + if error is None: + continue + compare_res[m] = {'error': error, 'error_percent': error_percent} + if error != 0: + miss_models[k] = compare_res + sort_by_error[k] = error + sort_by_error = sorted( + sort_by_error.items(), key=lambda x: abs(x[1]), reverse=True) + miss_models_sort = dict() + miss_models_sort['total error models'] = len(sort_by_error) + for k_v in sort_by_error: + index = k_v[0] + miss_models_sort[index] = miss_models[index] + save_path = os.path.join(os.path.dirname(file_path), 'summary_error.yml') + mmengine.fileio.dump(miss_models_sort, save_path, sort_keys=False) + print(f'Summary analysis result saved in {save_path}') + + +if __name__ == '__main__': + main() diff --git a/.dev_scripts/benchmark_test.py b/.dev_scripts/benchmark_test.py index a9a208dbb..1af3e4fa4 100644 --- a/.dev_scripts/benchmark_test.py +++ b/.dev_scripts/benchmark_test.py @@ -24,9 +24,9 @@ def parse_args(): parser = argparse.ArgumentParser( description="Test all models' accuracy in model-index.yml") - parser.add_argument( - 'partition', type=str, help='Cluster partition to use.') parser.add_argument('checkpoint_root', help='Checkpoint file root path.') + parser.add_argument( + '--partition', type=str, help='Cluster partition to use.') parser.add_argument( '--job-name', type=str, @@ -148,6 +148,7 @@ def create_test_job_batch(commands, model_info, args, port): if exists: print(f'{checkpoint} already exists.') else: + print(f'start downloading {fname}') wget.download(model_info.weights, str(checkpoint)) print(f'\nSaved in {checkpoint}.') diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e00ed24c8..88928727e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -25,22 +25,12 @@ concurrency: jobs: test_linux: - runs-on: ubuntu-18.04 + runs-on: ubuntu-20.04 strategy: matrix: python-version: [3.7] - torch: [1.6.0, 1.7.0, 1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0] + torch: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0] include: - - torch: 1.6.0 - torch_version: 1.6 - torchvision: 0.7.0 - - torch: 1.7.0 - torch_version: 1.7 - torchvision: 0.8.1 - - torch: 1.7.0 - torch_version: 1.7 - torchvision: 0.8.1 - python-version: 3.8 - torch: 1.8.0 torch_version: 1.8 torchvision: 0.9.0 @@ -103,11 +93,11 @@ jobs: pip install -U openmim mim install 'mmcv >= 2.0.0rc1' - name: Install MMCls - run: pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + run: pip install 'mmcls>=1.0.0rc0' - name: Install MMDet - run: pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + run: pip install git+https://github.com/open-mmlab/mmdetection.git@main - name: Install MMSeg - run: pip install git+https://github.com/open-mmlab/mmsegmentation.git@dev-1.x + run: pip install git+https://github.com/open-mmlab/mmsegmentation.git@main - name: Install other dependencies run: pip install -r requirements.txt - name: Build and install @@ -119,8 +109,8 @@ jobs: coverage report -m # Upload coverage report for python3.8 && pytorch1.12.0 cpu - name: Upload coverage to Codecov - if: ${{matrix.torch == '1.12.0' && matrix.python-version == '3.8'}} - uses: codecov/codecov-action@v2 + if: ${{matrix.torch == '1.13.0' && matrix.python-version == '3.8'}} + uses: codecov/codecov-action@v3 with: file: ./coverage.xml flags: unittests diff --git a/README.md b/README.md index f8778f6b6..ad92732cf 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ [![PyPI](https://img.shields.io/pypi/v/mmrazor)](https://pypi.org/project/mmrazor) -[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mmrazor.readthedocs.io/en/dev-1.x/) +[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mmrazor.readthedocs.io/en/main/) [![badge](https://github.com/open-mmlab/mmrazor/workflows/build/badge.svg)](https://github.com/open-mmlab/mmrazor/actions) [![codecov](https://codecov.io/gh/open-mmlab/mmrazor/branch/master/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmrazor) [![license](https://img.shields.io/github/license/open-mmlab/mmrazor.svg)](https://github.com/open-mmlab/mmrazor/blob/master/LICENSE) @@ -32,9 +32,9 @@ -[📘Documentation](https://mmrazor.readthedocs.io/en/dev-1.x/) | -[🛠️Installation](https://mmrazor.readthedocs.io/en/dev-1.x/get_started/installation.html) | -[👀Model Zoo](https://mmrazor.readthedocs.io/en/dev-1.x/get_started/model_zoo.html) | +[📘Documentation](https://mmrazor.readthedocs.io/en/main/) | +[🛠️Installation](https://mmrazor.readthedocs.io/en/main/get_started/installation.html) | +[👀Model Zoo](https://mmrazor.readthedocs.io/en/main/get_started/model_zoo.html) | [🤔Reporting Issues](https://github.com/open-mmlab/mmrazor/issues/new/choose) @@ -65,12 +65,12 @@ English | [简体中文](README_zh-CN.md) ## Introduction -MMRazor is a model compression toolkit for model slimming and AutoML, which includes 3 mainstream technologies: +MMRazor is a model compression toolkit for model slimming and AutoML, which includes 4 mainstream technologies: - Neural Architecture Search (NAS) - Pruning - Knowledge Distillation (KD) -- Quantization (come soon) +- Quantization It is a part of the [OpenMMLab](https://openmmlab.com/) project. @@ -88,22 +88,23 @@ Major features: With better modular design, developers can implement new model compression algorithms with only a few codes, or even by simply modifying config files. -Below is an overview of MMRazor's design and implementation, please refer to [tutorials](https://mmrazor.readthedocs.io/en/dev-1.x/get_started/overview.html) for more details. +About MMRazor's design and implementation, please refer to [tutorials](https://mmrazor.readthedocs.io/en/main/get_started/overview.html) for more details. -
- -
-
+## Latest Updates -## What's new +**The default branch is now `main` and the code on the branch has been upgraded to v1.0.0. The old `master` branch code now exists on the 0.x branch** -MMRazor v1.0.0rc0 was released in 1/9/2022. +MMRazor v1.0.0 was released in 2023-4-24, Major updates from 1.0.0rc2 include: -Please refer to [changelog.md](/docs/en/notes/changelog.md) for more details and other release history. +1. MMRazor quantization is released. +2. Add a new pruning algorithm named GroupFisher. +3. Support distilling rtmdet with MMRazor. + +To know more about the updates in MMRazor 1.0, please refer to [Changelog](https://mmrazor.readthedocs.io/en/main/notes/changelog.html) for more details! ## Benchmark and model zoo -Results and models are available in the [model zoo](/docs/en/get_started/model_zoo.md). +Results and models are available in the [model zoo](https://mmrazor.readthedocs.io/en/main/get_started/model_zoo.html). Supported algorithms: @@ -123,6 +124,12 @@ Supported algorithms: - [x] [AutoSlim(NeurIPS'2019)](/configs/pruning/mmcls/autoslim) +- [x] [L1-norm](/configs/pruning/mmcls/l1-norm) + +- [x] [Group Fisher](/configs/pruning/base/group_fisher) + +- [x] [DMCP](/configs/pruning/mmcls/dmcp) +
@@ -158,20 +165,31 @@ Supported algorithms:
+
+Quantization + +- [x] [PTQ](/configs/quantization/ptq/base) + +- [x] [QAT](/configs/quantization/qat/base) + +- [x] [LSQ](/configs/quantization/qat/lsq) + +
+ ## Installation MMRazor depends on [PyTorch](https://pytorch.org/), [MMCV](https://github.com/open-mmlab/mmcv) and [MMEngine](https://github.com/open-mmlab/mmengine). -Please refer to [installation.md](/docs/en/get_started/installation.md) for more detailed instruction. +Please refer to [installation.md](https://mmrazor.readthedocs.io/en/main/get_started/installation.html) for more detailed instruction. ## Getting Started -Please refer to [user guides](https://mmrazor.readthedocs.io/en/dev-1.x/user_guides/index.html) for the basic usage of MMRazor. There are also [advanced guides](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/index.html): +Please refer to [user guides](https://mmrazor.readthedocs.io/en/main/user_guides/index.html) for the basic usage of MMRazor. There are also [advanced guides](https://mmrazor.readthedocs.io/en/main/advanced_guides/index.html): ## Contributing We appreciate all contributions to improve MMRazor. -Please refer to [CONTRUBUTING.md](/docs/en/notes/contribution_guide.md) for the contributing guideline. +Please refer to [CONTRUBUTING.md](https://mmrazor.readthedocs.io/en/main/notes/contribution_guide.html) for the contributing guideline. ## Acknowledgement diff --git a/README_zh-CN.md b/README_zh-CN.md index 169181941..fc59086fb 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -21,7 +21,7 @@ [![PyPI](https://img.shields.io/pypi/v/mmrazor)](https://pypi.org/project/mmrazor) -[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mmrazor.readthedocs.io/en/latest/) +[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mmrazor.readthedocs.io/en/main/) [![badge](https://github.com/open-mmlab/mmrazor/workflows/build/badge.svg)](https://github.com/open-mmlab/mmrazor/actions) [![codecov](https://codecov.io/gh/open-mmlab/mmrazor/branch/master/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmrazor) [![license](https://img.shields.io/github/license/open-mmlab/mmrazor.svg)](https://github.com/open-mmlab/mmrazor/blob/master/LICENSE) @@ -32,9 +32,9 @@ -[📘使用文档](https://mmrazor.readthedocs.io/) | -[🛠️安装教程](https://mmrazor.readthedocs.io/en/latest/get_started.html) | -[👀模型库](https://mmrazor.readthedocs.io/en/latest/model_zoo.html) | +[📘使用文档](https://mmrazor.readthedocs.io/en/main/) | +[🛠️安装教程](https://mmrazor.readthedocs.io/en/main/get_started/installation.html) | +[👀👀模型库](https://mmrazor.readthedocs.io/en/main/get_started/model_zoo.html) | [🤔报告问题](https://github.com/open-mmlab/mmrazor/issues/new/choose) @@ -49,12 +49,12 @@ ## 说明 -MMRazor是一个可用于模型瘦身和AutoML的模型压缩工具箱,包含了3种主流的技术: +MMRazor是一个可用于模型瘦身和AutoML的模型压缩工具箱,包含了4种主流的技术: - 网络结构搜索 (NAS) - 模型剪枝 - 知识蒸馏 (KD) -- 量化 (下个版本发布) +- 量化 MMRazor是[OpenMMLab](https://openmmlab.com/)项目的一部分。 @@ -72,12 +72,11 @@ MMRazor是[OpenMMLab](https://openmmlab.com/)项目的一部分。 得益于更好的模块化设计,开发者仅用修改少量代码,甚至只用修改配置文件即可实现新的轻量化算法。 -下面是MMRazor设计和实现的概括图, 如果想了解更多的细节,请参考 [tutorials](/docs/en/tutorials/Tutorial_1_overview.md)。 +关于MMRazor设计和实现的概括图, 如果想了解更多的细节,请参考 [tutorials](/docs/en/tutorials/Tutorial_1_overview.md)。 -
- -
-
+## 近期更新 + +**默认分支目前为 main,且分支上的代码已经切换到 v1.0.0 版本。旧版 master 分支的代码现存在 0.x 分支上** ## 更新日志 @@ -85,7 +84,7 @@ MMRazor v0.3.1 版本已经在 2022.5.4 发布。 ## 基准测试和模型库 -测试结果可以在 [模型库](docs/en/model_zoo.md) 中找到. +测试结果可以在 [模型库](https://mmrazor.readthedocs.io/en/main/get_started/model_zoo.html) 中找到. 已经支持的算法: @@ -99,37 +98,69 @@ Neural Architecture Search Pruning -- [x] [AutoSlim(NeurIPS'2019)](/configs/pruning/autoslim) +- [x] [AutoSlim(NeurIPS'2019)](/configs/pruning/mmcls/autoslim) + +- [x] [L1-norm](/configs/pruning/mmcls/l1-norm) + +- [x] [Group Fisher](/configs/pruning/base/group_fisher) + +- [x] [DMCP](/configs/pruning/mmcls/dmcp) Knowledge Distillation -- [x] [CWD(ICCV'2021)](/configs/distill/cwd) +- [x] [CWD(ICCV'2021)](/configs/distill/mmdet/cwd) + +- [x] [WSLD(ICLR'2021)](/configs/distill/mmcls/wsld) + +- [x] [ABLoss](/configs/distill/mmcls/abloss) + +- [x] [BYOT](/configs/distill/mmcls/byot) + +- [x] [DAFL](/configs/distill/mmcls/dafl) + +- [x] [DFAD](/configs/distill/mmcls/dfad) + +- [x] [DKD](/configs/distill/mmcls/dkd) -- [x] [WSLD(ICLR'2021)](/configs/distill/wsld) +- [x] [Factor Transfer](/configs/distill/mmcls/factor_transfer) + +- [x] [FitNets](/configs/distill/mmcls/fitnets) + +- [x] [KD](/configs/distill/mmcls/kd) + +- [x] [OFD](/configs/distill/mmcls/ofd) + +- [x] [RKD](/configs/distill/mmcls/rkd) + +- [x] [ZSKT](/configs/distill/mmcls/zskt) + +- [x] [FBKD](/configs/distill/mmdet/fbkd) + +
+Quantization + +- [x] [PTQ](/configs/quantization/ptq/base) + +- [x] [QAT](/configs/quantization/qat/base) + +- [x] [LSQ](/configs/quantization/qat/lsq) + +
## 安装 MMRazor 依赖 [PyTorch](https://pytorch.org/) 和 [MMCV](https://github.com/open-mmlab/mmcv)。 -请参考[get_started.md](/docs/en/get_started.md)获取更详细的安装指南。 +请参考[安装教程](https://mmrazor.readthedocs.io/en/main/get_started/installation.html)获取更详细的安装指南。 ## 快速入门 -请参考 [get_started.md](/docs/en/get_started.md) 学习 MMRazor 的基本使用。 我们也提供了一些进阶教程: - -- [overview](/docs/en/tutorials/Tutorial_1_overview.md) -- [learn about configs](/docs/en/tutorials/Tutorial_2_learn_about_configs.md) -- [customize architectures](/docs/en/tutorials/Tutorial_3_customize_architectures.md) -- [customize nas algorithms](/docs/en/tutorials/Tutorial_4_customize_nas_algorithms.md) -- [customize pruning algorithms](/docs/en/tutorials/Tutorial_5_customize_pruning_algorithms.md) -- [customize kd algorithms](/docs/en/tutorials/Tutorial_6_customize_kd_algorithms.md) -- [customize mixed algorithms with our algorithm_components](/docs/en/tutorials/Tutorial_7_customize_mixed_algorithms_with_out_algorithms_components.md) -- [apply existing algorithms to other existing tasks](/docs/en/tutorials/Tutorial_8_apply_existing_algorithms_to_new_tasks.md) +请参考 [用户指引](https://mmrazor.readthedocs.io/en/main/user_guides/index.html) 学习 MMRazor 的基本使用。 我们也提供了一些[进阶教程](https://mmrazor.readthedocs.io/en/main/advanced_guides/index.html): ## 贡献指南 我们感谢所有的贡献者为改进和提升 MMRazor 所作出的努力。 -请参考[贡献指南](/.github/CONTRIBUTING.md)来了解参与项目贡献的相关指引。 +请参考[贡献指南](https://mmrazor.readthedocs.io/en/main/notes/contribution_guide.html)来了解参与项目贡献的相关指引。 ## 致谢 diff --git a/configs/distill/mmcls/deit/README.md b/configs/distill/mmcls/deit/README.md index 1057c81c2..4ccfa8cc9 100644 --- a/configs/distill/mmcls/deit/README.md +++ b/configs/distill/mmcls/deit/README.md @@ -17,9 +17,9 @@ Recently, neural networks purely based on attention were shown to address image ### Classification -| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download | -| -------- | --------- | ----------- | --------- | --------- | ------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| ImageNet | Deit-base | RegNety-160 | 83.24 | 96.33 | [config](deit-base_regnety160_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth?versionId=CAEQThiBgMCFteW0oBgiIDdmMWY2NGRiOGY1YzRmZWZiOTExMzQ2NjNlMjk2Nzcz) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.json?versionId=CAEQThiBgIDGos20oBgiIGVlNDgyM2M2ZTk5MzQyYjFhNTgwNGIzMjllZjg3YmZm) | +| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download | +| -------- | --------- | ----------- | --------- | --------- | ------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| ImageNet | Deit-base | RegNety-160 | 83.24 | 96.33 | [config](deit-base_regnety160_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.json) | ```{warning} Before training, please first install `timm`. diff --git a/configs/distill/mmcls/deit/metafile.yml b/configs/distill/mmcls/deit/metafile.yml index d46a91b64..6fe41c3a9 100644 --- a/configs/distill/mmcls/deit/metafile.yml +++ b/configs/distill/mmcls/deit/metafile.yml @@ -30,5 +30,5 @@ Models: Metrics: Top 1 Accuracy: 83.24 Top 5 Accuracy: 96.33 - Weights: https://download.openmmlab.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth?versionId=CAEQThiBgMCFteW0oBgiIDdmMWY2NGRiOGY1YzRmZWZiOTExMzQ2NjNlMjk2Nzcz + Weights: https://download.openmmlab.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth Config: configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py diff --git a/configs/distill/mmcls/ofd/README.md b/configs/distill/mmcls/ofd/README.md index 74a931b0d..eb789e840 100644 --- a/configs/distill/mmcls/ofd/README.md +++ b/configs/distill/mmcls/ofd/README.md @@ -22,16 +22,16 @@ We investigate the design aspects of feature distillation methods achieving netw #### Vanilla -| Dataset | Model | Top-1 (%) | Top-5 (%) | Download | -| ------- | ----------------------------------------------------------------------- | --------- | --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| CIFAR10 | [WRN16-2](../../../vanilla/mmcls/wide-resnet/wrn16-w2_b16x8_cifar10.py) | 93.43 | 99.75 | [model](https://download.openmmlab.com/mmrazor/v1/wide_resnet/wrn16_2_b16x8_cifar10_20220831_204709-446b466e.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/wide_resnet/wrn16_2_b16x8_cifar10_20220831_204709-446b466e.json) | -| CIFAR10 | [WRN28-4](../../../vanilla/mmcls/wide-resnet/wrn28-w4_b16x8_cifar10.py) | 95.49 | 99.81 | [model](https://download.openmmlab.com/mmrazor/v1/wide_resnet/wrn28_4_b16x8_cifar10_20220831_173536-d6f8725c.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/wide_resnet/wrn28_4_b16x8_cifar10_20220831_173536-d6f8725c.json) | +| Dataset | Model | Top-1 (%) | Download | +| ------- | ----------------------------------------------------------------------- | --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| CIFAR10 | [WRN16-2](../../../vanilla/mmcls/wide-resnet/wrn16-w2_b16x8_cifar10.py) | 93.43 | [model](https://download.openmmlab.com/mmrazor/v1/wide_resnet/wrn16_2_b16x8_cifar10_20220831_204709-446b466e.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/wide_resnet/wrn16_2_b16x8_cifar10_20220831_204709-446b466e.json) | +| CIFAR10 | [WRN28-4](../../../vanilla/mmcls/wide-resnet/wrn28-w4_b16x8_cifar10.py) | 95.49 | [model](https://download.openmmlab.com/mmrazor/v1/wide_resnet/wrn28_4_b16x8_cifar10_20220831_173536-d6f8725c.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/wide_resnet/wrn28_4_b16x8_cifar10_20220831_173536-d6f8725c.json) | #### Distillation -| Dataset | Model | Flops(M) | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download | -| ------- | ------- | -------- | ------- | --------- | --------- | ----------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| CIFAR10 | WRN16-2 | 101 | WRN28-4 | 95.23 | 99.79 | [config](./ofd_backbone_resnet50_resnet18_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmrazor/v1/overhaul/ofd_backbone_resnet50_resnet18_8xb16_cifar10_20220831_220553-f5d12e61.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/overhaul/ofd_backbone_resnet50_resnet18_8xb16_cifar10_20220831_220553-f5d12e61.json) | +| Dataset | Model | Flops(M) | Teacher | Top-1 (%) | Configs | Download | +| ------- | ------- | -------- | ------- | --------- | ----------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| CIFAR10 | WRN16-2 | 101 | WRN28-4 | 94.21 | [config](./ofd_backbone_resnet50_resnet18_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmrazor/v1/overhaul/ofd_backbone_resnet50_resnet18_8xb16_cifar10_20230417_192216-ace2908f.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/overhaul/ofd_backbone_resnet50_resnet18_8xb16_cifar10_20230417_192216-ace2908f.log) | ## Getting Started diff --git a/configs/distill/mmcls/ofd/metafile.yml b/configs/distill/mmcls/ofd/metafile.yml index 21716fd5c..cb176b1c3 100644 --- a/configs/distill/mmcls/ofd/metafile.yml +++ b/configs/distill/mmcls/ofd/metafile.yml @@ -33,6 +33,6 @@ Models: - Task: Image Classification Dataset: CIFAR-10 Metrics: - Top 1 Accuracy: 95.4400 + Top 1 Accuracy: 94.21 Config: configs/distill/mmcls/ofd/ofd_backbone_resnet50_resnet18_8xb16_cifar10.py - Weights: https://download.openmmlab.com/mmrazor/v1/overhaul/ofd_backbone_resnet50_resnet18_8xb16_cifar10_20220831_220553-f5d12e61.pth + Weights: https://download.openmmlab.com/mmrazor/v1/overhaul/ofd_backbone_resnet50_resnet18_8xb16_cifar10_20230417_192216-ace2908f.pth diff --git a/configs/nas/mmcls/autoslim/README.md b/configs/nas/mmcls/autoslim/README.md index e7292fae3..37a651ddb 100644 --- a/configs/nas/mmcls/autoslim/README.md +++ b/configs/nas/mmcls/autoslim/README.md @@ -60,11 +60,11 @@ CUDA_VISIBLE_DEVICES=0 PORT=29500 ./tools/dist_test.sh \ ### Subnet retrain -| Supernet | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | Subnet | Remark | -| :----------------- | :-------: | -------: | :-------: | :-------: | :---------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------: | -| MobileNet v2(x1.5) | 6.5 | 0.53 | 74.23 | 91.74 | [config](./autoslim_mbv2_subnet_8xb256_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.53M_acc-74.23_20211222-e5208bbd.pth) \| [log](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1kautoslim_mbv2_subnet_8xb256_in1k_paper_channel_cfg.log.json) | [channel](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.53M_acc-74.23_20211222-e5208bbd_channel_cfg.yaml) | official channel cfg | -| MobileNet v2(x1.5) | 5.77 | 0.32 | 72.73 | 90.83 | [config](./autoslim_mbv2_subnet_8xb256_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.32M_acc-72.73_20211222-b5b0b33c.pth) \| [log](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1kautoslim_mbv2_subnet_8xb256_in1k_paper_channel_cfg.log.json) | [channel](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.32M_acc-72.73_20211222-b5b0b33c_channel_cfg.yaml) | official channel cfg | -| MobileNet v2(x1.5) | 4.13 | 0.22 | 71.39 | 90.08 | [config](./autoslim_mbv2_subnet_8xb256_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.22M_acc-71.39_20211222-43117c7b.pth) \| [log](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1kautoslim_mbv2_subnet_8xb256_in1k_paper_channel_cfg.log.json) | [channel](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.22M_acc-71.39_20211222-43117c7b_channel_cfg.yaml) | official channel cfg | +| Supernet | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | Subnet | Remark | +| :----------------- | :-------: | -------: | :-------: | :-------: | :---------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------: | +| MobileNet v2(x1.5) | 6.5 | 0.53 | 74.23 | 91.74 | [config](./autoslim_mbv2_subnet_8xb256_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v1/autoslim/autoslim_mbv2_subnet_8xb256_in1k_flops-530M_acc-74.23_20220715-aa8754fe.pth) \| [log](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1kautoslim_mbv2_subnet_8xb256_in1k_paper_channel_cfg.log.json) | [channel](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.53M_acc-74.23_20211222-e5208bbd_channel_cfg.yaml) | official channel cfg | +| MobileNet v2(x1.5) | 5.77 | 0.32 | 72.73 | 90.83 | [config](./autoslim_mbv2_subnet_8xb256_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v1/autoslim/autoslim_mbv2_subnet_8xb256_in1k_flops-320M_acc-72.73_20220715-9aa8f8ae.pth) \| [log](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1kautoslim_mbv2_subnet_8xb256_in1k_paper_channel_cfg.log.json) | [channel](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.32M_acc-72.73_20211222-b5b0b33c_channel_cfg.yaml) | official channel cfg | +| MobileNet v2(x1.5) | 4.13 | 0.22 | 71.39 | 90.08 | [config](./autoslim_mbv2_subnet_8xb256_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v1/autoslim/autoslim_mbv2_subnet_8xb256_in1k_flops-220M_acc-71.4_20220715-9c288f3b.pth) \| [log](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1kautoslim_mbv2_subnet_8xb256_in1k_paper_channel_cfg.log.json) | [channel](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.22M_acc-71.39_20211222-43117c7b_channel_cfg.yaml) | official channel cfg | Note that we ran the official code and the Top-1 Acc of the models with official channel cfg are 73.8%, 72.5% and 71.1%. And there are 3 differences between our diff --git a/configs/nas/mmcls/darts/darts_subnet_1xb96_cifar10_2.0.py b/configs/nas/mmcls/darts/darts_subnet_1xb96_cifar10_2.0.py index ab9ee6180..c05a3b435 100644 --- a/configs/nas/mmcls/darts/darts_subnet_1xb96_cifar10_2.0.py +++ b/configs/nas/mmcls/darts/darts_subnet_1xb96_cifar10_2.0.py @@ -37,7 +37,7 @@ init_cfg=dict( type='Pretrained', checkpoint= # noqa: E251 - 'https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/darts/darts_subnetnet_1xb96_cifar10_acc-97.27_20211222-17e42600_latest.pth', # noqa: E501 + 'https://download.openmmlab.com/mmrazor/v1/darts/darts_subnetnet_1xb96_cifar10_acc-97.27_20211222-17e42600_latest.pth', # noqa: E501 prefix='architecture.')) model_wrapper_cfg = None diff --git a/configs/nas/mmcls/darts/metafile.yml b/configs/nas/mmcls/darts/metafile.yml index b92f28dd7..b262a6960 100644 --- a/configs/nas/mmcls/darts/metafile.yml +++ b/configs/nas/mmcls/darts/metafile.yml @@ -25,4 +25,4 @@ Models: Top 1 Accuracy: 97.32 Top 5 Accuracy: 99.94 Config: configs/nas/mmcls/darts/darts_subnet_1xb96_cifar10_2.0.py - Weights: https://download.openmmlab.com/mmrazor/v1/darts/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-23ca1e10.pth + Weights: https://download.openmmlab.com/mmrazor/v1/darts/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921_latest.pth diff --git a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py index 0da0388f1..e10daec7d 100644 --- a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py +++ b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py @@ -9,7 +9,7 @@ init_cfg=dict( type='Pretrained', checkpoint= # noqa: E251 - 'detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20220715-61d2e900_v1.pth', # noqa: E501 + 'https://download.openmmlab.com/mmrazor/v1/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20220715-61d2e900_v1.pth', # noqa: E501 prefix='architecture.')) find_unused_parameters = False diff --git a/configs/pruning/mmcls/dmcp/README.md b/configs/pruning/mmcls/dmcp/README.md index d2c7aa7a4..3a96b61c3 100644 --- a/configs/pruning/mmcls/dmcp/README.md +++ b/configs/pruning/mmcls/dmcp/README.md @@ -26,25 +26,28 @@ GPUS=32 sh tools/slurm_train.sh $PARTITION $JOB_NAME \ --work-dir $WORK_DIR ``` -## Results and models + + + + + -**Note** + ## Citation diff --git a/configs/pruning/mmcls/dmcp/metafile.yml b/configs/pruning/mmcls/dmcp/metafile.yml index 4c1268093..131f5c289 100644 --- a/configs/pruning/mmcls/dmcp/metafile.yml +++ b/configs/pruning/mmcls/dmcp/metafile.yml @@ -1,19 +1,19 @@ -Models: - - Name: dmcp_resnet50_subnet_32xb64 - In Collection: DMCP - Config: configs/pruning/mmcls/dmcp/dmcp_resnet50_subnet_32xb64.py - Weights: https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/resnet50/2G/DMCP_R50_2G.pth - Results: - - Task: Image Classification - Dataset: ImageNet-1k - Metrics: - Top 1 Accuracy: 76.11 - - Name: dmcp_mbv2_subnet_32xb64 - In Collection: DMCP - Config: configs/pruning/mmcls/dmcp/dmcp_mbv2_subnet_32xb64.py - Weights: https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/mobilenetv2/100M/DMCP_MBV2_100M.pth - Results: - - Task: Image Classification - Dataset: ImageNet-1k - Metrics: - Top 1 Accuracy: 67.22 +# Models: + # - Name: dmcp_resnet50_subnet_32xb64 + # In Collection: DMCP + # Config: configs/pruning/mmcls/dmcp/dmcp_resnet50_subnet_32xb64.py + # Weights: https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/resnet50/2G/DMCP_R50_2G.pth + # Results: + # - Task: Image Classification + # Dataset: ImageNet-1k + # Metrics: + # Top 1 Accuracy: 76.11 + # - Name: dmcp_mbv2_subnet_32xb64 + # In Collection: DMCP + # Config: configs/pruning/mmcls/dmcp/dmcp_mbv2_subnet_32xb64.py + # Weights: https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/mobilenetv2/100M/DMCP_MBV2_100M.pth + # Results: + # - Task: Image Classification + # Dataset: ImageNet-1k + # Metrics: + # Top 1 Accuracy: 67.22 diff --git a/configs/quantization/deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py b/configs/quantization/deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py new file mode 100644 index 000000000..d1fc673c5 --- /dev/null +++ b/configs/quantization/deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py @@ -0,0 +1,30 @@ +deploy_cfg = dict( + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='end2end.onnx', + input_names=['input'], + output_names=['output'], + input_shape=None, + optimize=True, + dynamic_axes={ + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'output': { + 0: 'batch' + } + }), + backend_config=dict( + type='openvino', + model_inputs=[dict(opt_shapes=dict(input=[1, 3, 224, 224]))]), + codebase_config=dict(type='mmcls', task='Classification'), + function_record_to_pop=[ + 'mmcls.models.classifiers.ImageClassifier.forward', + 'mmcls.models.classifiers.BaseClassifier.forward' + ], +) diff --git a/configs/quantization/deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py b/configs/quantization/deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py new file mode 100644 index 000000000..a562c370b --- /dev/null +++ b/configs/quantization/deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py @@ -0,0 +1,39 @@ +deploy_cfg = dict( + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='end2end.onnx', + input_names=['input'], + output_names=['output'], + input_shape=[224, 224], + optimize=True, + dynamic_axes=dict( + input=dict({ + 0: 'batch', + 2: 'height', + 3: 'width' + }), + output=dict({0: 'batch'}))), + codebase_config=dict(type='mmcls', task='Classification'), + backend_config=dict( + type='tensorrt', + common_config=dict( + fp16_mode=False, + max_workspace_size=1073741824, + int8_mode=True, + explicit_quant_mode=True), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 224, 224], + opt_shape=[4, 3, 224, 224], + max_shape=[8, 3, 224, 224]))) + ]), + function_record_to_pop=[ + 'mmcls.models.classifiers.ImageClassifier.forward', + 'mmcls.models.classifiers.BaseClassifier.forward', 'torch.cat' + ], +) diff --git a/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py b/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py new file mode 100644 index 000000000..c76898d0b --- /dev/null +++ b/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py @@ -0,0 +1,47 @@ +deploy_cfg = dict( + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='end2end.onnx', + input_shape=None, + input_names=['input'], + output_names=['dets', 'labels'], + optimize=True, + dynamic_axes={ + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'dets': { + 0: 'batch', + 1: 'num_dets', + }, + 'labels': { + 0: 'batch', + 1: 'num_dets', + }, + }), + backend_config=dict( + type='openvino', + model_inputs=[dict(opt_shapes=dict(input=[1, 3, 800, 1344]))]), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + model_type='end2end', + post_processing=dict( + score_threshold=0.05, + confidence_threshold=0.005, # for YOLOv3 + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + )), + function_record_to_pop=[ + 'mmdet.models.detectors.single_stage.SingleStageDetector.forward', + 'mmdet.models.detectors.two_stage.TwoStageDetector.forward', + 'mmdet.models.detectors.single_stage_instance_seg.SingleStageInstanceSegmentor.forward' # noqa: E501 + ]) diff --git a/configs/quantization/deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py b/configs/quantization/deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py new file mode 100644 index 000000000..1061d6bd6 --- /dev/null +++ b/configs/quantization/deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py @@ -0,0 +1,58 @@ +deploy_cfg = dict( + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='end2end.onnx', + input_names=['input'], + output_names=['dets', 'labels'], + input_shape=None, + optimize=True, + dynamic_axes=dict( + input=dict({ + 0: 'batch', + 2: 'height', + 3: 'width' + }), + dets=dict({ + 0: 'batch', + 1: 'num_dets' + }), + labels=dict({ + 0: 'batch', + 1: 'num_dets' + }))), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + model_type='end2end', + post_processing=dict( + score_threshold=0.05, + confidence_threshold=0.005, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1)), + backend_config=dict( + type='tensorrt', + common_config=dict( + fp16_mode=False, + max_workspace_size=1073741824, + int8_mode=True, + explicit_quant_mode=True), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 320, 320], + opt_shape=[1, 3, 800, 1344], + max_shape=[1, 3, 1344, 1344]))) + ]), + function_record_to_pop=[ + 'mmdet.models.detectors.single_stage.SingleStageDetector.forward', + 'mmdet.models.detectors.two_stage.TwoStageDetector.forward', + 'mmdet.models.detectors.single_stage_instance_seg.SingleStageInstanceSegmentor.forward', # noqa: E501 + 'torch.cat' + ]) diff --git a/configs/quantization/ptq/base/README.md b/configs/quantization/ptq/base/README.md new file mode 100644 index 000000000..1a9f53519 --- /dev/null +++ b/configs/quantization/ptq/base/README.md @@ -0,0 +1,59 @@ +# Post-Training Quantization (PTQ) + +> [A White Paper on Neural Network Quantization](https://arxiv.org/abs/2106.08295) + + + +## Abstract + +While neural networks have advanced the frontiers in many applications, they often come at a high computational cost. Reducing the power and latency of neural network inference is key if we want to integrate modern networks into edge devices with strict power and compute requirements. Neural network quantization is one of the most effective ways of achieving these savings but the additional noise it induces can lead to accuracy degradation. In this white paper, we introduce state-of-the-art algorithms for mitigating the impact of quantization noise on the network's performance while maintaining low-bit weights and activations. We start with a hardware motivated introduction to quantization and then consider two main classes of algorithms: Post-Training Quantization (PTQ) and Quantization-Aware-Training (QAT). PTQ requires no re-training or labelled data and is thus a lightweight push-button approach to quantization. In most cases, PTQ is sufficient for achieving 8-bit quantization with close to floating-point accuracy. QAT requires fine-tuning and access to labeled training data but enables lower bit quantization with competitive results. For both solutions, we provide tested pipelines based on existing literature and extensive experimentation that lead to state-of-the-art performance for common deep learning models and tasks. + +## Results and models + +### Classification + +| Model | Dataset | Backend | Top 1 Acc(fp32) | Top 1 Acc(int8) | Top 1 Acc(deployed) | Config | Download | +| ------------ | -------- | -------- | --------------- | --------------- | ------------------- | ----------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| resnet18 | ImageNet | openvino | 69.90 | 69.742 | 69.74 | [config](./ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_resnet18_8xb32_in1k_calib32xb32_20230330_163655-2386d965.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_resnet18_8xb32_in1k_calib32xb32_20230330_163655-2386d965.log) | +| resnet50 | ImageNet | openvino | 76.55 | 76.374 | 76.378 | [config](./ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_resnet50_8xb32_in1k_calib32xb32_20230330_170115-2acd6014.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_resnet50_8xb32_in1k_calib32xb32_20230330_170115-2acd6014.log) | +| mobilenet_v2 | ImageNet | openvino | 71.86 | 70.224 | 70.292 | [config](./ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_mbv2_8xb32_in1k_calib32xb32_20230330_170909-364822ad.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_mbv2_8xb32_in1k_calib32xb32_20230330_170909-364822ad.log) | +| resnet18 | ImageNet | tensorrt | 69.90 | 69.762 | 69.85 | [config](./ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32_20230331_144323-640b272e.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32_20230331_144323-640b272e.log) | +| resnet50 | ImageNet | tensorrt | 76.55 | 76.372 | 76.374 | [config](./ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32_20230331_145011-d2da300f.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32_20230331_145011-d2da300f.log) | +| mobilenet_v2 | ImageNet | tensorrt | 71.86 | 70.324 | 70.548 | [config](./ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32_20230331_153131-335988e4.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32_20230331_153131-335988e4.log) | + +### Detection + +| Model | Dataset | Backend | box AP(fp32) | box AP(int8) | box AP(deployed) | Config | Download | +| -------------- | ------- | -------- | ------------ | ------------ | ---------------- | -------------------------------------------------------------- | ------------------------ | +| retina_r50_fpn | COCO | openvino | 36.5 | 36.3 | 36.3 | [config](./ptq_openvino_retina_r50_1x_coco_calib32xb32.py) | [model](<>) \| [log](<>) | +| yolox_s | COCO | openvino | 40.5 | 38.5 | 38.5 | [config](./ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py) | [model](<>) \| [log](<>) | +| retina_r50_fpn | COCO | tensorrt | 36.5 | 36.2 | 36.3 | [config](./ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py) | [model](<>) \| [log](<>) | +| yolox_s | COCO | tensorrt | 40.5 | 38.8 | 39.3 | [config](./ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py) | [model](<>) \| [log](<>) | + +## Citation + +```latex + @misc{Nagel_Fournarakis_Amjad_Bondarenko_Baalen_Blankevoort_2021, + title={A White Paper on Neural Network Quantization}, + journal={Cornell University - arXiv}, + author={Nagel, Markus and Fournarakis, Marios and Amjad, RanaAli and Bondarenko, Yelysei and Baalen, Martvan and Blankevoort, Tijmen}, + year={2021}, + month={Jun} + } +``` + +## Getting Started + +**PTQ for pretrain model** + +``` +python tools/ptq.py ${CONFIG} +``` + +**Test for quantized model** + +``` +python tools/test.py ${CONFIG} ${CKPT} +``` + +For more details, please refer to [Quantization User Guide](https://mmrazor.readthedocs.io/en/main/user_guides/quantization_user_guide.html) diff --git a/configs/quantization/ptq/base/metafile.yml b/configs/quantization/ptq/base/metafile.yml new file mode 100644 index 000000000..1ebceab4b --- /dev/null +++ b/configs/quantization/ptq/base/metafile.yml @@ -0,0 +1,164 @@ +Collections: + - Name: PTQ + README: configs/quantization/ptq/base/README.md +Models: + - Name: ptq_openvino_mbv2_8xb32_in1k_calib32xb32 + In Collection: PTQ + Metadata: + Backend: openvino + Float Model: + Config: mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth + Metrics: + Top 1 Accuracy: 71.86 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 70.224 + Config: configs/quantization/ptq/base/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_mbv2_8xb32_in1k_calib32xb32_20230330_170909-364822ad.pth + - Name: ptq_openvino_resnet18_8xb32_in1k_calib32xb32 + In Collection: PTQ + Metadata: + Backend: openvino + Float Model: + Config: mmcls::resnet/resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth + Metrics: + Top 1 Accuracy: 69.90 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 69.742 + Config: configs/quantization/ptq/base/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_resnet18_8xb32_in1k_calib32xb32_20230330_163655-2386d965.pth + - Name: ptq_openvino_resnet50_8xb32_in1k_calib32xb32 + In Collection: PTQ + Metadata: + Backend: openvino + Float Model: + Config: mmcls::resnet/resnet50_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth + Metrics: + Top 1 Accuracy: 76.55 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 76.374 + Config: configs/quantization/ptq/base/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_resnet50_8xb32_in1k_calib32xb32_20230330_170115-2acd6014.pth + - Name: ptq_openvino_retina_r50_1x_coco_calib32xb32 + In Collection: PTQ + Metadata: + Backend: openvino + Float Model: + Config: mmdet::retinanet/retinanet_r50_fpn_1x_coco.py + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth + Metrics: + box AP: 36.5 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 36.3 + Config: configs/quantization/ptq/base/ptq_openvino_retina_r50_1x_coco_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_retina_r50_1x_coco_calib32xb32_20230330_172645-80eea5b6.pth + - Name: ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32 + In Collection: PTQ + Metadata: + Backend: openvino + Float Model: + Config: mmdet::yolox/yolox_s_8xb8-300e_coco.py + Weights: https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth + Metrics: + box AP: 40.5 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 38.5 + Config: configs/quantization/ptq/base/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32_20230330_175747-f1a0a2f4.pth + - Name: ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32 + In Collection: PTQ + Metadata: + Backend: tensorrt + Float Model: + Config: mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth + Metrics: + Top 1 Accuracy: 71.86 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 70.324 + Config: configs/quantization/ptq/base/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32_20230331_153131-335988e4.pth + - Name: ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32 + In Collection: PTQ + Metadata: + Backend: tensorrt + Float Model: + Config: mmcls::resnet/resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth + Metrics: + Top 1 Accuracy: 69.90 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 69.762 + Config: configs/quantization/ptq/base/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32_20230331_144323-640b272e.pth + - Name: ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32 + In Collection: PTQ + Metadata: + Backend: tensorrt + Float Model: + Config: mmcls::resnet/resnet50_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth + Metrics: + Top 1 Accuracy: 76.55 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 76.372 + Config: configs/quantization/ptq/base/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32_20230331_145011-d2da300f.pth + - Name: ptq_tensorrt_retina_r50_1x_coco_calib32xb32 + In Collection: PTQ + Metadata: + Backend: tensorrt + Float Model: + Config: mmdet::retinanet/retinanet_r50_fpn_1x_coco.py + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth + Metrics: + box AP: 36.5 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 36.2 + Config: configs/quantization/ptq/base/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_retina_r50_1x_coco_calib32xb32_20230330_205741-4c5c10c4.pth + - Name: ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32 + In Collection: PTQ + Metadata: + Backend: tensorrt + Float Model: + Config: mmdet::yolox/yolox_s_8xb8-300e_coco.py + Weights: https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth + Metrics: + box AP: 40.5 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 38.8 + Config: configs/quantization/ptq/base/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32_20230331_155139-f2021e57.pth diff --git a/configs/quantization/ptq/base/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py new file mode 100644 index 000000000..efa2a75dd --- /dev/null +++ b/configs/quantization/ptq/base/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py @@ -0,0 +1,54 @@ +_base_ = [ + 'mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py', + '../../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py' +] + +_base_.val_dataloader.batch_size = 32 + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=_base_.val_dataloader, + calibrate_steps=32, +) + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth' # noqa: E501 + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) diff --git a/configs/quantization/ptq/base/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py new file mode 100644 index 000000000..b548b15f5 --- /dev/null +++ b/configs/quantization/ptq/base/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py @@ -0,0 +1,51 @@ +_base_ = [ + 'mmcls::resnet/resnet18_8xb32_in1k.py', + '../../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py' +] + +_base_.val_dataloader.batch_size = 32 + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=_base_.val_dataloader, + calibrate_steps=32, +) + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +model_wrapper_cfg = dict(type='mmrazor.MMArchitectureQuantDDP', ) diff --git a/configs/quantization/ptq/base/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py new file mode 100644 index 000000000..14802a442 --- /dev/null +++ b/configs/quantization/ptq/base/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py @@ -0,0 +1,50 @@ +_base_ = [ + 'mmcls::resnet/resnet50_8xb32_in1k.py', + '../../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py' +] + +_base_.val_dataloader.batch_size = 32 + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=_base_.val_dataloader, + calibrate_steps=32, +) + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501 + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) +model_wrapper_cfg = dict(type='mmrazor.MMArchitectureQuantDDP', ) diff --git a/configs/quantization/ptq/base/ptq_openvino_retina_r50_1x_coco_calib32xb32.py b/configs/quantization/ptq/base/ptq_openvino_retina_r50_1x_coco_calib32xb32.py new file mode 100644 index 000000000..e35e6270e --- /dev/null +++ b/configs/quantization/ptq/base/ptq_openvino_retina_r50_1x_coco_calib32xb32.py @@ -0,0 +1,52 @@ +_base_ = [ + 'mmdet::retinanet/retinanet_r50_fpn_1x_coco.py', + '../../deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py' +] + +_base_.val_dataloader.batch_size = 32 + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=_base_.val_dataloader, + calibrate_steps=32, +) + +float_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmdet.models.dense_heads.base_dense_head.BaseDenseHead.predict_by_feat', # noqa: E501 + 'mmdet.models.dense_heads.anchor_head.AnchorHead.loss_by_feat', + ]))) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) diff --git a/configs/quantization/ptq/base/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py b/configs/quantization/ptq/base/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py new file mode 100644 index 000000000..bab9ed021 --- /dev/null +++ b/configs/quantization/ptq/base/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py @@ -0,0 +1,57 @@ +_base_ = [ + 'mmdet::yolox/yolox_s_8xb8-300e_coco.py', + '../../deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py' +] + +_base_.val_dataloader.batch_size = 32 + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=_base_.val_dataloader, + calibrate_steps=32, +) + +float_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + pad_size_divisor=32, + batch_augments=[ + dict( + type='mmdet.BatchSyncRandomResize', + random_size_range=(480, 800), + size_divisor=32, + interval=10) + ]), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmdet.models.dense_heads.yolox_head.YOLOXHead.predict_by_feat', # noqa: E501 + 'mmdet.models.dense_heads.yolox_head.YOLOXHead.loss_by_feat', + ]))) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +custom_hooks = [] diff --git a/configs/quantization/ptq/base/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py new file mode 100644 index 000000000..68b6d4f97 --- /dev/null +++ b/configs/quantization/ptq/base/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py @@ -0,0 +1,54 @@ +_base_ = [ + 'mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py', + '../../deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py' # noqa: E501 +] + +_base_.val_dataloader.batch_size = 32 + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=_base_.val_dataloader, + calibrate_steps=32, +) + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth' # noqa: E501 + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.TensorRTQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) diff --git a/configs/quantization/ptq/base/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py new file mode 100644 index 000000000..41d08812c --- /dev/null +++ b/configs/quantization/ptq/base/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py @@ -0,0 +1,51 @@ +_base_ = [ + 'mmcls::resnet/resnet18_8xb32_in1k.py', + '../../deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py' # noqa: E501 +] + +_base_.val_dataloader.batch_size = 32 + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=_base_.val_dataloader, + calibrate_steps=32, +) + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.TensorRTQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +model_wrapper_cfg = dict(type='mmrazor.MMArchitectureQuantDDP', ) diff --git a/configs/quantization/ptq/base/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py new file mode 100644 index 000000000..e4fa955dc --- /dev/null +++ b/configs/quantization/ptq/base/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py @@ -0,0 +1,51 @@ +_base_ = [ + 'mmcls::resnet/resnet50_8xb32_in1k.py', + '../../deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py' # noqa: E501 +] + +_base_.val_dataloader.batch_size = 32 + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=_base_.val_dataloader, + calibrate_steps=32, +) + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501 + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.TensorRTQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +model_wrapper_cfg = dict(type='mmrazor.MMArchitectureQuantDDP', ) diff --git a/configs/quantization/ptq/base/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py b/configs/quantization/ptq/base/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py new file mode 100644 index 000000000..4ca81a920 --- /dev/null +++ b/configs/quantization/ptq/base/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py @@ -0,0 +1,53 @@ +_base_ = [ + 'mmdet::retinanet/retinanet_r50_fpn_1x_coco.py', + '../../deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py' # noqa: E501 +] + +_base_.val_dataloader.batch_size = 32 + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=_base_.val_dataloader, + calibrate_steps=32, +) + +float_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.TensorRTQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmdet.models.dense_heads.base_dense_head.BaseDenseHead.predict_by_feat', # noqa: E501 + 'mmdet.models.dense_heads.anchor_head.AnchorHead.loss_by_feat', + ]))) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) diff --git a/configs/quantization/ptq/base/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py b/configs/quantization/ptq/base/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py new file mode 100644 index 000000000..51e4f8f11 --- /dev/null +++ b/configs/quantization/ptq/base/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py @@ -0,0 +1,58 @@ +_base_ = [ + 'mmdet::yolox/yolox_s_8xb8-300e_coco.py', + '../../deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py' # noqa: E501 +] + +_base_.val_dataloader.batch_size = 32 + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=_base_.val_dataloader, + calibrate_steps=32, +) + +float_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + pad_size_divisor=32, + batch_augments=[ + dict( + type='mmdet.BatchSyncRandomResize', + random_size_range=(480, 800), + size_divisor=32, + interval=10) + ]), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.TensorRTQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmdet.models.dense_heads.yolox_head.YOLOXHead.predict_by_feat', # noqa: E501 + 'mmdet.models.dense_heads.yolox_head.YOLOXHead.loss_by_feat', + ]))) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +custom_hooks = [] diff --git a/configs/quantization/qat/base/README.md b/configs/quantization/qat/base/README.md new file mode 100644 index 000000000..ec4541eb4 --- /dev/null +++ b/configs/quantization/qat/base/README.md @@ -0,0 +1,45 @@ +# Quantization-Aware-Training (QAT) + +> [A White Paper on Neural Network Quantization](https://arxiv.org/abs/2106.08295) + + + +## Abstract + +While neural networks have advanced the frontiers in many applications, they often come at a high computational cost. Reducing the power and latency of neural network inference is key if we want to integrate modern networks into edge devices with strict power and compute requirements. Neural network quantization is one of the most effective ways of achieving these savings but the additional noise it induces can lead to accuracy degradation. In this white paper, we introduce state-of-the-art algorithms for mitigating the impact of quantization noise on the network's performance while maintaining low-bit weights and activations. We start with a hardware motivated introduction to quantization and then consider two main classes of algorithms: Post-Training Quantization (PTQ) and Quantization-Aware-Training (QAT). PTQ requires no re-training or labelled data and is thus a lightweight push-button approach to quantization. In most cases, PTQ is sufficient for achieving 8-bit quantization with close to floating-point accuracy. QAT requires fine-tuning and access to labeled training data but enables lower bit quantization with competitive results. For both solutions, we provide tested pipelines based on existing literature and extensive experimentation that lead to state-of-the-art performance for common deep learning models and tasks. + +## Results and models + +### Classification + +| Model | Dataset | Backend | Top 1 Acc(fp32) | Top 1 Acc(int8) | Config | Download | +| -------- | -------- | -------- | --------------- | --------------- | --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| resnet18 | ImageNet | openvino | 69.90 | 69.98 | [config](./qat_openvino_resnet18_10e_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/qat_openvino_resnet18_8xb32_10e_in1k_20230413_172732-5b9ff01d.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/qat_openvino_resnet18_8xb32_10e_in1k_20230413_172732-5b9ff01d.log) | + +## Citation + +```latex + @misc{Nagel_Fournarakis_Amjad_Bondarenko_Baalen_Blankevoort_2021, + title={A White Paper on Neural Network Quantization}, + journal={Cornell University - arXiv}, + author={Nagel, Markus and Fournarakis, Marios and Amjad, RanaAli and Bondarenko, Yelysei and Baalen, Martvan and Blankevoort, Tijmen}, + year={2021}, + month={Jun} + } +``` + +## Getting Started + +**QAT for pretrain model** + +``` +python tools/train.py ${CONFIG} +``` + +**Test for quantized model** + +``` +python tools/test.py ${CONFIG} ${CKPT} +``` + +For more details, please refer to [Quantization User Guide](https://mmrazor.readthedocs.io/en/main/user_guides/quantization_user_guide.html) diff --git a/configs/quantization/qat/base/metafile.yml b/configs/quantization/qat/base/metafile.yml new file mode 100644 index 000000000..bd4015a50 --- /dev/null +++ b/configs/quantization/qat/base/metafile.yml @@ -0,0 +1,20 @@ +Collections: + - Name: QAT + README: configs/quantization/qat/base/README.md +Models: + - Name: qat_openvino_resnet18_10e_8xb32_in1k.py + In Collection: QAT + Metadata: + Backend: openvino + Float Model: + Config: mmcls::resnet/resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth + Metrics: + Top 1 Accuracy: 69.90 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 69.98 + Config: configs/quantization/qat/base/qat_openvino_resnet18_10e_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/qat_openvino_resnet18_8xb32_10e_in1k_20230413_172732-5b9ff01d.pth diff --git a/configs/quantization/qat/base/qat_openvino_resnet18_10e_8xb32_in1k.py b/configs/quantization/qat/base/qat_openvino_resnet18_10e_8xb32_in1k.py new file mode 100644 index 000000000..261af7abb --- /dev/null +++ b/configs/quantization/qat/base/qat_openvino_resnet18_10e_8xb32_in1k.py @@ -0,0 +1,62 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +resnet = _base_.model +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=resnet, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, type='ConstantLR', factor=1.0, by_epoch=True) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=False) + +# train, val, test setting +train_cfg = dict( + _delete_=True, + type='mmrazor.QATEpochBasedLoop', + max_epochs=10, + val_interval=1) +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') + +# Make sure the buffer such as min_val/max_val in saved checkpoint is the same +# among different rank. +default_hooks = dict(sync=dict(type='SyncBuffersHook')) diff --git a/configs/quantization/qat/lsq/README.md b/configs/quantization/qat/lsq/README.md new file mode 100644 index 000000000..7babfa96e --- /dev/null +++ b/configs/quantization/qat/lsq/README.md @@ -0,0 +1,46 @@ +# Learned Step Size Quantization (LSQ) + +> [Learned Step Size Quantization](https://arxiv.org/abs/1902.08153) + + + +## Abstract + +Deep networks run with low precision operations at inference time offer power and space advantages over high precision alternatives, but need to overcome the challenge of maintaining high accuracy as precision decreases. Here, we present a method for training such networks, Learned Step Size Quantization, that achieves the highest accuracy to date on the ImageNet dataset when using models, from a variety of architectures, with weights and activations quantized to 2-, 3- or 4-bits of precision, and that can train 3-bit models that reach full precision baseline accuracy. Our approach builds upon existing methods for learning weights in quantized networks by improving how the quantizer itself is configured. Specifically, we introduce a novel means to estimate and scale the task loss gradient at each weight and activation layer's quantizer step size, such that it can be learned in conjunction with other network parameters. This approach works using different levels of precision as needed for a given system and requires only a simple modification of existing training code. + +## Results and models + +### Classification + +| Model | Dataset | Backend | Top 1 Acc(fp32) | Top 1 Acc(int8) | Max Epochs | Config | Download | +| -------- | -------- | -------- | --------------- | --------------- | ---------- | ---------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| resnet18 | ImageNet | openvino | 69.90 | 69.418 | 10 | [config](./lsq_openvino_resnet18_8xb32_10e_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_10e_in1k_20230413_224237-36eac1f1.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_10e_in1k_20230413_224237-36eac1f1.log) | +| resnet18 | ImageNet | openvino | 69.90 | 69.992 | 100 | [config](./lsq_openvino_resnet18_8xb32_100e_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_100e_in1k_20230402_173316-ca5993bf.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_100e_in1k_20230402_173316-ca5993bf.log) | + +## Citation + +```latex + @misc{Esser_McKinstry_Bablani_Appuswamy_Modha_2019, + title={Learned Step Size Quantization}, + journal={arXiv: Learning}, + author={Esser, StevenK. and McKinstry, JeffreyL. and Bablani, Deepika and Appuswamy, Rathinakumar and Modha, DharmendraS.}, + year={2019}, + month={Feb} + } +``` + +## Getting Started + +**QAT for pretrain model** + +``` +python tools/train.py ${CONFIG} +``` + +**Test for quantized model** + +``` +python tools/test.py ${CONFIG} ${CKPT} +``` + +For more details, please refer to [Quantization User Guide](https://mmrazor.readthedocs.io/en/main/user_guides/quantization_user_guide.html) diff --git a/configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_100e_in1k.py b/configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_100e_in1k.py new file mode 100644 index 000000000..00e424141 --- /dev/null +++ b/configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_100e_in1k.py @@ -0,0 +1,68 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +resnet = _base_.model +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.LSQPerChannelObserver'), + a_observer=dict(type='mmrazor.LSQObserver'), + w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=resnet, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, + type='CosineAnnealingLR', + T_max=100, + by_epoch=True, + begin=0, + end=100) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +# train, val, test setting +train_cfg = dict( + _delete_=True, + type='mmrazor.LSQEpochBasedLoop', + max_epochs=100, + val_interval=1, + freeze_bn_begin=1) +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') + +# Make sure the buffer such as min_val/max_val in saved checkpoint is the same +# among different rank. +default_hooks = dict(sync=dict(type='SyncBuffersHook')) diff --git a/configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_10e_in1k.py b/configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_10e_in1k.py new file mode 100644 index 000000000..f931ddaf5 --- /dev/null +++ b/configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_10e_in1k.py @@ -0,0 +1,63 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +resnet = _base_.model +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.LSQPerChannelObserver'), + a_observer=dict(type='mmrazor.LSQObserver'), + w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=resnet, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, type='ConstantLR', factor=1.0, by_epoch=True) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +# train, val, test setting +train_cfg = dict( + _delete_=True, + type='mmrazor.LSQEpochBasedLoop', + max_epochs=10, + val_interval=1, + freeze_bn_begin=1) +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') + +# Make sure the buffer such as min_val/max_val in saved checkpoint is the same +# among different rank. +default_hooks = dict(sync=dict(type='SyncBuffersHook')) diff --git a/configs/quantization/qat/lsq/metafile.yml b/configs/quantization/qat/lsq/metafile.yml new file mode 100644 index 000000000..89308d333 --- /dev/null +++ b/configs/quantization/qat/lsq/metafile.yml @@ -0,0 +1,36 @@ +Collections: + - Name: LSQ + README: configs/quantization/qat/lsq/README.md +Models: + - Name: lsq_openvino_resnet18_8xb32_10e_in1k.py + In Collection: LSQ + Metadata: + Backend: openvino + Float Model: + Config: mmcls::resnet/resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth + Metrics: + Top 1 Accuracy: 69.90 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 69.418 + Config: configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_10e_in1k.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_10e_in1k_20230413_224237-36eac1f1.pth + - Name: lsq_openvino_resnet18_8xb32_100e_in1k.py + In Collection: LSQ + Metadata: + Backend: openvino + Float Model: + Config: mmcls::resnet/resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth + Metrics: + Top 1 Accuracy: 69.90 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 69.992 + Config: configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_100e_in1k.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_100e_in1k_20230402_173316-ca5993bf.pth diff --git a/docs/en/advanced_guides/algorithm.md b/docs/en/advanced_guides/algorithm.md index ea2670abc..ae632db6e 100644 --- a/docs/en/advanced_guides/algorithm.md +++ b/docs/en/advanced_guides/algorithm.md @@ -108,7 +108,7 @@ architecture = _base_.model - Use your customized model as below, which is an example of defining a VGG model as our architecture. ```{note} -How to customize architectures can refer to our tutorial: [Customize Architectures](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/customize_architectures.html). +How to customize architectures can refer to our tutorial: [Customize Architectures](https://mmrazor.readthedocs.io/en/main/advanced_guides/customize_architectures.html). ``` ```Python @@ -262,12 +262,12 @@ Please refer to our tutorials about how to customize different algorithms for mo 1. NAS -[Customize NAS algorithms](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/customize_nas_algorithms.html) +[Customize NAS algorithms](https://mmrazor.readthedocs.io/en/main/advanced_guides/customize_nas_algorithms.html) 2. Pruning -[Customize Pruning algorithms](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/customize_pruning_algorithms.html) +[Customize Pruning algorithms](https://mmrazor.readthedocs.io/en/main/advanced_guides/customize_pruning_algorithms.html) 3. Distill -[Customize KD algorithms](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/customize_kd_algorithms.html) +[Customize KD algorithms](https://mmrazor.readthedocs.io/en/main/advanced_guides/customize_kd_algorithms.html) diff --git a/docs/en/advanced_guides/apply_existing_algorithms_to_new_tasks.md b/docs/en/advanced_guides/apply_existing_algorithms_to_new_tasks.md index 2171044fe..119d7a6c3 100644 --- a/docs/en/advanced_guides/apply_existing_algorithms_to_new_tasks.md +++ b/docs/en/advanced_guides/apply_existing_algorithms_to_new_tasks.md @@ -1,6 +1,6 @@ # Apply existing algorithms to new tasks -Here we show how to apply existing algorithms to other tasks with an example of [SPOS ](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/nas/mmcls/spos)& [DetNAS](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/nas/mmdet/detnas). +Here we show how to apply existing algorithms to other tasks with an example of [SPOS ](https://github.com/open-mmlab/mmrazor/tree/main/configs/nas/mmcls/spos)& [DetNAS](https://github.com/open-mmlab/mmrazor/tree/main/configs/nas/mmdet/detnas). > SPOS: Single Path One-Shot NAS for classification > diff --git a/docs/en/advanced_guides/customize_mixed_algorithms.md b/docs/en/advanced_guides/customize_mixed_algorithms.md index c9e96dd86..17b928d12 100644 --- a/docs/en/advanced_guides/customize_mixed_algorithms.md +++ b/docs/en/advanced_guides/customize_mixed_algorithms.md @@ -1,11 +1,11 @@ # Customize mixed algorithms -Here we show how to customize mixed algorithms with our algorithm components. We take [AutoSlim ](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/pruning/mmcls/autoslim)as an example. +Here we show how to customize mixed algorithms with our algorithm components. We take [AutoSlim ](https://github.com/open-mmlab/mmrazor/tree/main/configs/pruning/mmcls/autoslim)as an example. ```{note} **Why is AutoSlim a mixed algorithm?** -In [AutoSlim](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/pruning/mmcls/autoslim), the sandwich rule and the inplace distillation will be introduced to enhance the training process, which is called as the slimmable training. The sandwich rule means that we train the model at smallest width, largest width and (n − 2) random widths, instead of n random widths. And the inplace distillation means that we use the predicted label of the model at the largest width as the training label for other widths, while for the largest width we use ground truth. So both the KD algorithm and the pruning algorithm are used in [AutoSlim](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/pruning/mmcls/autoslim). +In [AutoSlim](https://github.com/open-mmlab/mmrazor/tree/main/configs/pruning/mmcls/autoslim), the sandwich rule and the inplace distillation will be introduced to enhance the training process, which is called as the slimmable training. The sandwich rule means that we train the model at smallest width, largest width and (n − 2) random widths, instead of n random widths. And the inplace distillation means that we use the predicted label of the model at the largest width as the training label for other widths, while for the largest width we use ground truth. So both the KD algorithm and the pruning algorithm are used in [AutoSlim](https://github.com/open-mmlab/mmrazor/tree/main/configs/pruning/mmcls/autoslim). ``` 1. Register a new algorithm @@ -21,9 +21,9 @@ You can choose existing algorithm components in MMRazor, such as `OneShotChannel If these in MMRazor don't meet your needs, you can customize new algorithm components for your algorithm. Reference is as follows: -[Customize NAS algorithms](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/customize_nas_algorithms.html) -[Customize Pruning algorithms](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/customize_pruning_algorithms.html) -[Customize KD algorithms](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/customize_kd_algorithms.html) +[Customize NAS algorithms](https://mmrazor.readthedocs.io/en/main/advanced_guides/customize_nas_algorithms.html) +[Customize Pruning algorithms](https://mmrazor.readthedocs.io/en/main/advanced_guides/customize_pruning_algorithms.html) +[Customize KD algorithms](https://mmrazor.readthedocs.io/en/main/advanced_guides/customize_kd_algorithms.html) ``` ```Python diff --git a/docs/en/advanced_guides/customize_quantization_algorithms.md b/docs/en/advanced_guides/customize_quantization_algorithms.md new file mode 100644 index 000000000..e1dd25eaf --- /dev/null +++ b/docs/en/advanced_guides/customize_quantization_algorithms.md @@ -0,0 +1,283 @@ +# Customize Quantization algorithms + +Here we show how to develop new QAT algorithms with an example of LSQ on OpenVINO backend. + +This document is mainly aimed at QAT because the ptq process is relatively fixed and the components we provide can meet most of the needs. We will first give an overview of the overall required development components, and then introduce the specific implementation step by step. + +## Overall + +In the mmrazor quantization pipeline, in order to better support the openmmlab environment, we have configured most of the code modules for users. You can configure all the components directly in the config file. How to configure them can be found in our [file](https://github.com/open-mmlab/mmrazor/blob/quantize/configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py). + +```Python +global_qconfig = dict( + w_observer=dict(), + a_observer=dict(), + w_fake_quant=dict(), + a_fake_quant=dict(), + w_qscheme=dict(), + a_qscheme=dict(), +) +model = dict( + type='mmrazor.MMArchitectureQuant', + architecture=resnet, + quantizer=dict( + type='mmrazor.OpenvinoQuantizer', + global_qconfig=global_qconfig, + tracer=dict())) +train_cfg = dict(type='mmrazor.LSQEpochBasedLoop') +``` + +For `algorithm` and `tracer`, we recommend that you use the default configurations `MMArchitectureQuant` and `CustomTracer` provided by us. These two module operators are specially built for the openmmlab environment, while other modules can refer to the following steps and choose or develop new operators according to your needs. + +To adapt to different backends, you need to select a different `quantizer`. + +To develop new quantization algorithms, you need to define new `observer` and `fakequant`. + +If the existing `loop` does not meet your needs, you may need to make some changes to the existing `loop` based on your algorithm. + +## Detailed steps + +1. Select a quantization algorithm + +We recommend that you directly use the`MMArchitectureQuant` in `mmrazor/models/algorithms/quantization/mm_architecture.py`.The class `MMArchitectureQuant` inherits from class `BaseAlgorithm`. + +This structure is built for the model in openmmlab. If you have other requirements, you can also refer to this [document](https://mmrazor.readthedocs.io/en/main/advanced_guides/customize_architectures.html#develop-common-model-components) to design the overall framework. + +2. Select quantizer + +At present, the quantizers we support are `NativeQuantizer`, `OpenVINOQuantizer`, `TensorRTQuantizer` and `AcademicQuantizer` in `mmrazor/models/quantizers/`. `AcademicQuantizer` and `NativeQuantizer` inherit from class `BaseQuantizer` in `mmrazor/models/quantizers/base.py`: + +```Python +class BaseQuantizer(BaseModule): + def __init__(self, tracer): + super().__init__() + self.tracer = TASK_UTILS.build(tracer) + @abstractmethod + def prepare(self, model, graph_module): + """tmp.""" + pass + def swap_ff_with_fxff(self, model): + pass +``` + +`NativeQuantizer` is the operator we developed to adapt to the environment of mmrazor according to pytorch's official quantization logic. `AcademicQuantizer` is an operator designed for academic research to give users more space to operate. + +The class `OpenVINOQuantizer` and `TensorRTQuantizer` inherits from class `NativeQuantize`. They adapted `OpenVINO` and `TensorRT`backend respectively. You can also try to develop a quantizer based on other backends according to your own needs. + +3. Select tracer + +Tracer we use `CustomTracer` in `mmrazor/models/task_modules/tracer/fx/custom_tracer.py`. You can inherit this class and customize your own tracer. + +4. Develop new fakequant method(optional) + +You can use fakequants provided by pytorch in `mmrazor/models/fake_quants/torch_fake_quants.py` as core functions provider. If you want to use the fakequant methods from other papers, you can also define them yourself. Let's take lsq as an example as follows: + +a.Create a new file `mmrazor/models/fake_quants/lsq.py`, class `LearnableFakeQuantize` inherits from class `FakeQuantizeBase`. + +b. Finish the functions you need, eg: `observe_quant_params`, `calculate_qparams` and so on. + +```Python +from mmrazor.registry import MODELS +from torch.ao.quantization import FakeQuantizeBase + +@MODELS.register_module() +class LearnableFakeQuantize(FakeQuantizeBase): + def __init__(self, + observer, + quant_min=0, + quant_max=255, + scale=1., + zero_point=0., + use_grad_scaling=True, + zero_point_trainable=False, + **observer_kwargs): + super(LearnableFakeQuantize, self).__init__() + pass + + def observe_quant_params(self): + pass + + def calculate_qparams(self): + pass + + def forward(self, X): + pass +``` + +c.Import the module in `mmrazor/models/fake_quants/__init__.py`. + +```Python +from .lsq import LearnableFakeQuantize + +__all__ = ['LearnableFakeQuantize'] +``` + +5. Develop new observer(optional) + +You can directly use observers provided by pytorch in `mmrazor/models/observers/torch_observers.py` or use observers customized by yourself. Let's take `LSQObserver` as follows: + +a.Create a new observer file `mmrazor/models/observers/lsq.py`, class `LSQObserver` inherits from class `MinMaxObserver` and `LSQObserverMixIn`. These two observers can calculate `zero_point` and `scale`, respectively. + +b.Finish the functions you need, eg: `calculate_qparams` and so on. + +```Python +from mmrazor.registry import MODELS +from torch.ao.quantization.observer import MinMaxObserver + +class LSQObserverMixIn: + def __init__(self): + self.tensor_norm = None + + @torch.jit.export + def _calculate_scale(self): + scale = 2 * self.tensor_norm / math.sqrt(self.quant_max) + sync_tensor(scale) + return scale + +@MODELS.register_module() +class LSQObserver(MinMaxObserver, LSQObserverMixIn): + """LSQ observer. + Paper: Learned Step Size Quantization. + """ + def __init__(self, *args, **kwargs): + MinMaxObserver.__init__(self, *args, **kwargs) + LSQObserverMixIn.__init__(self) + + def forward(self, x_orig): + """Records the running minimum, maximum and tensor_norm of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + self.tensor_norm = x.abs().mean() + min_val_cur, max_val_cur = torch.aminmax(x) + min_val = torch.min(min_val_cur, self.min_val) + max_val = torch.max(max_val_cur, self.max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): + """Calculates the quantization parameters.""" + _, zero_point = MinMaxObserver.calculate_qparams(self) + scale = LSQObserverMixIn._calculate_scale(self) + return scale, zero_point +``` + +c.Import the module in `mmrazor/models/observers/__init__.py` + +```Python +from .lsq import LSQObserver + +__all__ = ['LSQObserver'] +``` + +6. Select loop or develop new loop + +At present, the QAT loops we support are `PTQLoop` and `QATEpochBasedLoop`, in `mmrazor/engine/runner/quantization_loops.py`. We can develop a new `LSQEpochBasedLoop` inherits from class `QATEpochBasedLoop` and finish the functions we need in LSQ method. + +```Python +from mmengine.runner import EpochBasedTrainLoop + +@LOOPS.register_module() +class LSQEpochBasedLoop(QATEpochBasedLoop): + def __init__( + self, + runner, + dataloader: Union[DataLoader, Dict], + max_epochs: int, + val_begin: int = 1, + val_interval: int = 1, + freeze_bn_begin: int = -1, + dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: + super().__init__( + runner, + dataloader, + max_epochs, + val_begin, + val_interval, + freeze_bn_begin=freeze_bn_begin, + dynamic_intervals=dynamic_intervals) + + self.is_first_batch = True + + def prepare_for_run_epoch(self): + pass + + def prepare_for_val(self): + pass + + def run_epoch(self) -> None: + pass +``` + +And then Import the module in `mmrazor/engine/runner/__init__.py` + +```Python +from .quantization_loops import LSQEpochBasedLoop + +__all__ = ['LSQEpochBasedLoop'] +``` + +7. Use the algorithm in your config file + +After completing the above steps, we have all the components of the qat algorithm, and now we can combine them in the config file. + +a.First, `_base_` stores the location of the model that needs to be quantized. + +b.Second, configure observer,fakequant and qscheme in `global_qconfig` in detail. +You can configure the required quantization bit width and quantization methods in `qscheme`, such as symmetric quantization or asymmetric quantization. + +c.Third, build the whole mmrazor model in `model`. + +d.Finally, complete all the remaining required configuration files. + +```Python +_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py'] + +global_qconfig = dict( + w_observer=dict(type='mmrazor.LSQPerChannelObserver'), + a_observer=dict(type='mmrazor.LSQObserver'), + w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=resnet, + float_checkpoint=float_ckpt, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + is_qat=True, + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +# learning policy +optim_wrapper = dict() +param_scheduler = dict() +model_wrapper_cfg = dict() + +# train, val, test setting +train_cfg = dict(type='mmrazor.LSQEpochBasedLoop') +val_cfg = dict() +test_cfg = val_cfg +``` diff --git a/docs/en/advanced_guides/index.rst b/docs/en/advanced_guides/index.rst index 7d46576ef..349dc5902 100644 --- a/docs/en/advanced_guides/index.rst +++ b/docs/en/advanced_guides/index.rst @@ -20,5 +20,6 @@ Development tutorials customize_nas_algorithms.md customize_pruning_algorithms.md customize_kd_algorithms.md + customize_quantization_algorithms.md customize_mixed_algorithms.md apply_existing_algorithms_to_new_tasks.md diff --git a/docs/en/advanced_guides/mutable.md b/docs/en/advanced_guides/mutable.md index c8f180c1a..8d59cfe71 100644 --- a/docs/en/advanced_guides/mutable.md +++ b/docs/en/advanced_guides/mutable.md @@ -13,7 +13,7 @@ To understand it better, we take the mutable module as an example to explain as As shown in the figure above, `Mutable` is a container that holds some candidate operations, thus it can sample candidates to constitute the subnet. `Supernet` usually consists of multiple `Mutable`, therefore, `Supernet` will be searchable with the help of `Mutable`. And all candidate operations in `Mutable` constitute the search space of `SuperNet`. ```{note} -If you want to know more about the relationship between Mutable and Mutator, please refer to [Mutator](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/mutator.html) +If you want to know more about the relationship between Mutable and Mutator, please refer to [Mutator](https://mmrazor.readthedocs.io/en/main/advanced_guides/mutator.html) ``` ### Features diff --git a/docs/en/advanced_guides/mutator.md b/docs/en/advanced_guides/mutator.md index ff28c4c91..aa28a199d 100644 --- a/docs/en/advanced_guides/mutator.md +++ b/docs/en/advanced_guides/mutator.md @@ -25,7 +25,7 @@ In MMRazor, we have implemented some mutators, their relationship is as below. `ModuleMuator`/ `ChannelMutator`: Two different types mutators are for handling mutable module and mutable channel respectively. ```{note} -Please refer to [Mutable](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/mutable.html) for more details about different types of mutable. +Please refer to [Mutable](https://mmrazor.readthedocs.io/en/main/advanced_guides/mutable.html) for more details about different types of mutable. ``` `OneShotModuleMutator` / `DiffModuleMutator`: Inherit from `ModuleMuator`, they are for implementing different types algorithms, such as [SPOS](https://arxiv.org/abs/1904.00420), [Darts](https://arxiv.org/abs/1806.09055) and so on. diff --git a/docs/en/get_started/installation.md b/docs/en/get_started/installation.md index a27105006..24550650a 100644 --- a/docs/en/get_started/installation.md +++ b/docs/en/get_started/installation.md @@ -1,138 +1,66 @@ # Installation -## Prepare Environment +## Prerequisites -Create a conda virtual environment and activate it. +In this section we demonstrate how to prepare an environment with PyTorch. -```Python -conda create -n openmmlab python=3.7 -y -conda activate openmmlab -``` - -Install PyTorch and torchvision following the [official instructions](https://pytorch.org/). - -```{note} -Make sure that your compilation CUDA version and runtime CUDA version match. You can check the supported CUDA version for precompiled packages on the [PyTorch website](https://pytorch.org/). If you build PyTorch from source instead of installing the prebuilt package, you can use more CUDA versions such as 9.0. -``` - -## Customize Installation - -It is recommended to install MMRazor with [MIM](https://github.com/open-mmlab/mim), which automatically handles the dependencies of OpenMMLab projects, including mmcv and other python packages. - -Or you can still install MMRazor manually +MMRazor works on Linux, Windows and macOS. It requires Python 3.6+, CUDA 9.2+ and PyTorch 1.8+. -1. Install mmcv. +**Note:** +If you are experienced with PyTorch and have already installed it, just skip this part and jump to the [next section](##installation). Otherwise, you can follow these steps for the preparation. -You can install mmcv with MIM, pip, or build it from source. +**Step 0.** Download and install Miniconda from the [official website](https://docs.conda.io/en/latest/miniconda.html). -- Install mmcv with MIM (recommend). +**Step 1.** Create a conda environment and activate it. -```Python -pip install openmim -mim install 'mmcv>=2.0.0rc1' +```shell +conda create --name openmmlab python=3.8 -y +conda activate openmmlab ``` -- Install mmcv with pip. - -```Python -pip install 'mmcv>=2.0.0rc1' -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html -``` +**Step 2.** Install PyTorch following [official instructions](https://pytorch.org/get-started/locally/), e.g. -Please replace `{cu_version}` and `{torch_version}` in the url to your desired one. For example, to install the latest `mmcv` with `CUDA 10.2` and `PyTorch 1.10.0`, use the following command: +On GPU platforms: -```Python -pip install 'mmcv>=2.0.0rc1' -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.10.0/index.html +```shell +conda install pytorch torchvision -c pytorch ``` -See [here](https://github.com/open-mmlab/mmcv#installation) for different versions of MMCV compatible to different PyTorch and CUDA versions. - -- Build mmcv from source. +On CPU platforms: -```bash -MMCV_WITH_OPS=0 pip install -e . -v -# install mmcv-lite, do not compile operators -MMCV_WITH_OPS=1 pip install -e . -v -# install mmcv (originally called mmcv-full), compile operators -pip install -e . -v -# install mmcv with compiled operators, +```shell +conda install pytorch torchvision cpuonly -c pytorch ``` -- For windows platform, try `set MMCV_WITH_OPS=1` instead. +## Installation -2. Install MMEngine. +We recommend that users follow our best practices to install MMRazor. -You can install mmengine with MIM or build it from source. +### Best Practices -- Install MMEngine with MIM. +**Step 0.** Install [MMCV](https://github.com/open-mmlab/mmcv) using [MIM](https://github.com/open-mmlab/mim). -```bash -pip install openmim +```shell +pip install -U openmim mim install mmengine +mim install "mmcv>=2.0.0" ``` -- Compile MMEngine from source. - -```Python -git clone https://github.com/open-mmlab/mmengine.git -cd mmengine -pip install -v -e . -``` +**Step 1.** Install MMRazor. -3. Install MMRazor. +Case a: If you develop and run mmrazor directly, install it from source: -If you would like to install MMRazor in `dev` mode, run following: - -```Python -git clone https://github.com/open-mmlab/mmrazor.git +```shell +git clone -b main https://github.com/open-mmlab/mmrazor.git cd mmrazor -git fetch origin -git checkout -b dev-1.x origin/dev-1.x -# The new version is released in branch ``dev-1.x`` pip install -v -e . -# "-v" means verbose, or more output -# "-e" means installing a project in editable mode, +# '-v' means verbose, or more output +# '-e' means installing a project in editable mode, # thus any local modifications made to the code will take effect without reinstallation. ``` -```{note} -When MMRazor is installed on `dev` mode, any local modifications made to the code will take effect without the need to reinstall it. -``` - -## A from-scratch Setup Script - -```Python -conda create -n openmmlab python=3.7 -y -conda activate openmmlab - -conda install pytorch torchvision cudatoolkit=10.2 -c pytorch -# install the latest mmcv -pip install 'mmcv>=2.0.0rc1' -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.10.0/index.html -# install mmrazor -git clone https://github.com/open-mmlab/mmrazor.git -cd mmrazor -git fetch origin -git checkout -b dev-1.x origin/dev-1.x -pip install -v -e . -``` - -## Install Other Libraries - -MMRazor can easily collaborate with other OpenMMLab libraries. MMRazor requires the use of other libraries for different tasks. For example, `MMClassification` is required for image classification tasks, `MMDetection` for object detection, and `MMSegmentation` for semantic segmentation. - -We provide the installation of the above three libraries using `MIM`. - -```bash -pip install openmim -# mmcv is required for all libraries -mim install 'mmcv>=2.0.0rc1' -# install mmcls -mim install 'mmcls>=1.0.0rc0' -# install mmdet -mim install 'mmdet>=3.0.0rc0' -# install mmseg -mim install 'mmseg>=1.0.0rc0' -``` +Case b: If you use mmrazor as a dependency or third-party package, install it with pip: -```{note} -Not all of above libraries are required by MMRazor. Please install according to your requirements. +```shell +pip install "mmrazor>=1.0.0" ``` diff --git a/docs/en/get_started/model_zoo.md b/docs/en/get_started/model_zoo.md index 3a5d5f22e..fce863881 100644 --- a/docs/en/get_started/model_zoo.md +++ b/docs/en/get_started/model_zoo.md @@ -2,23 +2,29 @@ ## Baselines -| Type | Name | Link | -| ------- | :-------------: | :---------------------------------------------------------------------------------------------------: | -| nas | SPOS | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/nas/mmcls/spos) | -| nas | DARTS | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/nas/mmcls/darts) | -| nas | DetNAS | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/nas/mmdet/detnas) | -| pruning | AutoSlim | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/pruning/mmcls/autoslim) | -| ditill | ABLoss | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/distill/mmcls/abloss) | -| ditill | BYOT | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/distill/mmcls/byot) | -| ditill | DAFL | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/distill/mmcls/dafl) | -| ditill | DFAD | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/distill/mmcls/dfad) | -| ditill | DKD | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/distill/mmcls/dkd) | -| ditill | Factor Transfer | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/distill/mmcls/factor_transfer) | -| ditill | FitNets | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/distill/mmcls/fitnets) | -| ditill | KD | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/distill/mmcls/kd) | -| ditill | OFD | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/distill/mmcls/ofd) | -| ditill | RKD | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/distill/mmcls/rkd) | -| ditill | WSLD | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/distill/mmcls/wsld) | -| ditill | ZSKT | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/distill/mmcls/zskt) | -| ditill | CWD | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/distill/mmdet/cwd) | -| ditill | FBKD | [README.md](https://github.com/open-mmlab/mmrazor/tree/dev-1.x/configs/distill/mmdet/fbkd) | +| Type | Name | Link | +| ------------ | :-------------: | :------------------------------------------------------------------------------------------------: | +| nas | SPOS | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/nas/mmcls/spos) | +| nas | DARTS | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/nas/mmcls/darts) | +| nas | DetNAS | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/nas/mmdet/detnas) | +| pruning | AutoSlim | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/pruning/mmcls/autoslim) | +| pruning | L1-norm | [README.md](https://github.com/open-mmlab/mmrazor/tree/main//configs/pruning/mmcls/l1-norm) | +| pruning | Group Fisher | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/pruning/base/group_fisher) | +| pruning | DMCP | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/pruning/mmcls/dmcp) | +| ditill | ABLoss | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/distill/mmcls/abloss) | +| ditill | BYOT | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/distill/mmcls/byot) | +| ditill | DAFL | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/distill/mmcls/dafl) | +| ditill | DFAD | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/distill/mmcls/dfad) | +| ditill | DKD | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/distill/mmcls/dkd) | +| ditill | Factor Transfer | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/distill/mmcls/factor_transfer) | +| ditill | FitNets | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/distill/mmcls/fitnets) | +| ditill | KD | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/distill/mmcls/kd) | +| ditill | OFD | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/distill/mmcls/ofd) | +| ditill | RKD | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/distill/mmcls/rkd) | +| ditill | WSLD | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/distill/mmcls/wsld) | +| ditill | ZSKT | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/distill/mmcls/zskt) | +| ditill | CWD | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/distill/mmdet/cwd) | +| ditill | FBKD | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/distill/mmdet/fbkd) | +| quantization | PTQ | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/quantization/ptq/base) | +| quantization | QAT | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/quantization/qat/base) | +| quantization | LSQ | [README.md](https://github.com/open-mmlab/mmrazor/tree/main/configs/quantization/qat/lsq) | diff --git a/docs/en/get_started/overview.md b/docs/en/get_started/overview.md index 99931121c..4249192fc 100644 --- a/docs/en/get_started/overview.md +++ b/docs/en/get_started/overview.md @@ -7,9 +7,9 @@ MMRazor is a model compression toolkit for model slimming, which includes 4 main - Neural Architecture Search (NAS) - Pruning - Knowledge Distillation (KD) -- Quantization (come soon) +- Quantization -It is a part of the [OpenMMLab](https://openmmlab.com/) project. If you want to use it now, please refer to [Installation](https://mmrazor.readthedocs.io/en/dev-1.x/get_started/installation.html). +It is a part of the [OpenMMLab](https://openmmlab.com/) project. If you want to use it now, please refer to [Installation](https://mmrazor.readthedocs.io/en/main/get_started/installation.html). ### Major features: @@ -59,26 +59,26 @@ For better understanding and using MMRazor, it is highly recommended to read the **Global** -- [Algorithm](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/algorithm.html) +- [Algorithm](https://mmrazor.readthedocs.io/en/main/advanced_guides/algorithm.html) **NAS & Pruning** -- [Mutator](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/mutator.html) -- [Mutable](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/mutable.html) +- [Mutator](https://mmrazor.readthedocs.io/en/main/advanced_guides/mutator.html) +- [Mutable](https://mmrazor.readthedocs.io/en/main/advanced_guides/mutable.html) **KD** -- [Delivery](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/delivery.html) -- [Recorder](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/recorder.html) +- [Delivery](https://mmrazor.readthedocs.io/en/main/advanced_guides/delivery.html) +- [Recorder](https://mmrazor.readthedocs.io/en/main/advanced_guides/recorder.html) ## User guide If you want to run mmrazor quickly, you can refer to as the follows. -- [Learn about Configs](https://mmrazor.readthedocs.io/en/dev-1.x/user_guides/1_learn_about_config.html) -- [Train different types algorithms](https://mmrazor.readthedocs.io/en/dev-1.x/user_guides/2_train_different_types_algorithms.html) -- [Train with different devices](https://mmrazor.readthedocs.io/en/dev-1.x/user_guides/3_train_with_different_devices.html) -- [Test a model](https://mmrazor.readthedocs.io/en/dev-1.x/user_guides/4_test_a_model.html) +- [Learn about Configs](https://mmrazor.readthedocs.io/en/main/user_guides/1_learn_about_config.html) +- [Train different types algorithms](https://mmrazor.readthedocs.io/en/main/user_guides/2_train_different_types_algorithms.html) +- [Train with different devices](https://mmrazor.readthedocs.io/en/main/user_guides/3_train_with_different_devices.html) +- [Test a model](https://mmrazor.readthedocs.io/en/main/user_guides/4_test_a_model.html) ## Tutorials @@ -86,20 +86,20 @@ We provide the following general tutorials according to some typical requirement **Tutorial list** -- [Customize Architectures](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/customize_architectures.html) -- [Customize NAS algorithms](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/customize_nas_algorithms.html) -- [Customize Pruning algorithms](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/customize_pruning_algorithms.html) -- [Customize KD algorithms](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/customize_kd_algorithms.html) -- [Customize mixed algorithms](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/customize_mixed_algorithms.html) -- [Apply existing algorithms to new tasks](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/apply_existing_algorithms_to_new_tasks.html) +- [Customize Architectures](https://mmrazor.readthedocs.io/en/main/advanced_guides/customize_architectures.html) +- [Customize NAS algorithms](https://mmrazor.readthedocs.io/en/main/advanced_guides/customize_nas_algorithms.html) +- [Customize Pruning algorithms](https://mmrazor.readthedocs.io/en/main/advanced_guides/customize_pruning_algorithms.html) +- [Customize KD algorithms](https://mmrazor.readthedocs.io/en/main/advanced_guides/customize_kd_algorithms.html) +- [Customize mixed algorithms](https://mmrazor.readthedocs.io/en/main/advanced_guides/customize_mixed_algorithms.html) +- [Apply existing algorithms to new tasks](https://mmrazor.readthedocs.io/en/main/advanced_guides/apply_existing_algorithms_to_new_tasks.html) ## F&Q -If you encounter some trouble using MMRazor, you can find whether your question has existed in [F&Q](https://mmrazor.readthedocs.io/en/dev-1.x/notes/faq.html). If not existed, welcome to open a [Github issue](https://github.com/open-mmlab/mmrazor/issues) for getting support, we will reply it as soon. +If you encounter some trouble using MMRazor, you can find whether your question has existed in [F&Q](https://mmrazor.readthedocs.io/en/main/notes/faq.html). If not existed, welcome to open a [Github issue](https://github.com/open-mmlab/mmrazor/issues) for getting support, we will reply it as soon. ## Get support and contribute back MMRazor is maintained on the [MMRazor Github repository](https://github.com/open-mmlab/mmrazor). We collect feedback and new proposals/ideas on Github. You can: - Open a [GitHub issue](https://github.com/open-mmlab/mmrazor/issues) for bugs and feature requests. -- Open a [pull request](https://github.com/open-mmlab/mmrazor/pulls) to contribute code (make sure to read the [contribution guide](https://mmrazor.readthedocs.io/en/dev-1.x/notes/contribution_guide.html) before doing this). +- Open a [pull request](https://github.com/open-mmlab/mmrazor/pulls) to contribute code (make sure to read the [contribution guide](https://mmrazor.readthedocs.io/en/main/notes/contribution_guide.html) before doing this). diff --git a/docs/en/notes/changelog.md b/docs/en/notes/changelog.md index 825c32f0d..338c9425d 100644 --- a/docs/en/notes/changelog.md +++ b/docs/en/notes/changelog.md @@ -1 +1,278 @@ -# Changelog +# Changelog of v1.x + +## v1.0.0 (24/04/2023) + +We are excited to announce the first official release of MMRazor 1.0. + +### Highlights + +- MMRazor quantization is released, which has got through task models and model deployment. With its help, we can quantize and deploy pre-trained models in OpenMMLab to specified backend quickly. + +### New Features & Improvements + +#### NAS + +- Update searchable model. (https://github.com/open-mmlab/mmrazor/pull/438) +- Update NasMutator to build search_space in NAS. (https://github.com/open-mmlab/mmrazor/pull/426) + +#### Pruning + +- Add a new pruning algorithm named GroupFisher. We support the full pipeline for GroupFisher, including pruning, finetuning and deployment.(https://github.com/open-mmlab/mmrazor/pull/459) + +#### KD + +- Support stopping distillation after a certain epoch. (https://github.com/open-mmlab/mmrazor/pull/455) +- Support distilling rtmdet with mmrazor, refer to here. (https://github.com/open-mmlab/mmyolo/pull/544) +- Add mask channel in MGD Loss. (https://github.com/open-mmlab/mmrazor/pull/461) + +#### Quantization + +- Support two quantization types: QAT and PTQ (https://github.com/open-mmlab/mmrazor/pull/513) +- Support various quantization bits. (https://github.com/open-mmlab/mmrazor/pull/513) +- Support various quantization methods, such as per_tensor / per_channel, symmetry / asymmetry and so on. (https://github.com/open-mmlab/mmrazor/pull/513) +- Support deploy quantized models to multiple backends, such as OpenVINO, TensorRT and so on. (https://github.com/open-mmlab/mmrazor/pull/513) +- Support applying quantization algorithms to multiple task repos directly, such as mmcls, mmdet and so on. (https://github.com/open-mmlab/mmrazor/pull/513) + +### Bug Fixes + +- Fix split in Darts config. (https://github.com/open-mmlab/mmrazor/pull/451) +- Fix a bug in Recorders. (https://github.com/open-mmlab/mmrazor/pull/446) +- Fix a bug when using get_channel_unit.py. (https://github.com/open-mmlab/mmrazor/pull/432) +- Fix a bug when deploying a pruned model to cuda. (https://github.com/open-mmlab/mmrazor/pull/495) + +### Contributors + +A total of 10 developers contributed to this release. +Thanks @415905716 @gaoyang07 @humu789 @LKJacky @HIT-cwh @aptsunny @cape-zck @vansin @twmht @wm901115nwpu + +## v1.0.0rc2 (06/01/2023) + +We are excited to announce the release of MMRazor 1.0.0rc2. + +### New Features + +#### NAS + +- Add Performance Predictor: Support 4 performance predictors with 4 basic machine learning algorithms, which can be used to directly predict model accuracy without evaluation.(https://github.com/open-mmlab/mmrazor/pull/306) + +- Support [Autoformer](https://arxiv.org/pdf/2107.00651.pdf), a one-shot architecture search algorithm dedicated to vision transformer search.(https://github.com/open-mmlab/mmrazor/pull/315 ) + +- Support [BigNAS](https://arxiv.org/pdf/2003.11142), a NAS algorithm which searches the following items in MobileNetV3 with the one-shot paradigm: kernel_sizes, out_channels, expand_ratios, block_depth and input sizes. (https://github.com/open-mmlab/mmrazor/pull/219 ) + +#### Pruning + +- Support [DCFF](https://arxiv.org/abs/2107.06916), a filter channel pruning algorithm dedicated to efficient image classification.(https://github.com/open-mmlab/mmrazor/pull/295) + +- We release a powerful tool to automatically analyze channel dependency, named ChannelAnalyzer. Here is an example as shown below.(https://github.com/open-mmlab/mmrazor/pull/371) + +Now, ChannelAnalyzer supports most of CNN models in torchvision, mmcls, mmseg and mmdet. We will continue to support more models. + +```python +from mmrazor.models.task_modules import ChannelAnalyzer +from mmengine.hub import get_model +import json + +model = get_model('mmdet::retinanet/retinanet_r18_fpn_1x_coco.py') +unit_configs: dict = ChannelAnalyzer().analyze(model) +unit_config0 = list(unit_configs.values())[0] +print(json.dumps(unit_config0, indent=4)) +# # short version of the config +# { +# "channels": { +# "input_related": [ +# {"name": "backbone.layer2.0.bn1"}, +# {“name": "backbone.layer2.0.conv2"} +# ], +# "output_related": [ +# {"name": "backbone.layer2.0.conv1"}, +# {"name": "backbone.layer2.0.bn1"} +# ] +# }, +#} +``` + +#### KD + +- Support [MGD](https://arxiv.org/abs/2205.01529), a detection distillation algorithm.(https://github.com/open-mmlab/mmrazor/pull/381) + +### Bug Fixes + +- Fix `FpnTeacherDistll` techer forward from `backbone + neck + head` to `backbone + neck`(#387 ) +- Fix some expire configs and checkpoints(#373 #372 #422 ) + +### Ongoing Changes + +We will release Quantization in next version(1.0.0rc3)! + +### Contributors + +A total of 11 developers contributed to this release: @wutongshenqiu @sunnyxiaohu @aptsunny @humu789 @TinyTigerPan @FreakieHuang @LKJacky @wilxy @gaoyang07 @spynccat @yivona08. + +## v1.0.0rc1 (27/10/2022) + +We are excited to announce the release of MMRazor 1.0.0rc1. + +### Highlights + +- **New Pruning Framework**:We have systematically refactored the Pruning module. The new Pruning module can more automatically resolve the dependencies between channels and cover more corner cases. + +### New Features + +#### Pruning + +- A new pruning framework is released in this release. (#311, #313) + It consists of five core modules, including Algorithm, `ChannelMutator`, `MutableChannelUnit`, `MutableChannel` and `DynamicOp`. + +- MutableChannelUnit is introduced for the first time. Each MutableChannelUnit manages all channels with channel dependency. + + ```python + from mmrazor.registry import MODELS + + ARCHITECTURE_CFG = dict( + _scope_='mmcls', + type='ImageClassifier', + backbone=dict(type='MobileNetV2', widen_factor=1.5), + neck=dict(type='GlobalAveragePooling'), + head=dict(type='mmcls.LinearClsHead', num_classes=1000, in_channels=1920)) + model = MODELS.build(ARCHITECTURE_CFG) + from mmrazor.models.mutators import ChannelMutator + + channel_mutator = ChannelMutator() + channel_mutator.prepare_from_supernet(model) + units = channel_mutator.mutable_units + print(units[0]) + # SequentialMutableChannelUnit( + # name=backbone.conv1.conv_(0, 48)_48 + # (output_related): ModuleList( + # (0): Channel(backbone.conv1.conv, index=(0, 48), is_output_channel=true, expand_ratio=1) + # (1): Channel(backbone.conv1.bn, index=(0, 48), is_output_channel=true, expand_ratio=1) + # (2): Channel(backbone.layer1.0.conv.0.conv, index=(0, 48), is_output_channel=true, expand_ratio=1) + # (3): Channel(backbone.layer1.0.conv.0.bn, index=(0, 48), is_output_channel=true, expand_ratio=1) + # ) + # (input_related): ModuleList( + # (0): Channel(backbone.conv1.bn, index=(0, 48), is_output_channel=false, expand_ratio=1) + # (1): Channel(backbone.layer1.0.conv.0.conv, index=(0, 48), is_output_channel=false, expand_ratio=1) + # (2): Channel(backbone.layer1.0.conv.0.bn, index=(0, 48), is_output_channel=false, expand_ratio=1) + # (3): Channel(backbone.layer1.0.conv.1.conv, index=(0, 48), is_output_channel=false, expand_ratio=1) + # ) + # (mutable_channel): SquentialMutableChannel(num_channels=48, activated_channels=48) + # ) + ``` + +Our new pruning algorithm can help you develop pruning algorithm more fluently. Pelease refer to our documents [PruningUserGuide](<./docs/en/user_guides/../../pruning/%5Bpruning_user_guide.md%5D(http://pruning_user_guide.md/)>) for model detail. + +#### Distillation + +- Support [CRD](https://arxiv.org/abs/1910.10699), a distillation algorithm based on contrastive representation learning. (#281) + +- Support [PKD](https://arxiv.org/abs/2207.02039), a distillation algorithm that can be used in `MMDetection` and `MMDetection3D`. #304 + +- Support [DEIT](https://arxiv.org/abs/2012.12877), a classic **Transformer** distillation algorithm.(#332) + +- Add a more powerful baseline setting for [KD](https://arxiv.org/abs/1503.02531). (#305) + +- Add `MethodInputsRecorder` and `FuncInputsRecorder` to record the input of a class method or a function.(#320) + +#### NAS + +- Support [DSNAS](https://arxiv.org/pdf/2002.09128.pdf), a nas algorithm that does not require retraining. (#226 ) + +#### Tools + +- Support configurable immediate feature map visualization. (#293 ) + A useful tool is supported in this release to visualize the immediate features of a neural network. Please refer to our documents [VisualizationUserGuide](http://./docs/zh_cn/user_guides/visualization.md) for more details. + +### Bug Fixes + +- Fix the bug that `FunctionXXRecorder` and `FunctionXXDelivery` can not be pickled. (#320) + +### Ongoing changes + +- Quantization: We are developing the basic interface of PTQ and QAT. RFC(Request for Comments) will be released soon. +- AutoSlim: AutoSlim is not yet available and is being refactored. +- Fx Pruning Tracer: Currently, the model topology can only be resolved through the backward tracer. In the future, both backward tracer and fx tracer will be supported. +- More Algorithms: BigNAS、AutoFormer、GreedyNAS and Resrep will be released in the next few versions. +- Documentation: we will add more design docs, tutorials, and migration guidance so that the community can deep dive into our new design, participate the future development, and smoothly migrate downstream libraries to MMRazor 1.x. + +### Contributors + +A total of 12 developers contributed to this release. +Thanks @FreakieHuang @gaoyang07 @HIT-cwh @humu789 @LKJacky @pppppM @pprp @spynccat @sunnyxiaohu @wilxy @kitecats @SheffieldCao + +## v1.0.0rc0 (31/8/2022) + +We are excited to announce the release of MMRazor 1.0.0rc0. +MMRazor 1.0.0rc0 is the first version of MMRazor 1.x, a part of the OpenMMLab 2.0 projects. +Built upon the new [training engine](https://github.com/open-mmlab/mmengine), +MMRazor 1.x simplified the interaction with other OpenMMLab repos, and upgraded the basic APIs of KD / Pruning / NAS. +It also provides a series of knowledge distillation algorithms. + +### Highlights + +- **New engines**. MMRazor 1.x is based on [MMEngine](https://github.com/open-mmlab/mmengine), which provides a general and powerful runner that allows more flexible customizations and significantly simplifies the entrypoints of high-level interfaces. + +- **Unified interfaces**. As a part of the OpenMMLab 2.0 projects, MMRazor 1.x unifies and refactors the interfaces and internal logic of train, testing, datasets, models, evaluation, and visualization. All the OpenMMLab 2.0 projects share the same design in those interfaces and logic to allow the emergence of multi-task/modality algorithms. + +- **More configurable KD**. MMRazor 1.x add [Recorder](../advanced_guides/recorder.md) to get the data needed for KD more automatically,[Delivery ](../advanced_guides/delivery.md) to automatically pass the teacher's intermediate results to the student, and connector to handle feature dimension mismatches between teacher and student. + +- **More kinds of KD algorithms**. Benefitting from the powerful APIs of KD, we have added several categories of KD algorithms, data-free distillation, self-distillation, and zero-shot distillation. + +- **Unify the basic interface of NAS and Pruning**. We refactored [Mutable](../advanced_guides/mutable.md), adding mutable value and mutable channel. Both NAS and Pruning can be developed based on mutables. + +- **More documentation and tutorials**. We add a bunch of documentation and tutorials to help users get started more smoothly. Read it [here](https://mmrazor.readthedocs.io/en/1.0.0rc0/). + +### Breaking Changes + +#### Training and testing + +- MMRazor 1.x runs on PyTorch>=1.6. We have deprecated the support of PyTorch 1.5 to embrace the mixed precision training and other new features since PyTorch 1.6. Some models can still run on PyTorch 1.5, but the full functionality of MMRazor 1.x is not guaranteed. +- MMRazor 1.x uses Runner in [MMEngine](https://github.com/open-mmlab/mmengine) rather than that in MMCV. The new Runner implements and unifies the building logic of dataset, model, evaluation, and visualizer. Therefore, MMRazor 1.x no longer maintains the building logics of those modules in `mmdet.train.apis` and `tools/train.py`. Those code have been migrated into [MMEngine](https://github.com/open-mmlab/mmengine/blob/main/mmengine/runner/runner.py). +- The Runner in MMEngine also supports testing and validation. The testing scripts are also simplified, which has similar logic as that in training scripts to build the runner. + +#### Configs + +- The [Runner in MMEngine](https://github.com/open-mmlab/mmengine/blob/main/mmengine/runner/runner.py) uses a different config structures +- Config and model names + +#### Components + +- Algorithms +- Distillers +- Mutators +- Mutables +- Hooks + +### Improvements + +- Support mixed precision training of all the models. However, some models may got Nan results due to some numerical issues. We will update the documentation and list their results (accuracy of failure) of mixed precision training. + +### Bug Fixes + +- AutoSlim: Models of different sizes will no longer have the same size checkpoint + +### New Features + +- Support [Activation Boundaries Loss](https://arxiv.org/pdf/1811.03233.pdf) +- Support [Be Your Own Teacher](https://arxiv.org/abs/1905.08094) +- Support [Data-Free Learning of Student Networks](https://doi.org/10.1109/ICCV.2019.00361) +- Support [Data-Free Adversarial Distillation](https://arxiv.org/pdf/1912.11006.pdf) +- Support [Decoupled Knowledge Distillation](https://arxiv.org/pdf/2203.08679.pdf) +- Support [Factor Transfer](https://arxiv.org/abs/1802.04977) +- Support [FitNets](https://arxiv.org/abs/1412.6550) +- Support [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531) +- Support [Overhaul](https://arxiv.org/abs/1904.01866) +- Support [Zero-shot Knowledge Transfer via Adversarial Belief Matching](https://arxiv.org/abs/1905.09768) + +### Ongoing changes + +- Quantization: We are developing the basic interface of PTQ and QAT. RFC(Request for Comments) will be released soon. +- AutoSlim: AutoSlim is not yet available and is being refactored. +- Fx Pruning Tracer: Currently, the model topology can only be resolved through the backward tracer. In the future, both backward tracer and fx tracer will be supported. +- More Algorithms: BigNAS、AutoFormer、GreedyNAS and Resrep will be released in the next few versions. +- Documentation: we will add more design docs, tutorials, and migration guidance so that the community can deep dive into our new design, participate the future development, and smoothly migrate downstream libraries to MMRazor 1.x. + +### Contributors + +A total of 13 developers contributed to this release. +Thanks @FreakieHuang @gaoyang07 @HIT-cwh @humu789 @LKJacky @pppppM @pprp @spynccat @sunnyxiaohu @wilxy @wutongshenqiu @NickYangMin @Hiwyl +Special thanks to @Davidgzx for his contribution to the data-free distillation algorithms diff --git a/docs/en/notes/changelog_v1.md b/docs/en/notes/changelog_v1.md deleted file mode 100644 index d768bae74..000000000 --- a/docs/en/notes/changelog_v1.md +++ /dev/null @@ -1,79 +0,0 @@ -# Changelog of v1.x - -## v1.0.0rc0 (31/8/2022) - -We are excited to announce the release of MMRazor 1.0.0rc0. -MMRazor 1.0.0rc0 is the first version of MMRazor 1.x, a part of the OpenMMLab 2.0 projects. -Built upon the new [training engine](https://github.com/open-mmlab/mmengine), -MMRazor 1.x simplified the interaction with other OpenMMLab repos, and upgraded the basic APIs of KD / Pruning / NAS. -It also provides a series of knowledge distillation algorithms. - -### Highlights - -- **New engines**. MMRazor 1.x is based on [MMEngine](https://github.com/open-mmlab/mmengine), which provides a general and powerful runner that allows more flexible customizations and significantly simplifies the entrypoints of high-level interfaces. - -- **Unified interfaces**. As a part of the OpenMMLab 2.0 projects, MMRazor 1.x unifies and refactors the interfaces and internal logic of train, testing, datasets, models, evaluation, and visualization. All the OpenMMLab 2.0 projects share the same design in those interfaces and logic to allow the emergence of multi-task/modality algorithms. - -- **More configurable KD**. MMRazor 1.x add [Recorder](../advanced_guides/recorder.md) to get the data needed for KD more automatically,[Delivery ](../advanced_guides/delivery.md) to automatically pass the teacher's intermediate results to the student, and connector to handle feature dimension mismatches between teacher and student. - -- **More kinds of KD algorithms**. Benefitting from the powerful APIs of KD, we have added several categories of KD algorithms, data-free distillation, self-distillation, and zero-shot distillation. - -- **Unify the basic interface of NAS and Pruning**. We refactored [Mutable](../advanced_guides/mutable.md), adding mutable value and mutable channel. Both NAS and Pruning can be developed based on mutables. - -- **More documentation and tutorials**. We add a bunch of documentation and tutorials to help users get started more smoothly. Read it [here](https://mmrazor.readthedocs.io/en/1.0.0rc0/). - -### Breaking Changes - -#### Training and testing - -- MMRazor 1.x runs on PyTorch>=1.6. We have deprecated the support of PyTorch 1.5 to embrace the mixed precision training and other new features since PyTorch 1.6. Some models can still run on PyTorch 1.5, but the full functionality of MMRazor 1.x is not guaranteed. -- MMRazor 1.x uses Runner in [MMEngine](https://github.com/open-mmlab/mmengine) rather than that in MMCV. The new Runner implements and unifies the building logic of dataset, model, evaluation, and visualizer. Therefore, MMRazor 1.x no longer maintains the building logics of those modules in `mmdet.train.apis` and `tools/train.py`. Those code have been migrated into [MMEngine](https://github.com/open-mmlab/mmengine/blob/main/mmengine/runner/runner.py). -- The Runner in MMEngine also supports testing and validation. The testing scripts are also simplified, which has similar logic as that in training scripts to build the runner. - -#### Configs - -- The [Runner in MMEngine](https://github.com/open-mmlab/mmengine/blob/main/mmengine/runner/runner.py) uses a different config structures -- Config and model names - -#### Components - -- Algorithms -- Distillers -- Mutators -- Mutables -- Hooks - -### Improvements - -- Support mixed precision training of all the models. However, some models may got Nan results due to some numerical issues. We will update the documentation and list their results (accuracy of failure) of mixed precision training. - -### Bug Fixes - -- AutoSlim: Models of different sizes will no longer have the same size checkpoint - -### New Features - -- Support [Activation Boundaries Loss](https://arxiv.org/pdf/1811.03233.pdf) -- Support [Be Your Own Teacher](https://arxiv.org/abs/1905.08094) -- Support [Data-Free Learning of Student Networks](https://doi.org/10.1109/ICCV.2019.00361) -- Support [Data-Free Adversarial Distillation](https://arxiv.org/pdf/1912.11006.pdf) -- Support [Decoupled Knowledge Distillation](https://arxiv.org/pdf/2203.08679.pdf) -- Support [Factor Transfer](https://arxiv.org/abs/1802.04977) -- Support [FitNets](https://arxiv.org/abs/1412.6550) -- Support [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531) -- Support [Overhaul](https://arxiv.org/abs/1904.01866) -- Support [Zero-shot Knowledge Transfer via Adversarial Belief Matching](https://arxiv.org/abs/1905.09768) - -### Ongoing changes - -- Quantization: We are developing the basic interface of PTQ and QAT. RFC(Request for Comments) will be released soon. -- AutoSlim: AutoSlim is not yet available and is being refactored. -- Fx Pruning Tracer: Currently, the model topology can only be resolved through the backward tracer. In the future, both backward tracer and fx tracer will be supported. -- More Algorithms: BigNAS、AutoFormer、GreedyNAS and Resrep will be released in the next few versions. -- Documentation: we will add more design docs, tutorials, and migration guidance so that the community can deep dive into our new design, participate the future development, and smoothly migrate downstream libraries to MMRazor 1.x. - -### Contributors - -A total of 13 developers contributed to this release. -Thanks @FreakieHuang @gaoyang07 @HIT-cwh @humu789 @LKJacky @pppppM @pprp @spynccat @sunnyxiaohu @wilxy @wutongshenqiu @NickYangMin @Hiwyl -Special thanks to @Davidgzx for his contribution to the data-free distillation algorithms diff --git a/docs/en/user_guides/index.rst b/docs/en/user_guides/index.rst index 622987867..96ebc0a6e 100644 --- a/docs/en/user_guides/index.rst +++ b/docs/en/user_guides/index.rst @@ -10,6 +10,15 @@ Train & Test 3_train_with_different_devices.md 4_test_a_model.md +Quantization +************ + +.. toctree:: + :maxdepth: 1 + + quantization_user_guide.md + Useful Tools ************ - please refer to upstream applied repositories' docs + +please refer to upstream applied repositories' docs diff --git a/docs/en/user_guides/quantization_user_guide.md b/docs/en/user_guides/quantization_user_guide.md new file mode 100644 index 000000000..35680630f --- /dev/null +++ b/docs/en/user_guides/quantization_user_guide.md @@ -0,0 +1,238 @@ +# Quantization + +## Introduction + +MMRazor's quantization is OpenMMLab's quantization toolkit, which has got through task models and model deployment. With its help, we can quantize and deploy pre-trained models in OpenMMLab to specified backend quickly. Of course, it can also contribute to implementing some custom quantization algorithms easier. + +### Major features + +- **Ease of use**. Benefited from PyTorch fx, we can quantize our model without modifying the original model, but with user-friendly config. +- **Multiple backends deployment support**. Because of the specificity of each backend, a gap in performance usually exists between before and after deployment. We provided some common backend deployment support to reduce the gap as much. +- **Multiple task repos support.** Benefited from OpenMMLab 2.0, our quantization can support all task repos of OpenMMLab without extra code. +- **Be compatible with PyTorch's core module in quantization**. Some core modules in PyTorch can be used directly in mmrazor, such as `Observer`, `FakeQuantize`, `BackendConfig` and so on. + +## Quick run + +```{note} +MMRazor's quantization is based on `torch==1.13`. Other requirements are the same as MMRazor's +``` + +Model quantization is in mmrazor, but quantized model deployment is in mmdeploy. So we need to the another branches as follows if we need to delopy our quantized model: + +mmdeploy: https://github.com/open-mmlab/mmdeploy/tree/for_mmrazor + +```{note} +If you try to compress mmdet's models and have used `dense_heads`, you can use this branch: +https://github.com/HIT-cwh/mmdetection/tree/for_mmrazor to avoid the problem that some code can not be traced by `torch.fx.tracer`. +``` + +1. Quantize the float model in mmrazor. + +```Shell +# For QAT (Quantization Aware Training) +python tools/train.py ${CONFIG_PATH} [optional arguments] + +# For PTQ (Post-training quantization) +python tools/ptq.py ${CONFIG_PATH} [optional arguments] +``` + +2. Evaluate the quantized model. (optional) + +```Shell +python tools/test.py ${CONFIG_PATH} ${CHECKPOINT_PATH} +``` + +3. Export quantized model to a specific backend in mmdeploy. (required by model deployment) + +```Shell +# MODEL_CFG_PATH is the used config in mmrazor. +python ./tools/deploy.py \ + ${DEPLOY_CFG_PATH} \ + ${MODEL_CFG_PATH} \ + ${MODEL_CHECKPOINT_PATH} \ + ${INPUT_IMG} \ + [optional arguments] +``` + +This step is the same as how to export an OpenMMLab model to a specific backend. For more details, please refer to [How to convert model](https://github.com/open-mmlab/mmdeploy/blob/master/docs/en/02-how-to-run/convert_model.md) + +4. Evaluate the quantized backend model. (optional) + +```Shell +python tools/test.py \ + ${DEPLOY_CFG} \ + ${MODEL_CFG} \ + --model ${BACKEND_MODEL_FILES} \ + [optional arguments] +``` + +This step is the same as evaluating backend models. For more details, please refer to [How to evaluate model](https://github.com/open-mmlab/mmdeploy/blob/master/docs/en/02-how-to-run/profile_model.md) + +## How to quantize your own model quickly + +If you want to try quantize your own model quickly, you just need to learn about how to change our provided config. + +**Case 1: If the model you want to quantize is in our provided configs.** + +You can refer to the previous chapter Quick Run. + +**Case 2: If the model you want to quantize is not in our provided configs.** + +Let us take `resnet50` as an example to show how to handle case 2. + +```Python +_base_ = [ + 'mmcls::resnet/resnet18_8xb32_in1k.py', + '../../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py' +] + +val_dataloader = dict(batch_size=32) + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=val_dataloader, + calibrate_steps=32, +) + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +model_wrapper_cfg = dict(type='mmrazor.MMArchitectureQuantDDP', ) +``` + +This is a config that quantize `resnet18` with OpenVINO backend. You just need to modify two args: `_base_` and `float_checkpoint`. + +```Python +# before +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' + +# after +_base_ = ['mmcls::resnet/resnet50_8xb32_in1k.py'] +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' +``` + +- `_base_` will be called from mmcls by mmengine, so you can just use mmcls provided configs directly. Other repos are similar. +- `float_checkpoint ` is a pre-trained float checkpoint by OpenMMLab. You can find it in the corresponding repo. + +After modifying required config, we can use it the same as case 1. + +## How to improve your quantization performance + +If you can not be satisfied with quantization performance by applying our provided configs to your own model, you can try to improve it with our provided various quantization schemes by modifying `global_qconfig`. + +```Python +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) +``` + +As shown above, `global_qconfig` contains server common core args as follows: + +- Observes + +In `forward`, they will update the statistics of the observed Tensor. And they should provide a `calculate_qparams` function that computes the quantization parameters given the collected statistics. + +```{note} +Whether it is per channel quantization depends on whether `PerChannel` is in the observer name. +``` + +Because mmrazor's quantization has been compatible with PyTorch's observers, we can use observers in PyTorch and our custom observers. + +Supported observers list in Pytorch. + +```Python +FixedQParamsObserver +HistogramObserver +MinMaxObserver +MovingAverageMinMaxObserver +MovingAveragePerChannelMinMaxObserver +NoopObserver +ObserverBase +PerChannelMinMaxObserver +PlaceholderObserver +RecordingObserver +ReuseInputObserver +UniformQuantizationObserverBase +``` + +- Fake quants + +In `forward`, they will update the statistics of the observed Tensor and fake quantize the input. They should also provide a `calculate_qparams` function that computes the quantization parameters given the collected statistics. + +Because mmrazor's quantization has been compatible with PyTorch's fakequants, we can use fakequants in PyTorch and our custom fakequants. + +Supported fakequants list in Pytorch. + +```Python +FakeQuantize +FakeQuantizeBase +FixedQParamsFakeQuantize +FusedMovingAvgObsFakeQuantize +``` + +- Qschemes + +Include some basic quantization configurations. + +`qdtype`: to specify whether quantized data type is sign or unsign. It can be chosen from \[ 'qint8', 'quint8' \] + +```{note} +If your model need to be deployed, `qdtype` must be consistent with the dtype in the corresponding backendconfig. Otherwise fakequant will not be inserted in front of the specified OPs. + +backendconfigs dir: +mmrazor/mmrazor/structures/quantization/backend_config +``` + +`bit`: to specify the quantized data bit. It can be chosen from \[1 ~ 16\]. + +`is_symmetry`: to specify whether to use symmetry quantization. It can be chosen from \[ True, False \] + +The specified qscheme is actually implemented by observers, so how to configurate other args needs to be based on the given observers, such as `is_symmetric_range` and `averaging_constant`. + +## How to customize your quantization algorithm + +If you try to customize your quantization algorithm, you can refer to the following link for more details. + +[Customize Quantization algorithms](https://github.com/open-mmlab/mmrazor/blob/quantize/docs/en/advanced_guides/customize_quantization_algorithms.md) diff --git a/mmrazor/__init__.py b/mmrazor/__init__.py index 74d91b8fa..fc5acaaeb 100644 --- a/mmrazor/__init__.py +++ b/mmrazor/__init__.py @@ -1,62 +1,18 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings - import mmcv import mmengine -from packaging.version import parse +from mmengine.utils import digit_version from .version import __version__ - -def digit_version(version_str: str, length: int = 4): - """Convert a version string into a tuple of integers. - - This method is usually used for comparing two versions. For pre-release - versions: alpha < beta < rc. - - Args: - version_str (str): The version string. - length (int): The maximum number of version levels. Default: 4. - - Returns: - tuple[int]: The version info in digits (integers). - """ - version = parse(version_str) - assert version.release, f'failed to parse version {version_str}' - release = list(version.release) - release = release[:length] - if len(release) < length: - release = release + [0] * (length - len(release)) - if version.is_prerelease: - mapping = {'a': -3, 'b': -2, 'rc': -1} - val = -4 - # version.pre can be None - if version.pre: - if version.pre[0] not in mapping: - warnings.warn(f'unknown prerelease version {version.pre[0]}, ' - 'version checking may go wrong') - else: - val = mapping[version.pre[0]] - release.extend([val, version.pre[-1]]) - else: - release.extend([val, 0]) - - elif version.is_postrelease: - release.extend([1, version.post]) # type: ignore - else: - release.extend([0, 0]) - return tuple(release) - - mmcv_minimum_version = '2.0.0rc1' -mmcv_maximum_version = '2.0.0' +mmcv_maximum_version = '2.1.0' mmcv_version = digit_version(mmcv.__version__) mmengine_minimum_version = '0.1.0' mmengine_maximum_version = '1.0.0' mmengine_version = digit_version(mmengine.__version__) - assert (mmcv_version >= digit_version(mmcv_minimum_version) and mmcv_version <= digit_version(mmcv_maximum_version)), \ f'MMCV=={mmcv.__version__} is used but incompatible. ' \ diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py index da6cec34d..8b0d4a692 100644 --- a/mmrazor/engine/__init__.py +++ b/mmrazor/engine/__init__.py @@ -4,15 +4,16 @@ from .optimizers import SeparateOptimWrapperConstructor from .runner import (AutoSlimGreedySearchLoop, DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop, EvolutionSearchLoop, - GreedySamplerTrainLoop, SelfDistillValLoop, + GreedySamplerTrainLoop, LSQEpochBasedLoop, PTQLoop, + QATEpochBasedLoop, QATValLoop, SelfDistillValLoop, SingleTeacherDistillValLoop, SlimmableValLoop, SubnetValLoop) __all__ = [ - 'SeparateOptimWrapperConstructor', 'DumpSubnetHook', - 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', - 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', - 'GreedySamplerTrainLoop', 'EstimateResourcesHook', 'SelfDistillValLoop', - 'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'StopDistillHook', - 'DMCPSubnetHook' + 'DMCPSubnetHook', 'StopDistillHook', 'SeparateOptimWrapperConstructor', + 'DumpSubnetHook', 'SingleTeacherDistillValLoop', + 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', + 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'EstimateResourcesHook', + 'SelfDistillValLoop', 'AutoSlimGreedySearchLoop', 'SubnetValLoop', + 'PTQLoop', 'QATEpochBasedLoop', 'LSQEpochBasedLoop', 'QATValLoop' ] diff --git a/mmrazor/engine/runner/__init__.py b/mmrazor/engine/runner/__init__.py index 10eb2b598..5fe2fd524 100644 --- a/mmrazor/engine/runner/__init__.py +++ b/mmrazor/engine/runner/__init__.py @@ -4,6 +4,8 @@ from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop from .evolution_search_loop import EvolutionSearchLoop from .iteprune_val_loop import ItePruneValLoop +from .quantization_loops import (LSQEpochBasedLoop, PTQLoop, QATEpochBasedLoop, + QATValLoop) from .slimmable_val_loop import SlimmableValLoop from .subnet_sampler_loop import GreedySamplerTrainLoop from .subnet_val_loop import SubnetValLoop @@ -12,5 +14,6 @@ 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop', - 'ItePruneValLoop', 'AutoSlimGreedySearchLoop' + 'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'QATEpochBasedLoop', + 'PTQLoop', 'LSQEpochBasedLoop', 'QATValLoop' ] diff --git a/mmrazor/engine/runner/iteprune_val_loop.py b/mmrazor/engine/runner/iteprune_val_loop.py index bbca5d53a..2a627f398 100644 --- a/mmrazor/engine/runner/iteprune_val_loop.py +++ b/mmrazor/engine/runner/iteprune_val_loop.py @@ -52,7 +52,6 @@ def _save_fix_subnet(self): file.write(fix_subnet) torch.save({'state_dict': static_model.state_dict()}, osp.join(self.runner.work_dir, weight_name)) - self.runner.logger.info( 'export finished and ' f'{subnet_name}, ' diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py new file mode 100644 index 000000000..58d91cf18 --- /dev/null +++ b/mmrazor/engine/runner/quantization_loops.py @@ -0,0 +1,399 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torch +from mmengine.evaluator import Evaluator +from mmengine.logging import print_log +from mmengine.runner import EpochBasedTrainLoop, TestLoop, ValLoop + +try: + from torch.ao.quantization import (disable_observer, enable_fake_quant, + enable_observer) + from torch.nn.intrinsic.qat import freeze_bn_stats +except ImportError: + from mmrazor.utils import get_placeholder + + disable_observer = get_placeholder('torch>=1.13') + enable_fake_quant = get_placeholder('torch>=1.13') + enable_observer = get_placeholder('torch>=1.13') + freeze_bn_stats = get_placeholder('torch>=1.13') + +from mmengine.dist import all_reduce_params, is_distributed +from torch.utils.data import DataLoader + +from mmrazor.models import register_torch_fake_quants, register_torch_observers +from mmrazor.models.fake_quants import (enable_param_learning, + enable_static_estimate, enable_val) +from mmrazor.registry import LOOPS + +TORCH_observers = register_torch_observers() +TORCH_fake_quants = register_torch_fake_quants() + + +@LOOPS.register_module() +class QATEpochBasedLoop(EpochBasedTrainLoop): + """`EpochBasedLoop` for `QuantizationAwareTraining` + + Args: + runner (Runner): A reference of runner + dataloader (Dataloader or dict): An iterator to generate one batch of + dataset each iteration. + max_epochs (int): Total training epochs. + val_begin (int): The epoch that begins validating. Defaults to 1. + val_interval (int): Validation interval. Defaults to 1. + disable_observer_begin (int): The number of total epochs to update + observers. Defaults to -1, which means observers are enabled + all the time. + freeze_bn_begin (int): The number of total epochs to update batch norm + stats. Defaults to -1, which means no need to freeze bn. + dynamic_intervals (List[Tuple[int, int]], optional): The + first element in the tuple is a milestone and the second + element is a interval. The interval is used after the + corresponding milestone. Defaults to None. + """ + + def __init__( + self, + runner, + dataloader: Union[DataLoader, Dict], + max_epochs: int, + val_begin: int = 1, + val_interval: int = 1, + disable_observer_begin: int = -1, + freeze_bn_begin: int = -1, + dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: + super().__init__(runner, dataloader, max_epochs, val_begin, + val_interval, dynamic_intervals) + + self.disable_observer_begin = disable_observer_begin + self.freeze_bn_begin = freeze_bn_begin + + def prepare_for_run_epoch(self): + """Toggle the state of the observers and fake quantizers before qat + training.""" + self.runner.model.apply(enable_fake_quant) + + # The initialized _epoch equals to 0 so _epoch + 1 + # equal to the current epoch + if (self.disable_observer_begin > 0 + and self._epoch + 1 >= self.disable_observer_begin): + self.runner.model.apply(disable_observer) + else: + self.runner.model.apply(enable_observer) + + if (self.freeze_bn_begin > 0 + and self._epoch + 1 >= self.freeze_bn_begin): + self.runner.model.apply(freeze_bn_stats) + + def prepare_for_val(self): + """Toggle the state of the observers and fake quantizers before + validation.""" + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(disable_observer) + + def run(self): + """Launch training.""" + self.runner.call_hook('before_train') + + while self._epoch < self._max_epochs: + self.prepare_for_run_epoch() + self.run_epoch() + + self._decide_current_val_interval() + if (self.runner.val_loop is not None + and self._epoch >= self.val_begin + and self._epoch % self.val_interval == 0): + self.runner.val_loop.run() + + self.runner.call_hook('after_train') + + def run_epoch(self) -> None: + """Iterate one epoch.""" + self.runner.call_hook('before_train_epoch') + self.runner.model.train() + + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch) + + self.runner.model.sync_qparams(src_mode='loss') + # Make sure the registered buffer such as `observer_enabled` is + # correct in the saved checkpoint. + self.prepare_for_val() + self.runner.call_hook('after_train_epoch') + self._epoch += 1 + + +@LOOPS.register_module() +class LSQEpochBasedLoop(QATEpochBasedLoop): + """`EpochBasedLoop` for `LEARNED STEP SIZE QUANTIZATION` + + Paper: Learned Step Size Quantization. + + Args: + runner (Runner): A reference of runner + dataloader (Dataloader or dict): An iterator to generate one batch of + dataset each iteration. + max_epochs (int): Total training epochs. + val_begin (int): The epoch that begins validating. Defaults to 1. + val_interval (int): Validation interval. Defaults to 1. + freeze_bn_begin (int): The number of total epochs to update batch norm + stats. Defaults to -1, which means no need to freeze bn. + dynamic_intervals (List[Tuple[int, int]], optional): The + first element in the tuple is a milestone and the second + element is a interval. The interval is used after the + corresponding milestone. Defaults to None. + """ + + def __init__( + self, + runner, + dataloader: Union[DataLoader, Dict], + max_epochs: int, + val_begin: int = 1, + val_interval: int = 1, + freeze_bn_begin: int = -1, + dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: + super().__init__( + runner, + dataloader, + max_epochs, + val_begin, + val_interval, + freeze_bn_begin=freeze_bn_begin, + dynamic_intervals=dynamic_intervals) + + self.is_first_batch = True + self.distributed = is_distributed() + + def prepare_for_run_epoch(self): + """Toggle the state of the observers and fake quantizers before qat + training.""" + if (self.freeze_bn_begin > 0 + and self._epoch + 1 >= self.freeze_bn_begin): + self.runner.model.apply(freeze_bn_stats) + + self.runner.model.apply(enable_param_learning) + + def prepare_for_val(self): + """Toggle the state of the observers and fake quantizers before + validation.""" + self.runner.model.apply(enable_val) + + def run_epoch(self) -> None: + """Iterate one epoch.""" + self.runner.call_hook('before_train_epoch') + self.runner.model.train() + + for idx, data_batch in enumerate(self.dataloader): + if self.is_first_batch: + # lsq observer init + self.runner.model.apply(enable_static_estimate) + + self.run_iter(idx, data_batch) + + if self.is_first_batch: + # In the first batch, scale in LearnableFakeQuantize is + # calculated through lsq observer. As the values of `scale` of + # different observers in different rank are usually different, + # we have to sync the `scale` here. + if self.distributed: + all_reduce_params( + self.runner.model.parameters(), op='mean') + + # Change back to param learning mode + self.is_first_batch = False + self.runner.model.apply(enable_param_learning) + + self.runner.model.sync_qparams(src_mode='loss') + # Make sure the registered buffer such as `observer_enabled` is + # correct in the saved checkpoint. + self.prepare_for_val() + self.runner.call_hook('after_train_epoch') + self._epoch += 1 + + +@LOOPS.register_module() +class QATValLoop(ValLoop): + """`ValLoop` for `QuantizationAwareTraining` + + Args: + runner (Runner): A reference of runner + dataloader (Dataloader or dict): An iterator to generate one batch of + dataset each iteration. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool): Whether to enable fp16 validation. Defaults to + False. + """ + + def __init__(self, + runner, + dataloader: Union[DataLoader, Dict], + evaluator: Union[Evaluator, Dict, List], + fp16: bool = False) -> None: + super().__init__(runner, dataloader, evaluator, fp16) + if self.runner.distributed: + assert hasattr(self.runner.model.module, 'architecture') + # TODO: remove hard code after mmcls add data_preprocessor + data_preprocessor = self.runner.model.module.data_preprocessor + self.architecture = self.runner.model.module.architecture + self.architecture.data_preprocessor = data_preprocessor + + else: + assert hasattr(self.runner.model, 'architecture') + # TODO: remove hard code after mmcls add data_preprocessor + data_preprocessor = self.runner.model.data_preprocessor + self.architecture = self.runner.model.architecture + self.architecture.data_preprocessor = data_preprocessor + + def run(self) -> dict: + """Launch validation.""" + self.runner.call_hook('before_val') + self.runner.call_hook('before_val_epoch') + self.runner.model.eval() + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch, self.runner.model) + + # compute metrics + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + qat_metrics = dict() + for key, value in metrics.items(): + qat_key = 'qat.' + key + ori_key = 'original.' + key + qat_metrics[qat_key] = value + self.runner.message_hub.log_scalars.pop(f'val/{ori_key}', None) + + self.runner.call_hook('after_val_epoch', metrics=qat_metrics) + + self.runner.call_hook('before_val_epoch') + self.runner.model.eval() + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch, self.architecture) + + # compute metrics + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + qat_metrics = dict() + for key, value in metrics.items(): + qat_key = 'qat.' + key + ori_key = 'original.' + key + qat_metrics[ori_key] = value + self.runner.message_hub.log_scalars.pop(f'val/{qat_key}', None) + + self.runner.call_hook('after_val_epoch', metrics=qat_metrics) + + self.runner.call_hook('after_val') + return qat_metrics + + @torch.no_grad() + def run_iter(self, idx, data_batch: Sequence[dict], model): + """Iterate one mini-batch. + + Args: + data_batch (Sequence[dict]): Batch of data + from dataloader. + """ + self.runner.call_hook( + 'before_val_iter', batch_idx=idx, data_batch=data_batch) + # outputs should be sequence of BaseDataElement + + outputs = model.val_step(data_batch) + self.evaluator.process(data_samples=outputs, data_batch=data_batch) + self.runner.call_hook( + 'after_val_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) + + +@LOOPS.register_module() +class PTQLoop(TestLoop): + """`TestLoop` for Post Training Quantization. + + Args: + runner (Runner): A reference of runner + dataloader (Dataloader or dict): An iterator to generate one batch of + dataset each iteration. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool, optional): Enable FP16 training mode. Defaults to False. + """ + + def __init__(self, + runner, + dataloader: Union[DataLoader, Dict], + evaluator: Union[Evaluator, Dict, List], + calibrate_dataloader: Union[DataLoader, Dict], + calibrate_steps=32, + fp16: bool = False, + only_val=False): + super().__init__(runner, dataloader, evaluator, fp16) + if isinstance(calibrate_dataloader, dict): + # Determine whether or not different ranks use different seed. + diff_rank_seed = runner._randomness_cfg.get( + 'diff_rank_seed', False) + self.calibrate_dataloader = runner.build_dataloader( + calibrate_dataloader, + seed=runner.seed, + diff_rank_seed=diff_rank_seed) + else: + self.calibrate_dataloader = calibrate_dataloader + + self.calibrate_steps = calibrate_steps + self.only_val = only_val + + def run(self) -> dict: + """Launch test.""" + self.runner.call_hook('before_test') + self.runner.call_hook('before_test_epoch') + + self.runner.model.eval() + + if not self.only_val: + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(enable_observer) + + print_log('Star calibratiion...') + for idx, data_batch in enumerate(self.calibrate_dataloader): + if idx == self.calibrate_steps: + break + self.run_iter(idx, data_batch) + print_log('Finish calibratiion!') + + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(disable_observer) + + save_dir = os.path.join(self.runner.work_dir, + self.runner.timestamp) + self.runner.save_checkpoint( + save_dir, + 'model_ptq.pth', + file_client_args=None, + save_optimizer=False, + save_param_scheduler=False) + print_log(f'Quantized model is saved in {save_dir}') + + print_log('Start Evaluating quantized model...') + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(disable_observer) + metricts = self.runner.val_loop.run() + self.runner.call_hook('after_test_epoch', metrics=metricts) + self.runner.call_hook('after_test') + + return metricts + + @torch.no_grad() + def run_iter(self, idx, data_batch: Sequence[dict]) -> None: + """Iterate one mini-batch. + + Args: + data_batch (Sequence[dict]): Batch of data from dataloader. + """ + self.runner.call_hook( + 'before_test_iter', batch_idx=idx, data_batch=data_batch) + + _ = self.runner.model.calibrate_step(data_batch) + + self.runner.call_hook( + 'after_test_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=None) diff --git a/mmrazor/models/__init__.py b/mmrazor/models/__init__.py index f5295aa9e..e5b9ec451 100644 --- a/mmrazor/models/__init__.py +++ b/mmrazor/models/__init__.py @@ -2,7 +2,11 @@ from .algorithms import * # noqa: F401,F403 from .architectures import * # noqa: F401,F403 from .distillers import * # noqa: F401,F403 +from .fake_quants import * # noqa: F401,F403 from .losses import * # noqa: F401,F403 from .mutables import * # noqa: F401,F403 from .mutators import * # noqa: F401,F403 +from .observers import * # noqa: F401,F403 +from .quantizers import * # noqa: F401,F403 from .task_modules import * # noqa: F401,F403 +from .utils import * # noqa: F401,F403 diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index 3cef96dfe..178cc6535 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -7,6 +7,7 @@ BigNAS, BigNASDDP, Darts, DartsDDP) from .pruning import DCFF, DMCP, DMCPDDP, SlimmableNetwork, SlimmableNetworkDDP from .pruning.ite_prune_algorithm import ItePruneAlgorithm +from .quantization import MMArchitectureQuant, MMArchitectureQuantDDP __all__ = [ 'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS', @@ -14,5 +15,6 @@ 'Darts', 'DartsDDP', 'DCFF', 'SelfDistill', 'DataFreeDistillation', 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'ItePruneAlgorithm', 'DSNAS', 'DSNASDDP', 'Autoformer', 'BigNAS', - 'BigNASDDP', 'DMCP', 'DMCPDDP' + 'BigNASDDP', 'DMCP', 'DMCPDDP', 'MMArchitectureQuant', + 'MMArchitectureQuantDDP' ] diff --git a/mmrazor/models/algorithms/quantization/__init__.py b/mmrazor/models/algorithms/quantization/__init__.py new file mode 100644 index 000000000..03a9538e2 --- /dev/null +++ b/mmrazor/models/algorithms/quantization/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .mm_architecture import MMArchitectureQuant, MMArchitectureQuantDDP + +__all__ = ['MMArchitectureQuant', 'MMArchitectureQuantDDP'] diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py new file mode 100644 index 000000000..ce6d926d0 --- /dev/null +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -0,0 +1,427 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from mmengine.config import Config +from mmengine.model import MMDistributedDataParallel +from mmengine.runner import load_checkpoint +from mmengine.structures import BaseDataElement +from torch import nn + +from mmrazor.models.utils import pop_rewriter_function_record +from mmrazor.registry import MODEL_WRAPPERS, MODELS +from mmrazor.structures.quantization import QConfigHandler +from ..base import BaseAlgorithm, BaseModel + +try: + from torch.ao.quantization import (FakeQuantizeBase, MinMaxObserver, + PerChannelMinMaxObserver, + disable_observer) +except ImportError: + from mmrazor.utils import get_placeholder + + FakeQuantizeBase = get_placeholder('torch>=1.13') + MinMaxObserver = get_placeholder('torch>=1.13') + PerChannelMinMaxObserver = get_placeholder('torch>=1.13') + disable_observer = get_placeholder('torch>=1.13') + +LossResults = Dict[str, torch.Tensor] +TensorResults = Union[Tuple[torch.Tensor], torch.Tensor] +PredictResults = List[BaseDataElement] +ForwardResults = Union[LossResults, TensorResults, PredictResults] + + +@MODELS.register_module() +class MMArchitectureQuant(BaseAlgorithm): + """General quantization for OpenMMLab's models. + + Args: + architecture (Union[Dict, BaseModel]): The config of model to be + quantized. + quantizer (Union[Dict, BaseModel]): The quantizer to support different + backend type. + deploy_cfg (Union[str, Dict]): Deployment config file or Config object. + qmodel_modes (List): The available mode of runner. + data_preprocessor (Optional[Dict]): The pre-process + config of :class:`BaseDataPreprocessor`. Defaults to None. + forward_modes (Tuple): The modes in forward method in OpenMMLab + architecture could be tensor, predict, or loss. It can generate + different graph of quantized model. + float_checkpoint (Optional[str]): The path of pretrained FP checkpoint. + Quantization is different from or task, we recommend to use + `float_checkpoint` as pretrain model. Defaults to None. + init_cfg (Optional[Dict]): The weight initialized config for: + class:`BaseModule`. + + Note: + forward_modes (Tuple): In OpenMMLab architecture, differenet modes + will trace a different graph of quantized model. + """ + + def __init__(self, + architecture: Union[Dict, BaseModel], + quantizer: Union[Dict, BaseModel], + deploy_cfg: Optional[Union[str, Dict]] = None, + data_preprocessor: Optional[Dict] = None, + forward_modes: Tuple = ('tensor', 'predict', 'loss'), + float_checkpoint: Optional[str] = None, + input_shapes: Tuple = (1, 3, 224, 224), + init_cfg: Optional[Dict] = None): + + super().__init__(architecture, data_preprocessor, init_cfg) + + self.quantizer = MODELS.build(quantizer) + self.input_shapes = input_shapes + self.forward_modes = forward_modes + if isinstance(deploy_cfg, str): + deploy_cfg = Config.fromfile(deploy_cfg) + self.deploy_cfg = deploy_cfg + + # Replace syncbn and _BatchNormXd (in mmengine) with batchnorm2d + self.quantizer.convert_batchnorm2d(self.architecture) + + # If we have a float_checkpoint, we load it as pretrain. + if float_checkpoint: + _ = load_checkpoint(self.architecture, float_checkpoint) + self.architecture._is_init = True + + self.qmodels = self._build_qmodels(self.architecture) + self.sync_qparams('tensor') + self.reset_observer_and_fakequant_statistics(self) + + def reset_observer_and_fakequant_statistics(self, model): + """Reset the statistics in observers and fake quantizers. + + The forward computation in `_build_qmodels` can modify the original + statistics in observers and fake quantizers. + """ + for module in model.modules(): + if isinstance(module, (MinMaxObserver, PerChannelMinMaxObserver)): + module.reset_min_max_vals() + elif isinstance(module, FakeQuantizeBase): + module.scale.data = torch.ones_like(module.scale) + module.zero_point.data = torch.zeros_like(module.zero_point) + + def sync_qparams(self, src_mode: str): + """Sync all quantize parameters in different `forward_modes`. We could + have more than one forward mode to generate graphs, each mode will + generate one graph. But in training, only one graph will be update, so + we need to sync qparams in the other graphs. + + Args: + src_mode (str): The modes of forward method. + + Note: + `traverse()` method recursively traverses all modules to sync + quantized graph generated from different `forward_modes`. + This is because We have different mode ('tensor', 'predict', + 'loss') in OpenMMLab architecture which have different graph + in some subtle ways, so we need to sync them here. + """ + + def traverse(module, prefix): + for name, child in module._modules.items(): + if module is None: + continue + child_name = f'{prefix}{name}' + if isinstance(child, FakeQuantizeBase): + for name, param in child.named_parameters(): + param_name = f'{child_name}.{name}' + src_param = src_state_dict[param_name] + if src_param.shape == param.shape: + param.data.copy_(src_param) + else: + requirs_grad = param.requires_grad + param.requires_grad = False + param.resize_(src_param.shape) + param.requires_grad = requirs_grad + param.data.copy_(src_param) + for name, buffer in child.named_buffers(): + buffer_name = f'{child_name}.{name}' + src_buffer = src_state_dict[buffer_name] + if src_buffer.shape == buffer.shape: + buffer.data.copy_(src_buffer) + else: + buffer.resize_(src_buffer.shape) + buffer.data.copy_(src_buffer) + else: + traverse(child, f'{child_name}.') + + src_state_dict = self.qmodels[src_mode].state_dict() + for mode in self.forward_modes: + if mode == src_mode: + continue + traverse(self.qmodels[mode], '') + + def _get_rewriter_context_in_mmdeploy(self, deploy_cfg): + """Get rewriter context in mmdeploy according to the deploy related + config.""" + from mmdeploy.apis.onnx.passes import optimize_onnx + from mmdeploy.codebase import import_codebase + from mmdeploy.core import RewriterContext + from mmdeploy.utils import (IR, Backend, get_backend, get_codebase, + get_dynamic_axes, get_ir_config, + get_onnx_config) + from mmdeploy.utils.config_utils import get_codebase_external_module + + codebase = get_codebase(deploy_cfg) + custom_module_list = get_codebase_external_module(deploy_cfg) + import_codebase(codebase, custom_module_list) + + def _add_or_update(cfg: dict, key: str, val: Any): + if key in cfg and isinstance(cfg[key], dict) and isinstance( + val, dict): + cfg[key].update(val) + else: + cfg[key] = val + + context_info = dict() + deploy_cfg = copy.deepcopy(deploy_cfg) + + backend = get_backend(deploy_cfg).value + + onnx_cfg = get_onnx_config(deploy_cfg) + opset_version = onnx_cfg.get('opset_version', 11) + + input_names = onnx_cfg['input_names'] + output_names = onnx_cfg['output_names'] + axis_names = input_names + output_names + dynamic_axes = get_dynamic_axes(deploy_cfg, axis_names) + + verbose = not onnx_cfg.get('strip_doc_string', True) or onnx_cfg.get( + 'verbose', False) + keep_initializers_as_inputs = onnx_cfg.get( + 'keep_initializers_as_inputs', True) + optimize = onnx_cfg.get('optimize', False) + if backend == Backend.NCNN.value: + """NCNN backend needs a precise blob counts, while using onnx + optimizer will merge duplicate initilizers without reference + count.""" + optimize = False + + ir_config = dict( + type='onnx', + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + verbose=verbose, + keep_initializers_as_inputs=keep_initializers_as_inputs) + + _add_or_update(deploy_cfg, 'ir_config', ir_config) + ir = IR.get(get_ir_config(deploy_cfg)['type']) + if isinstance(backend, Backend): + backend = backend.value + backend_config = dict(type=backend) + _add_or_update(deploy_cfg, 'backend_config', backend_config) + + context_info['cfg'] = deploy_cfg + context_info['ir'] = ir + if 'backend' not in context_info: + context_info['backend'] = backend + if 'opset' not in context_info: + context_info['opset'] = opset_version + + if 'onnx_custom_passes' not in context_info: + onnx_custom_passes = optimize_onnx if optimize else None + context_info['onnx_custom_passes'] = onnx_custom_passes + + return RewriterContext(**context_info) + + def _pop_function_record_in_rewriter_context(self, rewriter_context): + """Delete user-specific rewriters from + `RewriterContext._rewriter_manager`. We use the model which is + rewritten by mmdeploy to build quantized models. However not all the + functions rewritten by mmdeploy need to be rewritten in mmrazor. For + example, mmdeploy rewrite + `mmcls.models.classifiers.ImageClassifier.forward` and + `mmcls.models.classifiers.BaseClassifier.forward` for deployment. But + they can't be rewritten by mmrazor as ptq and qat are done in mmrazor. + So to ensure ptq and qat proceed normally, we have to remove these + record from `RewriterContext._rewriter_manager`. + + Args: + rewriter_context (RewriterContext): The RewriterContext used in + mmdeploy. + """ + skipped_methods = getattr(self.quantizer.tracer, 'skipped_methods', []) + function_record_to_pop = self.deploy_cfg.get('function_record_to_pop', + []) + function_record_to_pop.extend(skipped_methods) + return pop_rewriter_function_record(rewriter_context, + function_record_to_pop) + + def _build_qmodels(self, model: BaseModel): + """Build quantized models from the given model. + + Args: + model (BaseModel): the given fp model. + + Example: + The main body of the graph is all the same, but the last one or two + op will have difference, as shown below. + + self.qmodels['tensor'].graph.print_tabular() + opcode target args + call_module head.fc (activation_post_process_38,) + output output (head_fc,) + + self.qmodels['loss'].graph.print_tabular() + opcode target args + call_method _get_loss (head, head_fc, data_samples) + output output (_get_loss,) + + self.qmodels['predict'].graph.print_tabular() + opcode target args + call_method _get_predictions (head, head_fc, data_samples) + output output (_get_predictions,) + """ + + rewriter_context = self._get_rewriter_context_in_mmdeploy( + self.deploy_cfg) if self.deploy_cfg is not None else None + + if rewriter_context is not None: + # Pop function records in `quantizer.tracer.skipped_method` + # temporarily + function_record_backup = \ + self._pop_function_record_in_rewriter_context(rewriter_context) + + qmodels = nn.ModuleDict() + for mode in self.forward_modes: + concrete_args = {'mode': mode} + + if rewriter_context is not None: + with rewriter_context: + observed_module = self.quantizer.prepare( + model, concrete_args) + else: + observed_module = self.quantizer.prepare(model, concrete_args) + + qmodels[mode] = observed_module + + if rewriter_context is not None: + # Add these popped function records back. + rewriter_context._rewriter_manager.function_rewriter. \ + _registry._rewrite_records.update(function_record_backup) + + # data_samples can not be None in detectors during prediction. + # But we need to make the dummy prediction in _build_qmodels. + # It is more convenient to use `tensor` mode. + is_training = qmodels['tensor'].training + # Avoid random input changing bn's statistics + qmodels['tensor'].eval() + # Originally, the steps to train a qat model is as follows: + # 1. build qmodels 2. convert the model to ddpmodel 3. forward backward + # The shape of `scale` and `zero_point` can be modified during forward. + # We initialize these parameters with per-tensor mode by default for + # convenience. Their shape will be modified during forward if + # per-channel mode is used. It's hacky. Hence we need to input a + # dummy input to make sure the shape has been modified. + device = next(qmodels.parameters()).device + dummy_input = torch.randn(self.input_shapes).to(device) + qmodels['tensor'](dummy_input, None, 'tensor') + qmodels['tensor'].train(mode=is_training) + + return qmodels + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + mode: str = 'tensor') -> ForwardResults: + """Forward with qmodels in quantization.""" + + if mode in self.qmodels: + qmodel = self.qmodels[mode] + return qmodel(inputs, data_samples, mode) + else: + return self.architecture(inputs, data_samples, mode) + + def calibrate_step(self, data: Union[Dict, Tuple, List]): + """PTQ method need calibrate by cali data.""" + + data = self.data_preprocessor(data, False) + return self._run_forward(data, mode='predict') + + def get_deploy_model(self): + """Prepare for deploy to the backend with mmdeploy, which will be used + in mmdeploy, and usually includes as follows: + + 1. prepare for the float model rewritten by mmdeploy. + 2. load checkpoint consists of float weight and quantized params in + mmrazor. + 3. post process weight fakequant for exporting .onnx that meet + the backend's requirement. + """ + device = next(self.parameters()).device + quantized_state_dict = self.qmodels['predict'].state_dict() + fp32_model = self.architecture + self.quantizer.convert_batchnorm2d(fp32_model) + observed_model = self.quantizer.prepare(fp32_model) + observed_model.load_state_dict(quantized_state_dict) + + self.quantizer.post_process_for_deploy( + observed_model, + device=device, + keep_w_fake_quant=True, + update_weight_with_fakequant=True) + + # replace various activation fakequant with base fakequant, which + # contributes to deploy our model to various backends. + for node in observed_model.graph.nodes: + if 'activation_post_process_' in node.name: + module_name = node.target + module = getattr(observed_model, module_name) + fakequant_new = QConfigHandler.replace_fakequant( + module, + self.quantizer.qconfig.a_qscheme, + update_qparams=True) + setattr(observed_model, module_name, fakequant_new) + + observed_model.apply(disable_observer) + + return observed_model + + +@MODEL_WRAPPERS.register_module() +class MMArchitectureQuantDDP(MMDistributedDataParallel): + """DDPwapper for MMArchitectureQuant. + + Args: + device_ids (Optional[Union[List, int, torch.device]]): devices to run + ddp. + """ + + def __init__(self, + *, + device_ids: Optional[Union[List, int, torch.device]] = None, + **kwargs) -> None: + + if device_ids is None: + if os.environ.get('LOCAL_RANK') is not None: + device_ids = [int(os.environ['LOCAL_RANK'])] + super().__init__(device_ids=device_ids, **kwargs) + # After moving all model parameters and buffers to the GPU + # (`model.cuda()`), the buffers in model are different. + self.module.qmodels = self.module._build_qmodels( + self.module.architecture) + self.module.sync_qparams('tensor') + self.module.reset_observer_and_fakequant_statistics(self) + + def calibrate_step(self, data: Union[Dict, Tuple, List]): + """PTQ method need calibrate by cali data.""" + + return self.module.calibrate_step(data) + + def sync_qparams(self, src_mode: str): + """Same as in 'MMArchitectureQuant'. Sync all quantize parameters in + different `forward_modes`. We could have several modes to generate + graphs, but in training, only one graph will be update, so we need to + sync qparams on the other graphs. + + Args: + src_mode (str): The src modes of forward method. + """ + + self.module.sync_qparams(src_mode) diff --git a/mmrazor/models/fake_quants/__init__.py b/mmrazor/models/fake_quants/__init__.py new file mode 100644 index 000000000..950821210 --- /dev/null +++ b/mmrazor/models/fake_quants/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseFakeQuantize +from .lsq import (LearnableFakeQuantize, enable_param_learning, + enable_static_estimate, enable_val) +from .torch_fake_quants import register_torch_fake_quants + +__all__ = [ + 'BaseFakeQuantize', 'register_torch_fake_quants', 'LearnableFakeQuantize', + 'enable_val', 'enable_param_learning', 'enable_static_estimate' +] diff --git a/mmrazor/models/fake_quants/base.py b/mmrazor/models/fake_quants/base.py new file mode 100644 index 000000000..45aed7421 --- /dev/null +++ b/mmrazor/models/fake_quants/base.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +try: + from torch.ao.quantization import FakeQuantize +except ImportError: + from mmrazor.utils import get_placeholder + FakeQuantize = get_placeholder('torch>=1.13') + +BaseFakeQuantize = FakeQuantize diff --git a/mmrazor/models/fake_quants/lsq.py b/mmrazor/models/fake_quants/lsq.py new file mode 100644 index 000000000..1689d0393 --- /dev/null +++ b/mmrazor/models/fake_quants/lsq.py @@ -0,0 +1,313 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parameter import Parameter + +from mmrazor.registry import MODELS + +try: + from torch.ao.quantization import FakeQuantizeBase +except ImportError: + from mmrazor.utils import get_placeholder + FakeQuantizeBase = get_placeholder('torch>=1.13') + + +def enable_param_learning(mod): + """Enables learning of quantization parameters, if applicable. Example + usage:: + + # model is any PyTorch model model.apply(enable_param_learning) + """ + if isinstance(mod, LearnableFakeQuantize): + mod.enable_param_learning() + + +def enable_static_estimate(mod): + """Enables static observer estimates, if applicable. Example usage:: + + # model is any PyTorch model model.apply(enable_static_estimate) + """ + if isinstance(mod, LearnableFakeQuantize): + mod.enable_static_estimate() + + +def enable_val(mod): + """Enable validation, if applicable. Example usage:: + + # model is any PyTorch model model.apply(enable_val) + """ + if isinstance(mod, LearnableFakeQuantize): + mod.enable_val() + + +@MODELS.register_module() +class LearnableFakeQuantize(FakeQuantizeBase): + """This is an extension of the FakeQuantize module in fake_quantize.py, + which supports learning of the scale and zero point parameters through + backpropagation. + + In addition to the attributes in the original FakeQuantize module, the + LearnableFakeQuantize module also includes the following attributes to + support quantization parameter learning. + + * :attr:`fake_quant_enabled` defines the flag for enabling fake + quantization on the output. + + * :attr:`static_enabled` defines the flag for using observer's static + estimation for scale and zero point. + + * :attr:`learning_enabled` defines the flag for enabling backpropagation + for scale and zero point. + + Args: + observer (module): Module for observing statistics on input tensors and + calculating scale and zero-point. + quant_min (int): Minimum quantization value. If unspecified, it will + follow the 8-bit setup. + quant_max (int): Maximum quantization value. If unspecified, it will + follow the 8-bit setup. + scale (float): The initial value of the floating-point scale factor. + Defaults to 1. + zero_point (float): The initial value of the floating-point zero-point. + Defaults to 0. + use_grad_scaling (bool): Whether the gradients for scale and zero point + are normalized by the constant, which is proportional to the square + root of the number of elements in the tensor. The related + literature justifying the use of this particular constant can be + found here: https://openreview.net/pdf?id=rkgO66VKDS. Defaults to + True. + zero_point_trainable (bool): Whether the zero_point is trainable. + Defaults to False. + observer_kwargs (dict | optional): Arguments for the observer module. + """ + + def __init__(self, + observer, + quant_min=0, + quant_max=255, + scale=1., + zero_point=0., + use_grad_scaling=True, + zero_point_trainable=False, + **observer_kwargs): + super(LearnableFakeQuantize, self).__init__() + assert quant_min < quant_max, \ + 'quant_min must be strictly less than quant_max.' + self.quant_min = quant_min + self.quant_max = quant_max + # also pass quant_min and quant_max to observer + observer_kwargs['quant_min'] = quant_min + observer_kwargs['quant_max'] = quant_max + self.use_grad_scaling = use_grad_scaling + + self.scale = Parameter(torch.tensor([scale])) + self.zero_point_trainable = zero_point_trainable + if zero_point_trainable: + self.zero_point = Parameter(torch.tensor([zero_point])) + else: + self.register_buffer('zero_point', torch.tensor([zero_point])) + + self.activation_post_process = observer(**observer_kwargs) + assert \ + torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \ + 'quant_min out of bound' + assert \ + quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \ + 'quant_max out of bound' + self.dtype = self.activation_post_process.dtype + self.qscheme = self.activation_post_process.qscheme + self.ch_axis = self.activation_post_process.ch_axis \ + if hasattr(self.activation_post_process, 'ch_axis') else -1 + self.register_buffer('fake_quant_enabled', + torch.tensor([1], dtype=torch.uint8)) + self.register_buffer('static_enabled', + torch.tensor([1], dtype=torch.uint8)) + self.register_buffer('learning_enabled', + torch.tensor([0], dtype=torch.uint8)) + + bitrange = torch.tensor(quant_max - quant_min + 1).double() + self.bitwidth = int(torch.log2(bitrange).item()) + self.register_buffer('eps', + torch.tensor([torch.finfo(torch.float32).eps])) + + @torch.jit.export + def enable_param_learning(self): + """Enables learning of quantization parameters and disables static + observer estimates. + + Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=True) \ + .toggle_fake_quant(enabled=True) \ + .toggle_observer_update(enabled=False) + return self + + @torch.jit.export + def enable_static_estimate(self): + """Enables static observer estimates and disables learning of + quantization parameters. + + Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=False) \ + .toggle_fake_quant(enabled=True) \ + .toggle_observer_update(enabled=True) + + @torch.jit.export + def enable_val(self): + """Disables static observer accumulating data from input and doesn't + update the quantization parameters. + + Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=False) \ + .toggle_fake_quant(enabled=True) \ + .toggle_observer_update(enabled=False) + + @torch.jit.export + def enable_static_observation(self): + """Enables static observer accumulating data from input but doesn't + update the quantization parameters. + + Forward path returns the original X. + """ + self.toggle_qparam_learning(enabled=False) \ + .toggle_fake_quant(enabled=False) \ + .toggle_observer_update(enabled=True) + + @torch.jit.export + def toggle_observer_update(self, enabled=True): + """Toggles whether static observer accumulates data from input.""" + self.static_enabled[0] = int(enabled) + return self + + @torch.jit.export + def enable_observer(self, enabled=True): + """Enables static observer accumulating data from input.""" + self.toggle_observer_update(enabled) + + @torch.jit.export + def toggle_qparam_learning(self, enabled=True): + """Toggles whether the quantization parameters are learnable.""" + self.learning_enabled[0] = int(enabled) + self.scale.requires_grad = enabled + if self.zero_point_trainable: + self.zero_point.requires_grad = enabled + return self + + @torch.jit.export + def toggle_fake_quant(self, enabled=True): + """Toggles whether the fake quantization is enabled.""" + self.fake_quant_enabled[0] = int(enabled) + return self + + @torch.jit.export + def observe_quant_params(self): + """Shows the quantization parameters.""" + print('LearnableFakeQuantize Scale: {}'.format(self.scale.detach())) + print('LearnableFakeQuantize Zero Point: {}'.format( + self.zero_point.detach())) + + @torch.jit.export + def calculate_qparams(self): + """Calculate the quantization parameters.""" + self.scale.data.clamp_(min=self.eps.item()) + scale = self.scale.detach() + zero_point = self.zero_point.detach().round().clamp( + self.quant_min, self.quant_max).long() + return scale, zero_point + + def forward(self, X): + """Forward computation. + + Forward path returns fake quantized X. + """ + if self.static_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = \ + self.activation_post_process.calculate_qparams() + _scale = _scale.to(self.scale.device) + _zero_point = _zero_point.to(self.zero_point.device) + + if self.qscheme in (torch.per_channel_symmetric, + torch.per_channel_affine): + self.scale.data = torch.ones_like(_scale) + self.zero_point.data = torch.zeros_like(_zero_point.float()) + + self.scale.data.copy_(_scale) + self.zero_point.data.copy_(_zero_point) + else: + self.scale.data.clamp_(min=self.eps.item()) + + if self.fake_quant_enabled[0] == 1: + + if self.use_grad_scaling: + grad_factor = 1.0 / (X.numel() * self.quant_max)**0.5 + else: + grad_factor = 1.0 + if self.qscheme in (torch.per_channel_symmetric, + torch.per_channel_affine): + X = torch._fake_quantize_learnable_per_channel_affine( + X, self.scale, self.zero_point, self.ch_axis, + self.quant_min, self.quant_max, grad_factor) + else: + if not (self.quant_min <= self.zero_point <= self.quant_max): + print(self.quant_min, self.zero_point, self.quant_max) + X = torch._fake_quantize_learnable_per_tensor_affine( + X, self.scale, self.zero_point, self.quant_min, + self.quant_max, grad_factor) + + return X + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """Removing this function throws an error that the the size of the + loaded tensor does not match the original size i.e., These buffers + start out with numel 0 and become numel 1 once they have their first + forward pass. + + Modified from https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fake_quantize.py # noqa:E501 + """ + local_state = ['scale', 'zero_point'] + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + # Custom handling to allow loading scale and zero_point + # of size N into uninitialized buffers of size 0. The + # buffers are resized here, and the values are copied in + # the default state_dict loading code of the parent. + if name == 'scale': + self.scale.data = self.scale.data.resize_(val.shape) + else: + assert name == 'zero_point' + self.zero_point.data = self.zero_point.data.resize_( + val.shape) + # For torchscript module we need to update the attributes here + # since we do not call the `_load_from_state_dict` function + # defined module.py + if torch.jit.is_scripting(): + if name == 'scale': + self.scale.copy_(val) + else: + assert name == 'zero_point' + self.zero_point.copy_(val) + elif strict: + missing_keys.append(key) + super(LearnableFakeQuantize, + self)._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, + unexpected_keys, error_msgs) + + @torch.jit.export + def extra_repr(self): + """The printable representational string.""" + repr_str = f'static_enabled={self.static_enabled}, ' + repr_str += f'fake_quant_enabled={self.fake_quant_enabled}, ' + repr_str += f'quant_min={self.activation_post_process.quant_min}, ' + repr_str += f'quant_max={self.activation_post_process.quant_max}, ' + repr_str += f'dtype={self.dtype}, ' + repr_str += f'qscheme={self.qscheme}, ' + repr_str += f'scale={self.scale}, ' + repr_str += f'zero_point={self.zero_point}, ' + repr_str += f'zero_point_trainable={self.zero_point_trainable}' + return repr_str diff --git a/mmrazor/models/fake_quants/torch_fake_quants.py b/mmrazor/models/fake_quants/torch_fake_quants.py new file mode 100644 index 000000000..06e325b32 --- /dev/null +++ b/mmrazor/models/fake_quants/torch_fake_quants.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +from typing import List + +from mmrazor.registry import MODELS + +try: + import torch.ao.quantization.fake_quantize as torch_fake_quant_src +except ImportError: + from mmrazor.utils import get_package_placeholder + torch_fake_quant_src = get_package_placeholder('torch>=1.13') + + +# TORCH_fake_quants = register_torch_fake_quants() +# TORCH_fake_quants including: +# FakeQuantize +# FakeQuantizeBase +# FixedQParamsFakeQuantize +# FusedMovingAvgObsFakeQuantize +def register_torch_fake_quants() -> List[str]: + """Register fake_quants in ``torch.ao.quantization.fake_quantize`` to the + ``MODELS`` registry. + + Returns: + List[str]: A list of registered fake_quants' name. + """ + torch_fake_quants = [] + for module_name in dir(torch_fake_quant_src): + if module_name.startswith('__') or module_name.startswith('_') or \ + module_name.startswith('default'): + continue + _fake_quant = getattr(torch_fake_quant_src, module_name) + if inspect.isclass(_fake_quant) and issubclass( + _fake_quant, torch_fake_quant_src.FakeQuantizeBase): + if MODELS.get(module_name) is None: + MODELS.register_module(module=_fake_quant) + torch_fake_quants.append(module_name) + return torch_fake_quants diff --git a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py index d0a2deff0..251214f70 100644 --- a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. """This module defines MutableChannelUnit.""" import abc -from collections import Set +# from collections import set from typing import Dict, List, Type, TypeVar import torch @@ -72,7 +72,7 @@ def process_container(container: MutableChannelContainer, if isinstance(derived_choices, torch.Tensor): derived_choices = derived_choices.sum().item() if isinstance(mutable, DerivedMutable): - source_mutables: Set = \ + source_mutables: set = \ mutable._trace_source_mutables() source_channel_mutables = [ mutable for mutable in source_mutables diff --git a/mmrazor/models/observers/__init__.py b/mmrazor/models/observers/__init__.py new file mode 100644 index 000000000..84d1677dd --- /dev/null +++ b/mmrazor/models/observers/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseObserver +from .lsq import LSQObserver, LSQPerChannelObserver +from .torch_observers import register_torch_observers + +__all__ = [ + 'BaseObserver', 'register_torch_observers', 'LSQObserver', + 'LSQPerChannelObserver' +] diff --git a/mmrazor/models/observers/base.py b/mmrazor/models/observers/base.py new file mode 100644 index 000000000..ce226cb48 --- /dev/null +++ b/mmrazor/models/observers/base.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +try: + from torch.ao.quantization.observer import UniformQuantizationObserverBase +except ImportError: + from mmrazor.utils import get_placeholder + UniformQuantizationObserverBase = get_placeholder('torch>=1.13') + +BaseObserver = UniformQuantizationObserverBase diff --git a/mmrazor/models/observers/lsq.py b/mmrazor/models/observers/lsq.py new file mode 100644 index 000000000..ccab3b0e6 --- /dev/null +++ b/mmrazor/models/observers/lsq.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.distributed as dist + +from mmrazor.registry import MODELS + +try: + from torch.ao.quantization.observer import (MinMaxObserver, + PerChannelMinMaxObserver) +except ImportError: + from mmrazor.utils import get_placeholder + MinMaxObserver = get_placeholder('torch>=1.13') + PerChannelMinMaxObserver = get_placeholder('torch>=1.13') + + +def sync_tensor(tensor): + """Synchronize the target tensor during distributed training.""" + if torch.distributed.is_initialized() and tensor.is_cuda: + tensor.data = tensor.data / dist.get_world_size() + dist.all_reduce(tensor.data) + return tensor + + +class LSQObserverMixIn: + """A mixin class for LSQObserver which can provide the initialized + floating-point scale factor.""" + + def __init__(self): + self.tensor_norm = None + + @torch.jit.export + def _calculate_scale(self): + """Calculate the initialized floating-point scale factor. + + Each layer of weights and each layer of activations has a distinct step + size, represented as a fp32 value, initialized to 2<|v|> / sqrt(Q_p), + computed on either the initial weights values or the first batch of + activations, respectively. + """ + scale = 2 * self.tensor_norm / math.sqrt(self.quant_max) + sync_tensor(scale) + return scale + + +@MODELS.register_module() +class LSQObserver(MinMaxObserver, LSQObserverMixIn): + """LSQ observer. + + Paper: Learned Step Size Quantization. + """ + + def __init__(self, *args, **kwargs): + MinMaxObserver.__init__(self, *args, **kwargs) + LSQObserverMixIn.__init__(self) + + def forward(self, x_orig): + """Records the running minimum, maximum and tensor_norm of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + self.tensor_norm = x.abs().mean() + min_val_cur, max_val_cur = torch.aminmax(x) + min_val = torch.min(min_val_cur, self.min_val) + max_val = torch.max(max_val_cur, self.max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): + """Calculates the quantization parameters.""" + _, zero_point = MinMaxObserver.calculate_qparams(self) + scale = LSQObserverMixIn._calculate_scale(self) + return scale, zero_point + + +@MODELS.register_module() +class LSQPerChannelObserver(PerChannelMinMaxObserver, LSQObserverMixIn): + """LSQ per-channel observer. + + Paper: Learned Step Size Quantization. + """ + + def __init__(self, *args, **kwargs): + PerChannelMinMaxObserver.__init__(self, *args, **kwargs) + LSQObserverMixIn.__init__(self) + + def forward(self, x_orig): + """Records the per-channel running minimum, maximum and tensor_norm of + ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + min_val = self.min_val + max_val = self.max_val + x_dim = x.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + + self.tensor_norm = y.abs().mean(1) + + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): + """Calculates the quantization parameters.""" + _, zero_point = PerChannelMinMaxObserver.calculate_qparams(self) + scale = LSQObserverMixIn._calculate_scale(self) + return scale, zero_point diff --git a/mmrazor/models/observers/torch_observers.py b/mmrazor/models/observers/torch_observers.py new file mode 100644 index 000000000..4e540667a --- /dev/null +++ b/mmrazor/models/observers/torch_observers.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +from typing import List + +import torch + +from mmrazor.registry import MODELS + +try: + import torch.ao.quantization.observer as torch_observer_src + from torch.ao.quantization.observer import PerChannelMinMaxObserver +except ImportError: + from mmrazor.utils import get_package_placeholder + torch_observer_src = get_package_placeholder('torch>=1.13') + PerChannelMinMaxObserver = get_package_placeholder('torch>=1.13') + + +@torch.jit.export +def reset_min_max_vals(self): + """Resets the min/max values. + + `min_val` and `max_val` are always be on cpu in the pytorch version of this + method. + """ + min_val = torch.rand(0, ) + max_val = torch.rand(0, ) + self.min_val.resize_(min_val.shape).copy_(min_val) + self.max_val.resize_(max_val.shape).copy_(max_val) + + +PerChannelMinMaxObserver.reset_min_max_vals = reset_min_max_vals + + +# TORCH_observers = register_torch_observers() +# TORCH_observers including: +# FixedQParamsObserver +# HistogramObserver +# MinMaxObserver +# MovingAverageMinMaxObserver +# MovingAveragePerChannelMinMaxObserver +# NoopObserver +# ObserverBase +# PerChannelMinMaxObserver +# PlaceholderObserver +# RecordingObserver +# ReuseInputObserver +# UniformQuantizationObserverBase +def register_torch_observers() -> List[str]: + """Register observers in ``torch.ao.quantization.observer`` to the + ``MODELS`` registry. + + Returns: + List[str]: A list of registered observers' name. + """ + torch_observers = [] + for module_name in dir(torch_observer_src): + if module_name.startswith('__') or module_name.startswith('_') or \ + module_name.startswith('default'): + continue + _observer = getattr(torch_observer_src, module_name) + if inspect.isclass(_observer) and issubclass( + _observer, torch_observer_src.ObserverBase): + if MODELS.get(module_name) is None: + MODELS.register_module(module=_observer) + torch_observers.append(module_name) + return torch_observers diff --git a/mmrazor/models/quantizers/__init__.py b/mmrazor/models/quantizers/__init__.py new file mode 100644 index 000000000..a26bb1322 --- /dev/null +++ b/mmrazor/models/quantizers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .academic_quantizer import AcademicQuantizer +from .base import BaseQuantizer +from .native_quantizer import TorchNativeQuantizer +from .openvino_quantizer import OpenVINOQuantizer +from .tensorrt_quantizer import TensorRTQuantizer + +__all__ = [ + 'BaseQuantizer', 'AcademicQuantizer', 'TorchNativeQuantizer', + 'TensorRTQuantizer', 'OpenVINOQuantizer' +] diff --git a/mmrazor/models/quantizers/academic_quantizer.py b/mmrazor/models/quantizers/academic_quantizer.py new file mode 100644 index 000000000..0dbe6dcdd --- /dev/null +++ b/mmrazor/models/quantizers/academic_quantizer.py @@ -0,0 +1,170 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional + +import torch + +from mmrazor.models.task_modules.tracer import build_graphmodule +from mmrazor.models.utils import str2class +from mmrazor.registry import MODELS +from mmrazor.structures.quantization import BackendConfigs, QConfigHandler +from .base import BaseQuantizer + +try: + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.fx.custom_config import (FuseCustomConfig, + PrepareCustomConfig) + from torch.ao.quantization.qconfig_mapping import QConfigMapping + from torch.ao.quantization.quantize_fx import _fuse_fx +except ImportError: + from mmrazor.utils import get_placeholder + prepare = get_placeholder('torch>=1.13') + FuseCustomConfig = get_placeholder('torch>=1.13') + PrepareCustomConfig = get_placeholder('torch>=1.13') + QConfigMapping = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + +GLOBAL_DICT_KEY = '_global_' +OBJECT_TYPE_DICT_KEY = 'object_type' +MODULE_NAME_DICT_KEY = 'module_name' + +# keys can be used in `prepare_custom_config` of `AcademicQuantizer`. +FLOAT_TO_OBSERVED_DICT_KEY = 'float_to_observed_custom_module_class' +PRESERVED_ATTRIBUTES_DICT_KEY = 'preserved_attributes' + + +@MODELS.register_module() +class AcademicQuantizer(BaseQuantizer): + """Quantizer for academic researching. Different from some quantizers for + deploying, `AcademicQuantizer` is without the interfaces for deployment, + but it has more flexible functions for quantizing your model. With its + help, you can custom configuration qconfig for differenet OP by + `qconfig_mapping` to implement customized experiments, including using + custom fakquant, trying mixed precision quantization, comparing different + quantization scheme and so on. + + Args: + qconfig_mapping (Dict): Mapping from model ops to qconfig to configure + how a model is quantized. You can specify qconfigs using the + following keys (in increasing match priority): + ``_global_`` : sets the global (default) qconfig + ``object_type`` : sets the qconfig for a given module type, + function, or method name + ``module_name`` : sets the qconfig for modules matching the + given module name + tracer (Dict): It can be used to trace the float model to generate the + corresponding graph, which contributes to prepare for quantizing + the float model with code-free. Default to + `dict(type='mmrazor.CustomTracer')`. + prepare_custom_config (Optional[Dict]): Custom configuration for + :func:`~torch.ao.quantization.fx.prepare`. You can specify the + follow: + ``float_to_observed_custom_module_class`` : a list of dict that + mapping from float module classes to observed module + classes, e.g. + `[('FloatCustomModule', 'ObservedCustomModule')]` + ``preserved_attributes``: a list of attributes that persist + even if they are not used in ``forward``, e.g. + `['attr1', 'attr2']` + """ + + def __init__(self, + qconfig_mapping: Dict, + tracer: Dict = dict(type='mmrazor.CustomTracer'), + prepare_custom_config: Optional[Dict] = None): + super().__init__(tracer) + self.qconfig_mapping = self.gen_qconfig_mapping(qconfig_mapping) + self.prepare_custom_config = self.gen_prepare_custom_config( + prepare_custom_config) + self.backend_config = BackendConfigs[self.backend] + self.example_inputs = (torch.randn(1, 3, 224, 224), ) + + @property + def backend(self): + """The key of the corresponding backend config.""" + return 'academic' + + def prepare(self, model, concrete_args=None): + """Prepare for quantizing model, which includes as follows: + + 1. Swap floatfunctional with FXFloatFunctional; + 2. Trace model to generate `GraphModule`; + 2. Fuse some OPs combination, such as conv + bn, conv + relu and so on; + 3. Swap some conv or linear module with QAT Modules which contain + weight fakequant nodes; + 4. Insert required fakequant nodes for activation. + step 3 and step 4 are implemented in + :func:`~torch.ao.quantization.fx.prepare` + """ + self.swap_ff_with_fxff(model) + traced_graph = self.tracer.trace(model, concrete_args=concrete_args) + graph_module = build_graphmodule(model, traced_graph) + preserved_attributes = self.prepare_custom_config.preserved_attributes + for attr_name in preserved_attributes: + setattr(graph_module, attr_name, getattr(model, attr_name)) + fuse_custom_config = FuseCustomConfig().set_preserved_attributes( + preserved_attributes) + + # set the training modes of all modules to True to `_fuse_fx` correctly + # todo: check freezebn + self.sync_module_training_mode(graph_module, mode=True) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + fuse_custom_config=fuse_custom_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + prepare_custom_config=self.prepare_custom_config, + backend_config=self.backend_config) + for attr_name in preserved_attributes: + setattr(prepared, attr_name, getattr(model, attr_name)) + + return prepared + + def gen_qconfig_mapping(self, qconfig_mapping: Dict): + """Convert qconfig_mapping in config file to `QConfigMapping`. + + `QConfigMapping` is a custom class for mapping from model ops to + :class:`torch.ao.quantization.QConfig` s. + """ + conf = QConfigMapping() + if GLOBAL_DICT_KEY in qconfig_mapping: + qconfig = QConfigHandler( + qconfig_mapping[GLOBAL_DICT_KEY]).convert() + conf.set_global(qconfig) + + for object_type, qconfig in qconfig_mapping.get( + OBJECT_TYPE_DICT_KEY, []): + qconfig = QConfigHandler(qconfig).convert() + conf.set_object_type(str2class(object_type), qconfig) + + for module_name, qconfig in qconfig_mapping.get( + MODULE_NAME_DICT_KEY, []): + qconfig = QConfigHandler(qconfig).convert() + conf.set_module_name(module_name, qconfig) + + return conf + + def gen_prepare_custom_config(self, prepare_custom_config: Optional[Dict]): + """Convert prepare_custom_config in config file to + `PrepareCustomConfig`. + + `PrepareCustomConfig` is a custom class for custom configurating + :func:`~torch.ao.quantization.fx.prepare`. + """ + conf = PrepareCustomConfig() + if prepare_custom_config is None: + return conf + else: + for float_class_str, observed_class_str in prepare_custom_config.get( # noqa: E501 + FLOAT_TO_OBSERVED_DICT_KEY, []): + float_class = MODELS.get(float_class_str) + observed_class = MODELS.get(observed_class_str) + conf.set_float_to_observed_mapping(float_class, observed_class) + conf.set_preserved_attributes( + prepare_custom_config.get(PRESERVED_ATTRIBUTES_DICT_KEY, [])) + return conf diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py new file mode 100644 index 000000000..78c8163c7 --- /dev/null +++ b/mmrazor/models/quantizers/base.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import Dict + +import torch +import torch.nn as nn +from mmengine.model import BaseModule +from mmengine.model.utils import _BatchNormXd + +from mmrazor.registry import TASK_UTILS + + +class BaseQuantizer(BaseModule): + """Base class for quantizers. Its role for several subclass is as follows: + 1. Provide tracer for tracing model for all subclass. + 2. Define some common abstract methods, such as `prepare`. + 3. Provide some common functional interfaces, such as `swap_ff_with_fxff`. + + Args: + tracer (Dict): It can be used to trace the float model to generate the + corresponding graph, which contributes to prepare for quantizing + the float model with code-free. + """ + + def __init__(self, tracer: Dict): + super().__init__() + self.tracer = TASK_UTILS.build(tracer) + + def sync_module_training_mode(self, model, mode=True): + """Synchronize the training modes. + + Note that modes of conv and bn must be the same during ``_fuse_fx``. + """ + for module in model.modules(): + module.training = mode + return + + @staticmethod + def convert_batchnorm2d(model): + """Helper function to convert all :attr:`_BatchNormXd` layers and + :class:`torch.nn.SyncBatchNorm` layers in the model to + :class:`torch.nn.BatchNorm2d` layers. + """ + # todo: Convert all `_BatchNormXd` and `SyncBatchNorm` + # layers to `BatchNorm2d` layers but they may be :attr:`BatchNorm*D` + # layers + module_checklist = [nn.modules.batchnorm.SyncBatchNorm, _BatchNormXd] + + def traverse(module: nn.Module): + for child_name, child in module.named_children(): + if isinstance(child, tuple(module_checklist)): + bn = nn.BatchNorm2d(child.num_features, child.eps, + child.momentum, child.affine, + child.track_running_stats) + setattr(module, child_name, bn) + else: + traverse(child) + + traverse(model) + + @abstractmethod + def prepare(self, model): + """Prepare for quantizing model, which usually includes as follows: + + 1. Swap floatfunctional with FXFloatFunctional; + 2. Trace model to generate `GraphModule`; + 2. Fuse some OPs combination, such as conv + bn, conv + relu and so on; + 3. Swap some conv or linear module with QAT Modules which contain + weight fakequant nodes; + 4. Insert required fakequant nodes for activation. + 5. (Optional) Delete some redundant fakequant nodes according to the + special requirement of the backend for deployment. + """ + pass + + def swap_ff_with_fxff(self, model: torch.nn.Module): + """Swap FloatFunctional with FXFloatFunctional.""" + modules_to_swap = [] + for name, module in model.named_children(): + if isinstance(module, torch.ao.nn.quantized.FloatFunctional): + modules_to_swap.append(name) + else: + self.swap_ff_with_fxff(module) + + for name in modules_to_swap: + del model._modules[name] + model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() diff --git a/mmrazor/models/quantizers/exporters/__init__.py b/mmrazor/models/quantizers/exporters/__init__.py new file mode 100644 index 000000000..b8153289d --- /dev/null +++ b/mmrazor/models/quantizers/exporters/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .openvino_quantize_exporter import OpenVinoQuantizeExportor +from .tensorrt_quantize_exporter import TensorRTExplicitExporter + +__all__ = ['OpenVinoQuantizeExportor', 'TensorRTExplicitExporter'] diff --git a/mmrazor/models/quantizers/exporters/base_quantize_exporter.py b/mmrazor/models/quantizers/exporters/base_quantize_exporter.py new file mode 100644 index 000000000..6527d3207 --- /dev/null +++ b/mmrazor/models/quantizers/exporters/base_quantize_exporter.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import print_log + +from .optim_utils import ONNXOptimUtils + +try: + import onnx + from onnx import numpy_helper +except ImportError: + from mmrazor.utils import get_package_placeholder + onnx = get_package_placeholder('No module named onnx') + numpy_helper = get_package_placeholder('No module named onnx.numpy_helper') + +SUPPORT_QWEIGHT_NODE = ['Gemm', 'Conv', 'ConvTranspose'] + +PERCHANNEL_FAKEQUANTIZER = [ + 'FakeQuantizeLearnablePerchannelAffine', 'FixedPerChannelAffine' +] +PERTENSOR_FAKEQUANTIZER = ['LearnablePerTensorAffine', 'FixedPerTensorAffine'] + +ALL_FAKEQUANTIZER = PERCHANNEL_FAKEQUANTIZER + PERTENSOR_FAKEQUANTIZER + + +def _parse_attrs(node_attrs): + attrs = {} + for attr in node_attrs: + if attr.type == onnx.AttributeProto.AttributeType.INTS: + attrs[attr.name] = tuple(attr.ints) + elif attr.type == onnx.AttributeProto.AttributeType.INT: + attrs[attr.name] = attr.i + elif attr.type == onnx.AttributeProto.AttributeType.FLOATS: + attrs[attr.name] = tuple(attr.floats) + elif attr.type == onnx.AttributeProto.AttributeType.FLOAT: + attrs[attr.name] = attr.f + elif attr.type == onnx.AttributeProto.AttributeType.TENSOR: + attrs[attr.name] = numpy_helper.to_array(attr.t) + elif attr.type == onnx.AttributeProto.AttributeType.STRING: + attrs[attr.name] = str(attr.s) + elif attr.type == onnx.AttributeProto.AttributeType.STRINGS: + attrs[attr.name] = tuple([str(x) for x in attr.strings]) + else: + raise Exception('ATTR Type [{}] Not Supported!'.format(attr.type)) + return attrs + + +class BaseQuantizeExportor(): + + optimizer = ONNXOptimUtils + + def __init__(self, onnx_model, export_path) -> None: + + if isinstance(onnx_model, str): + self.onnx_model = onnx.load(onnx_model) + elif isinstance(onnx_model, onnx.ModelProto): + self.onnx_model = onnx_model + else: + raise TypeError + + self.export_path = export_path + self._init_mappings_from_onnx(self.onnx_model) + + self.optimizer.remove_fake_pad_op(self.onnx_model, self.name2data, + self.input2node, self.output2node) + + self._remap_input_and_node() + self._remap_output_and_node() + + @property + def graph(self): + """The onnx model's graph.""" + return self.onnx_model.graph + + def _init_mappings_from_onnx(self, onnx_model): + """Build necessary mappings in a onnx model.""" + + self.input2node = self.optimizer.map_input_and_node(onnx_model) + self.output2node = self.optimizer.map_output_and_node(onnx_model) + self.name2data = self.optimizer.map_name_and_data(onnx_model) + + def _remap_input_and_node(self): + """Rebuild the mapping from input name to a (node, input index) + tuple.""" + self.input2node = self.optimizer.map_input_and_node(self.onnx_model) + + def _remap_output_and_node(self): + """Rebuild the mapping from a node's output name to this node.""" + self.output2node = self.optimizer.map_output_and_node(self.onnx_model) + + def parse_qparams(self, node: onnx.NodeProto): + """Parse the quantize-related parameters based on a node.""" + tensor_name, scale, zero_point = node.input[:3] + + scale, zero_point = self.name2data[scale], self.name2data[zero_point] + if len(node.input) > 3: + qmin, qmax = node.input[-2:] + qmin, qmax = self.name2data[qmin], self.name2data[qmax] + elif len(node.attribute) > 0: + qparams = _parse_attrs(node.attribute) + qmin = qparams['quant_min'] + qmax = qparams['quant_max'] + else: + print_log(f'qmin and qmax are not found for <{node.name}>!') + qmax = qmin = None + return tensor_name, scale, zero_point, qmin, qmax + + def collect_symbolic_nodes(self, onnx_model: onnx.ModelProto): + """Collect all the fakequant nodes from a onnx model.""" + symbolic_nodes = list() + for node in onnx_model.graph.node: + if node.op_type in ALL_FAKEQUANTIZER: + symbolic_nodes.append(node) + return symbolic_nodes + + def _get_constant_inputs(self, node: onnx.NodeProto): + """Get the constant input node for the current node.""" + constant_nodes = list() + output2node = self.output2node + for inp in node.input: + if inp in output2node and output2node[inp].op_type == 'Constant': + cnode = output2node[inp] + + constant_nodes.append(cnode) + return constant_nodes + + def _collect_symbolic_constant_inputs(self, symbolic_nodes: List): + """Collect these constant nodes which is the input of all the symbolic + node.""" + + collected_constant_names = set() + constant_inputs = list() + for node in symbolic_nodes: + constant_inputs = self._get_constant_inputs(node) + for constant in constant_inputs: + if constant.name in collected_constant_names: + continue + constant_inputs.append(constant) + collected_constant_names.add(constant.name) + return constant_inputs + + def _remove_symbolic_related_from_onnx(self, symbolic_nodes: List, + symbolic_constant_inputs: List): + """Remove these out of date fakequant nodes and theirs constant input + nodes.""" + for node in symbolic_nodes: + self.onnx_model.graph.node.remove(node) + + # Remove symbolic related constant nodes. The constant node which is + # only used by those symbolic nodes can be removed. + + def _is_standalone_constant_node(constant): + for node in self.onnx_model.graph.node: + for input_name in node.input: + # A constant node always has one output. + if input_name == constant.output[0]: + return False + return True + + for constant in symbolic_constant_inputs: + if _is_standalone_constant_node(constant): + self.onnx_model.graph.node.remove(constant) + + def export(self): + """Export end to end onnx model.""" + # todo: is it a abstract method? + raise NotImplementedError diff --git a/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py b/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py new file mode 100644 index 000000000..e706251ca --- /dev/null +++ b/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py @@ -0,0 +1,159 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import List + +import numpy as np +from google.protobuf.internal.containers import RepeatedScalarFieldContainer + +try: + import onnx + from onnx import helper, numpy_helper +except ImportError: + from mmrazor.utils import get_package_placeholder + onnx = get_package_placeholder('No module named onnx') + numpy_helper = get_package_placeholder('No module named onnx.numpy_helper') + helper = get_package_placeholder('No module named onnx.helper') + +from .base_quantize_exporter import BaseQuantizeExportor + + +class OpenVinoQuantizeExportor(BaseQuantizeExportor): + + def __init__(self, onnx_model, export_path) -> None: + super().__init__(onnx_model, export_path) + + def _build_backend_node_from_symbolic(self, node: onnx.NodeProto, + tensor_name: str, qmin: np.ndarray, + qmax: np.ndarray): + """Build new onnx nodes which can be deployed to the specific backend. + + These nodes will be used to replace those symbolic nodes in the + original onnx model. + """ + qmax = int(qmax) + qmin = int(qmin) + levels = qmax - qmin + 1 + # adjust weight levels + # if levels == 128: + # levels = 256 + # qmax = qmax * 2 + 1 + # qmin = qmin * 2 + output_name = node.output[0] + # Create a node (FakeQuantize) + keys = ['input_min', 'input_max', 'output_min', 'output_max'] + input_names = [f'{tensor_name}_{key}' for key in keys] + backend_node = helper.make_node( + 'FakeQuantize', # node name + [tensor_name, *input_names], # inputs + [output_name], # outputs + levels=levels, # Attributes + domain='org.openvinotoolkit', + name=node.name) + return backend_node + + def _build_backend_initializer(self, + names: RepeatedScalarFieldContainer[str], + scale: np.ndarray, zero_point: np.ndarray, + qmin: np.ndarray, qmax: np.ndarray, + shape: List[int]): + """Build onnx initializers which can be deployed to specific + backend.""" + + scale = np.abs(np.asarray(scale, dtype=np.float64).reshape(-1)) + zero_point = np.clip( + np.asarray(np.round(zero_point), dtype=np.int32).reshape(-1), + a_min=qmin, + a_max=qmax) + + qrange = float(qmax - qmin) + input_range = scale * qrange + input_high = (qmax - zero_point).astype( + np.float64) * input_range / qrange + input_low = input_high - input_range + input_low_size = input_low.size + + if input_low_size != 1: + input_low = input_low.reshape(*shape) + input_high = input_high.reshape(*shape) + + input_low = input_low.astype(np.float32) + input_high = input_high.astype(np.float32) + + initializers = list() + for init_name, value_tensor in zip( + names, [input_low, input_high, input_low, input_high]): + init = numpy_helper.from_array(value_tensor) + init.name = init_name + initializers.append(init) + return initializers + + def build_backend_nodes_and_initializers(self, symbolic_nodes: List): + """Build new onnx nodes and initializers which can be deployed to + specific backend.""" + backend_nodes = list() + backend_initializers = list() + for node in symbolic_nodes: + tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams( + node) + new_node = self._build_backend_node_from_symbolic( + node, tensor_name, qmin, qmax) + backend_nodes.append(new_node) + + try: + # If the successor node (such as a conv node) has weight, + # we need get the length of the weight's shape. And ensure + # the length of the weight's shape and the new node's + # input shape (such as input_low and input_high) is the same. + next_node = self.input2node[node.output[0]][0][0] + # node for save weights + fake_node = self.output2node[next_node.input[1]] + tensor = self.name2data[fake_node.input[0]] + shape_length = len(tensor.shape) + new_shape = [-1] + [1] * (shape_length - 1) + except Exception: + new_shape = [-1] + + # The first element of new_node.input is the tensor name. + new_init_names = new_node.input[1:] + new_initializers = self._build_backend_initializer( + new_init_names, scale, zero_point, qmin, qmax, new_shape) + backend_initializers.extend(new_initializers) + return backend_nodes, backend_initializers + + def _insert_initializers_to_onnx(self, initializers: List): + """Insert onnx initializers to the onnx graph.""" + inserted_init_names = set() + for init in initializers: + if init.name in inserted_init_names: + continue + + self.onnx_model.graph.initializer.append(init) + inserted_init_names.add(init.name) + + def _replace_symbolic_related(self): + """Replacing symbolic related nodes and initializers in the original + onnx model with new nodes and initializers that can be deployed to the + specific backend.""" + + symbolic_nodes = self.collect_symbolic_nodes(self.onnx_model) + + collect_func = self._collect_symbolic_constant_inputs + # Usually different activation fakequants share the same constant + # input, and different weight fakequants share the same constant input. + symbolic_constant_inputs = collect_func(symbolic_nodes) + + build_func = self.build_backend_nodes_and_initializers + new_nodes, new_initializers = build_func(symbolic_nodes) + + self._insert_initializers_to_onnx(new_initializers) + + self._remove_symbolic_related_from_onnx(symbolic_nodes, + symbolic_constant_inputs) + + self.onnx_model.graph.node.extend(new_nodes) + self.optimizer.optimize(self.onnx_model) + + def export(self): + """Export end to end onnx model.""" + self._replace_symbolic_related() + onnx.save(self.onnx_model, self.export_path) diff --git a/mmrazor/models/quantizers/exporters/optim_utils.py b/mmrazor/models/quantizers/exporters/optim_utils.py new file mode 100644 index 000000000..f4adc5ee1 --- /dev/null +++ b/mmrazor/models/quantizers/exporters/optim_utils.py @@ -0,0 +1,265 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, List, Optional + +from mmengine import print_log + +try: + import onnx + from onnx import numpy_helper +except ImportError: + from mmrazor.utils import get_package_placeholder + onnx = get_package_placeholder('No module named onnx') + numpy_helper = get_package_placeholder('No module named onnx.numpy_helper') + + +class ONNXOptimUtils(): + + @classmethod + def map_name_and_data(cls, onnx_model: onnx.ModelProto): + """Build the mapping from a data's name to the data itself.""" + params = {} + for init in onnx_model.graph.initializer: + params[init.name] = numpy_helper.to_array(init) + for node in onnx_model.graph.node: + # If two zero_points are identity, one is a reference to the other + # after optimized by onnx. + if node.op_type == 'Identity' and len(node.input) == 1 and \ + node.input[0] in params: + params[node.output[0]] = copy.deepcopy(params[node.input[0]]) + if node.op_type == 'Constant': + for attr in node.attribute: + if attr.name == 'value': + params[node.output[0]] = numpy_helper.to_array(attr.t) + return params + + @classmethod + def map_name_and_initializer(cls, + onnx_model: onnx.ModelProto, + allow_redundant=True): + """Build the mapping from a initializer's output name to this + initializer.""" + + initializers = dict() + + for idx, init in enumerate(onnx_model.graph.initializer): + initializers[init.name] = (init, idx) + + return initializers + + @classmethod + def map_output_and_node(cls, onnx_model: onnx.ModelProto): + """Build the mapping from a node's output name to this node.""" + output2node = dict() + for node in onnx_model.graph.node: + for output_name in node.output: + output2node[output_name] = node + return output2node + + @classmethod + def map_input_and_node(cls, onnx_model: onnx.ModelProto): + """Build the mapping from input name to a (node, input index) tuple.""" + + input2node: Dict[str, List] = dict() + for node in onnx_model.graph.node: + for idx, input_name in enumerate(node.input): + if input_name not in input2node: + input2node[input_name] = [] + input2node[input_name].append([node, idx]) + return input2node + + @classmethod + def remove_node_from_onnx(cls, node: onnx.NodeProto, + onnx_model: onnx.ModelProto): + """Removes a node from node list.""" + onnx_model.graph.node.remove(node) + + @classmethod + def remove_initializer_from_onnx(cls, initializer: onnx.TensorProto, + onnx_model: onnx.ModelProto): + """Inserts the initializer at the specified position.""" + onnx_model.graph.initializer.remove(initializer) + + @classmethod + def remove_fake_pad_op(cls, onnx_model, name2data, inp2node, out2node): + nodes_to_be_removed = [] + for idx, node in enumerate(onnx_model.graph.node): + if node.op_type == 'Pad': + pads = name2data[node.input[1]] + if all([x == 0 for x in pads]): + print_log(f'Remove pad op: <{node.name}>.') + next_nodes = inp2node[node.output[0]] + for next_node, idx in next_nodes: + next_node.input[idx] = node.input[0] + nodes_to_be_removed.append(node) + + for node in nodes_to_be_removed: + onnx_model.graph.node.remove(node) + + @classmethod + def insert_node_to_onnx(cls, + node: onnx.NodeProto, + onnx_model: onnx.ModelProto, + idx: int = 0): + """Inserts the node at the specified position.""" + onnx_model.graph.node.insert(idx, node) + + @classmethod + def find_standalone_nodes(cls, + onnx_model: onnx.ModelProto, + input2node: Optional[Dict] = None, + output2node: Optional[Dict] = None): + """Find unused nodes.""" + + if input2node is None: + input2node = cls.map_input_and_node(onnx_model) + if output2node is None: + output2node = cls.map_output_and_node(onnx_model) + + def _is_standalone_node(node, input2node, output2node): + for input_name in node.input: + if input_name in output2node: + return False + + for out_node in node.output: + if out_node in input2node: + return False + + return True + + standalone_nodes = list() + for node in onnx_model.graph.node: + + if _is_standalone_node(node, input2node, output2node): + standalone_nodes.append(node) + return standalone_nodes + + @classmethod + def find_redundant_initializers(cls, + onnx_model: onnx.ModelProto, + input2node: Optional[Dict] = None): + """Find unused initializers.""" + if input2node is None: + input2node = cls.map_input_and_node(onnx_model) + + initializers = cls.map_name_and_initializer(onnx_model) + redundant_initializers = list() + redundant_set = set() + for name, init_and_idx in initializers.items(): + if name not in input2node and name not in redundant_set: + # init_and_idx[0] is onnx.onnx_ml_pb2.TensorProto + # init_and_idx[1] is a integer index + redundant_initializers.append(init_and_idx[0]) + redundant_set.add(name) + return redundant_initializers + + @classmethod + def topo_sort(cls, + onnx_model: onnx.ModelProto, + initializers: Optional[Dict] = None, + inplace: bool = True): + """Topologically sort the nodes in a directed acyclic graph. + + Note that nodes in a directed acyclic graph may be out of order + after replacing symbolic related nodes with new nodes. + + Args: + onnx_model (onnx.ModelProto): The onnx model to be sorted + topologically. + initializers (Dict | Optional): The mapping from name to + initializers. Default to None. + inplace (bool): Can optionally do the operation in-place. + Defaults to True. + """ + + if inplace: + _onnx_model = onnx_model + else: + _onnx_model = copy.deepcopy(onnx_model) + + if initializers is None: + initializers = cls.map_name_and_initializer( + _onnx_model, allow_redundant=True) + + # A node may have multiple outputs. The first output name of a node + # named `/conv/Conv` is `/conv/Conv_output_0` + output_name2node = {} + for node in _onnx_model.graph.node: + for output_name in node.output: + output_name2node[output_name] = node + for node in _onnx_model.graph.input: + output_name2node[node.name] = node + + name2node = {node.name: node for node in _onnx_model.graph.node} + + graph: Dict[str, + List] = {node.name: [] + for node in _onnx_model.graph.node} + for node in _onnx_model.graph.input: + graph[node.name] = [] + + indegree = {node.name: 0 for node in _onnx_model.graph.node} + + # Build graph + for i, node in enumerate(_onnx_model.graph.node): + for input_name in node.input: + if input_name not in initializers: + indegree[node.name] += 1 + prev_node = output_name2node[input_name] + graph[prev_node.name].append(node) + + graph_input = [node.name for node in _onnx_model.graph.input] + root = graph_input.copy() + sorted_nodes = [] + + # There are some nodes whose input are all initializers. + for node_name, in_degree in indegree.items(): + if in_degree == 0: + root.append(node_name) + + while root: + node_name = root.pop() + # There is no intersection between graph_input and + # _onnx_model.graph.node + if node_name not in graph_input: + node = name2node[node_name] + sorted_nodes.append(node) + for next_node in graph[node_name]: + indegree[next_node.name] -= 1 + if indegree[next_node.name] == 0: + root.append(next_node.name) + + num_nodes = len(_onnx_model.graph.node) + if len(sorted_nodes) != num_nodes: + raise RuntimeError('The graph is not a DAG.') + + for _ in range(num_nodes): + _onnx_model.graph.node.pop() + for node in sorted_nodes: + _onnx_model.graph.node.append(node) + + return _onnx_model + + @classmethod + def optimize(cls, onnx_model): + """Remove standalone nodes and redundant initializers, and + topologically sort the nodes in a directed acyclic graph.""" + + input2node = cls.map_input_and_node(onnx_model) + output2node = cls.map_output_and_node(onnx_model) + + standalone_nodes = cls.find_standalone_nodes(onnx_model, input2node, + output2node) + for node in standalone_nodes: + cls.remove_node_from_onnx(node, onnx_model) + print_log(f'Remove node {node.name}') + + redundant_inits = cls.find_redundant_initializers( + onnx_model, input2node) + for init in redundant_inits: + cls.remove_initializer_from_onnx(init, onnx_model) + print_log(f'Remove initializer {init.name}') + + sorted_onnx_model = cls.topo_sort(onnx_model) + + return sorted_onnx_model diff --git a/mmrazor/models/quantizers/exporters/tensorrt_quantize_exporter.py b/mmrazor/models/quantizers/exporters/tensorrt_quantize_exporter.py new file mode 100644 index 000000000..cde430b08 --- /dev/null +++ b/mmrazor/models/quantizers/exporters/tensorrt_quantize_exporter.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +try: + import onnx +except ImportError: + from mmrazor.utils import get_package_placeholder + onnx = get_package_placeholder('No module named onnx') + +from .base_quantize_exporter import BaseQuantizeExportor + + +class TensorRTExplicitExporter(BaseQuantizeExportor): + + def __init__(self, onnx_model, export_path) -> None: + super().__init__(onnx_model, export_path) + + def _build_backend_node_from_symbolic(self, node): + quantize_linear_node = onnx.helper.make_node( + 'QuantizeLinear', node.input[:3], [node.name + '_quantized_out'], + node.name + '_quantized') + dequantize_linear_node = onnx.helper.make_node( + 'DequantizeLinear', + [node.name + '_quantized_out'] + quantize_linear_node.input[1:3], + node.output, node.name + '_dequantized') + return [quantize_linear_node, dequantize_linear_node] + + def build_backend_nodes(self, symbolic_nodes): + backend_nodes = list() + for node in symbolic_nodes: + _, _, zero_point, qmin, qmax = self.parse_qparams(node) + assert qmax - qmin in ( + 2**8 - 1, 2**8 - + 2), 'Only 8 bit quantization support deployment to ONNX.' + assert not np.any(zero_point != 0), \ + 'This pass is only supposed to be used with TensorRT ' \ + 'Backend which does not support asymmetric quantization.' + new_nodes = self._build_backend_node_from_symbolic(node) + backend_nodes.extend(new_nodes) + return backend_nodes + + def export(self): + symbolic_nodes = self.collect_symbolic_nodes(self.onnx_model) + new_nodes = self.build_backend_nodes(symbolic_nodes) + for node in symbolic_nodes: + self.onnx_model.graph.node.remove(node) + self.onnx_model.graph.node.extend(new_nodes) + self.optimizer.optimize(self.onnx_model) + onnx.save(self.onnx_model, self.export_path) diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py new file mode 100644 index 000000000..7b6f2f9ad --- /dev/null +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -0,0 +1,446 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from mmengine.config import Config + +try: + from torch.ao.quantization import (disable_observer, enable_fake_quant, + enable_observer) + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.fx.graph_module import ObservedGraphModule + from torch.ao.quantization.qconfig_mapping import ( + _FIXED_QPARAMS_OP_TO_OBSERVER, FixedQParamsFakeQuantize, QConfig, + QConfigMapping, default_weight_fake_quant) + from torch.ao.quantization.quantize_fx import _fuse_fx + from torch.fx.graph_module import GraphModule + from torch.nn.intrinsic.qat import modules as qat_fused_modules + from torch.nn.qat import modules as qat_modules + from torch.onnx import register_custom_op_symbolic +except ImportError: + from mmrazor.utils import get_package_placeholder, get_placeholder + GraphModule = get_placeholder('torch>=1.13') + ObservedGraphModule = get_placeholder('torch>=1.13') + enable_fake_quant = get_placeholder('torch>=1.13') + disable_observer = get_placeholder('torch>=1.13') + enable_observer = get_placeholder('torch>=1.13') + prepare = get_placeholder('torch>=1.13') + QConfigMapping = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + qat_fused_modules = get_package_placeholder('torch>=1.13') + qat_modules = get_package_placeholder('torch>=1.13') + _FIXED_QPARAMS_OP_TO_OBSERVER = get_package_placeholder('torch>=1.13') + FixedQParamsFakeQuantize = get_package_placeholder('torch>=1.13') + QConfig = get_package_placeholder('torch>=1.13') + default_weight_fake_quant = get_package_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.task_modules.tracer import build_graphmodule +from mmrazor.models.task_modules.tracer.fx import ( + del_fakequant_after_function, del_fakequant_after_method, + del_fakequant_after_module, del_fakequant_after_op, + del_fakequant_before_function, del_fakequant_before_method, + del_fakequant_before_module, del_fakequant_before_op) +from mmrazor.models.utils import str2class +from mmrazor.registry import MODELS +from mmrazor.structures.quantization import BackendConfigs, QConfigHandler +from .base import BaseQuantizer + +if digit_version(torch.__version__) >= digit_version('1.13.0'): + SUPPORT_QAT_MODULES: Tuple = ( + qat_fused_modules.ConvBn1d, qat_fused_modules.ConvBn2d, + qat_fused_modules.ConvBn3d, qat_fused_modules.ConvBnReLU1d, + qat_fused_modules.ConvBnReLU2d, qat_fused_modules.ConvBnReLU3d, + qat_fused_modules.ConvReLU1d, qat_fused_modules.ConvReLU2d, + qat_fused_modules.ConvReLU3d, qat_fused_modules.LinearBn1d, + qat_fused_modules.LinearReLU, qat_modules.Conv1d, qat_modules.Conv2d, + qat_modules.Conv3d, qat_modules.Linear) + + MERGE_BN_MAPPINGS: Dict = { + qat_fused_modules.ConvBn1d: qat_modules.Conv1d, + qat_fused_modules.ConvBn2d: qat_modules.Conv2d, + qat_fused_modules.ConvBn3d: qat_modules.Conv3d, + qat_fused_modules.ConvBnReLU1d: qat_fused_modules.ConvReLU1d, + qat_fused_modules.ConvBnReLU2d: qat_fused_modules.ConvReLU2d, + qat_fused_modules.ConvBnReLU3d: qat_fused_modules.ConvReLU3d, + qat_fused_modules.LinearBn1d: qat_modules.Linear + } + + def fake_quantize_per_channel_affine(g, x, scale, zero_point, ch_axis, + quant_min, quant_max): + return g.op('mmrazor::FixedPerChannelAffine', x, scale, zero_point, + ch_axis, quant_min, quant_max) + + register_custom_op_symbolic('::fake_quantize_per_channel_affine', + fake_quantize_per_channel_affine, 11) + + def fake_quantize_per_tensor_affine(g, x, scale, zero_point, quant_min, + quant_max): + return g.op('mmrazor::FixedPerTensorAffine', x, scale, zero_point, + quant_min, quant_max) + + register_custom_op_symbolic('::fake_quantize_per_tensor_affine', + fake_quantize_per_tensor_affine, 11) + +else: + SUPPORT_QAT_MODULES = () + MERGE_BN_MAPPINGS = {} + + +@MODELS.register_module() +class TorchNativeQuantizer(BaseQuantizer): + """Native class for quantizer. + + Args: + global_qconfig (Union[Dict, Config]): Config for quantization details + of weight and activation include observer, quantizer, and qscheme. + no_observer_modules (Optional[List]): Modules don't need observer. + To fit different backend, we need qconfig to determine the modules + which don't need observer. + tracer (Dict): Config for tracer to trace modules for torch fx . + + Raises: + NotImplementedError: _description_ + + Examples: + >>> global_qconfig = dict( + ... w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + ... a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + ... w_fake_quant=dict(type='mmrazor.FakeQuantize'), + ... a_fake_quant=dict(type='mmrazor.FakeQuantize'), + ... w_qscheme=dict( + ... qdtype='qint8', bit=8, is_symmetry=True, + ... is_symmetric_range=True), + ... a_qscheme=dict( + ... qdtype='quint8', bit=8, is_symmetry=True, + ... averaging_constant=0.1), +) + """ + + def __init__(self, + global_qconfig: Union[Dict, Config], + no_observer_modules: Optional[List] = None, + tracer: Dict = dict(type='CustomTracer'), + extra_redundant_fakequants: Dict = dict( + extra_module_prev_wo_fakequant=tuple(), + extra_module_next_wo_fakequant=tuple(), + extra_function_prev_wo_fakequant=tuple(), + extra_function_next_wo_fakequant=tuple(), + extra_method_prev_wo_fakequant=tuple(), + extra_method_next_wo_fakequant=tuple(), + extra_op_prev_wo_fakequant=tuple(), + extra_op_next_wo_fakequant=tuple())): + super().__init__(tracer) + self.qconfig = QConfigHandler(global_qconfig) + if self.qconfig.w_qscheme.is_per_channel: + w_mode = 'per_channel' + else: + w_mode = 'per_tensor' + if self.qconfig.a_qscheme.is_per_channel: + a_mode = 'per_channel' + else: + a_mode = 'per_tensor' + assert w_mode in self.support_w_modes + assert a_mode in self.support_a_modes + + self.qconfig_mapping = self.gen_qconfig_mapping( + self.qconfig, no_observer_modules) + self.no_observer_modules = no_observer_modules + + self.backend_config = BackendConfigs[self.backend] + self.example_inputs = (torch.randn(1, 3, 224, 224), ) + + self.extra_redundant_fakequants = extra_redundant_fakequants + + def gen_qconfig_mapping(self, qconfig, no_observer_modules): + """Convert qconfig in config file to `QConfigMapping`. + + `QConfigMapping` is a custom class for mapping from model ops to + :class:`torch.ao.quantization.QConfig` s. + """ + qconfig_mapping = QConfigMapping().set_global(qconfig.convert()) + + if no_observer_modules is not None: + no_observer_modules = str2class(no_observer_modules) + for mod in no_observer_modules: + qconfig_mapping.set_object_type(mod, None) + + fixed_qparams_observer_to_qconfig = {} + for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items( + ): + if observer in fixed_qparams_observer_to_qconfig: + fixed_qparams_qconfig = fixed_qparams_observer_to_qconfig[ + observer] + else: + activation = FixedQParamsFakeQuantize.with_args( + observer=observer) + + fixed_qparams_qconfig = QConfig( + activation=activation, weight=default_weight_fake_quant) + fixed_qparams_observer_to_qconfig[ + observer] = fixed_qparams_qconfig + qconfig_mapping.set_object_type(fixed_qparams_op, + fixed_qparams_qconfig) + + return qconfig_mapping + + @property + def backend(self): + """The key of the corresponding backend config.""" + return 'native' + + @property + def support_w_modes(self): + """Supported quantization modes for weight about per_tensor or + per_channel.""" + return ('per_tensor', 'per_channel') + + @property + def support_a_modes(self): + """Supported quantization modes for activation about per_tensor or + per_channel.""" + return ('per_tensor') + + def export_onnx(self, model: Union[torch.nn.Module, torch.jit.ScriptModule, + torch.jit.ScriptFunction], + args: Union[Tuple[Any, ...], + torch.Tensor], output_path: str, **kwargs): + """Export the onnx model that can be deployed to a native backend.""" + torch.onnx.export(model, args, output_path, **kwargs) + + def prepare(self, model, concrete_args=None): + """prepare graph to ObservedGraphModule. + + Returns: + ObservedGraphModule: GraphModules after fuse and observer. + + Notes: + 'graph_module' after '_fuse_fx()' function will fuse conv, BN, ReLU + into modules in SUPPORT_QAT_MODULES. + 'graph_module' after 'prepare()' function will become observed. + + Notes: + Keep `is_qat` is True is because in Pytorch when `is_qat` is false, + the `_fuse_fx()` function only fuse module into `nn.Squential`. + In mmrazor, we aim to add more ptq algorithm into our pipeline such + as Adaround, these kind of ptq method have some additional + fake_quant operations that we need it to be fused into our + `SUPPORT_QAT_MODULES` type, which is a tricky way to deal with it. + """ + self.swap_ff_with_fxff(model) + traced_graph = self.tracer.trace(model, concrete_args=concrete_args) + graph_module = build_graphmodule(model, traced_graph) + + # set the training modes of all modules to True to `_fuse_fx` correctly + # todo: check freezebn + self.sync_module_training_mode(graph_module, mode=True) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + prepared = self.del_redundant_fakequant(prepared) + + return prepared + + def post_process_for_deploy(self, + observed_module: ObservedGraphModule, + device: str = 'cpu', + update_weight_with_fakequant: bool = False, + keep_w_fake_quant: bool = False): + """weight fake-quant for supported QAT modules. + + Args: + observed_module (ObservedGraphModule): Modules after fused and + observed. + keep_w_fake_quant (bool, optional): Bool to determine whether to + keep weight fake-quant op, depending on the backend. Defaults + to False. + + Note: + `post_process_weight_fakequant()` function is necessary that the + `SUPPORT_QAT_MODULES` will be convert to normal modules, and + BN will be really integrated into conv layers. + """ + + def traverse(module): + for name, child in module.named_children(): + # Trace `SUPPORT_QAT_MODULES` recursively. + if isinstance(child, SUPPORT_QAT_MODULES): + # We add w_fakequant once in case some ptq methods have + # specific operations such as Adaround. So we do Quantize + # to perform these operations and do dequantize to + # introduce quantization loss in advance. + weight_fakequant = child.weight_fake_quant + + # `to_float()` function fuse BN into conv or conv_relu, and + # also convert a qat module to a normal module. + # source url: https://github.com/pytorch/pytorch/blob/master/torch/nn/intrinsic/qat/modules/conv_fused.py # noqa: E501 + float_child = child.to_float() + + if update_weight_with_fakequant: + from torch.ao.nn.intrinsic import _FusedModule + if issubclass(type(float_child), _FusedModule): + float_child[0].weight.data = weight_fakequant( + float_child[0].weight.data) + else: + float_child.weight.data = weight_fakequant( + float_child.weight.data) + # This is decided by backend type, some backend need + # explicitly keep the fake quant structure, others don't. + # TODO add deploy doc link + if keep_w_fake_quant: + # make weight fakequant fixed as the consistent + # fakequant, it will help to deploy our model to + # various backends. + self.qconfig.fixed_w_fakequant() + for m in float_child.modules(): + setattr(m, 'qconfig', self.qconfig.convert()) + if type(child) in MERGE_BN_MAPPINGS: + cls = MERGE_BN_MAPPINGS[type(child)] + new_child = cls.from_float(float_child).to(device) + else: + new_child = type(child).from_float(float_child).to( + device) + + # because weight fakequants and observers are replaced + # with base fakequants and base observers, some + # initialized args need to be update by running + # weight_fake_quant. + enable_observer(new_child) + new_child.weight_fake_quant(new_child.weight) + disable_observer(new_child) + else: + new_child = float_child.to(device) + setattr(module, name, new_child) + else: + traverse(child) + + observed_module.apply(enable_fake_quant) + observed_module.apply(disable_observer) + traverse(observed_module) + + def del_redundant_fakequant(self, prepared: GraphModule): + """delete redundant fakequant op in prepared model. + + Returns: + prepared (GraphModule): prepared model after delete redundant + fakequant op. + + Notes: + We can configure different ways to delete redundant nodes: + @property + def module_prev_wo_fakequant(self): + return (torch.nn.ReLU6, torch.nn.Identity) + """ + extra_module_prev_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_module_prev_wo_fakequant', tuple()) + prepared = del_fakequant_before_module( + prepared, + self.module_prev_wo_fakequant + extra_module_prev_wo_fakequant, + inplace=True) + + extra_module_next_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_module_next_wo_fakequant', tuple()) + prepared = del_fakequant_after_module( + prepared, + self.module_next_wo_fakequant + extra_module_next_wo_fakequant, + inplace=True) + + extra_function_prev_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_function_prev_wo_fakequant', tuple()) + prepared = del_fakequant_before_function( + prepared, + self.function_prev_wo_fakequant + extra_function_prev_wo_fakequant, + inplace=True) + + extra_function_next_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_function_next_wo_fakequant', tuple()) + prepared = del_fakequant_after_function( + prepared, + self.function_next_wo_fakequant + extra_function_next_wo_fakequant, + inplace=True) + + extra_method_prev_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_method_prev_wo_fakequant', tuple()) + prepared = del_fakequant_before_method( + prepared, + self.method_prev_wo_fakequant + extra_method_prev_wo_fakequant, + inplace=True) + + extra_method_next_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_method_next_wo_fakequant', tuple()) + prepared = del_fakequant_after_method( + prepared, + self.method_next_wo_fakequant + extra_method_next_wo_fakequant, + inplace=True) + + extra_op_prev_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_op_prev_wo_fakequant', tuple()) + prepared = del_fakequant_before_op( + prepared, + self.op_prev_wo_fakequant + extra_op_prev_wo_fakequant, + inplace=True) + + extra_op_next_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_op_next_wo_fakequant', tuple()) + prepared = del_fakequant_after_op( + prepared, + self.op_next_wo_fakequant + extra_op_next_wo_fakequant, + inplace=True) + return prepared + + @property + def module_prev_wo_fakequant(self): + """Configurate the modules that their previous nodes are redundant + fakequants.""" + return tuple() + + @property + def module_next_wo_fakequant(self): + """Configurate the modules that their next nodes are redundant + fakequants.""" + return tuple() + + @property + def function_prev_wo_fakequant(self): + """Configurate the functions that their previous nodes are redundant + fakequants.""" + return tuple() + + @property + def function_next_wo_fakequant(self): + """Configurate the functions that their next nodes are redundant + fakequants.""" + return tuple() + + @property + def method_prev_wo_fakequant(self): + """Configurate the methods that their previous nodes are redundant + fakequants.""" + return tuple() + + @property + def method_next_wo_fakequant(self): + """Configurate the methods that their next nodes are redundant + fakequants.""" + return tuple() + + @property + def op_prev_wo_fakequant(self): + """Configurate the OPs that their previous nodes are redundant + fakequants.""" + return tuple() + + @property + def op_next_wo_fakequant(self): + """Configurate the OPs that their next nodes are redundant + fakequants.""" + return tuple() diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py new file mode 100644 index 000000000..8f5ef3873 --- /dev/null +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Optional, Tuple, Union + +import torch + +from mmrazor.registry import MODELS +from .native_quantizer import TorchNativeQuantizer + + +@MODELS.register_module() +class OpenVINOQuantizer(TorchNativeQuantizer): + """Quantizer for quantizing and deploying to Openvino backend. + + Each backend has its own features, for reducing the gap of quantized + performance between before and after deployment as possible, we should + match the backend's features in quantization. + + Openvino's some important features about quantization is as follows: + * support_w_mode = ('per_tensor', 'per_channel') + * support_a_mode = ('per_tensor') + * weight range should be symmetric, such as int 8 is [-127, 127] rather + than [-128, 127] + """ + + @property + def backend(self): + """The backend to deploy, also the key of the corresponding backend + config.""" + return 'openvino' + + @property + def support_w_modes(self): + """Supported quantization modes for weight about per_tensor or + per_channel.""" + return ('per_tensor', 'per_channel') + + @property + def support_a_modes(self): + """Supported quantization modes for activation about per_tensor or + per_channel.""" + return ('per_tensor') + + def export_onnx(self, + model: Union[torch.nn.Module, torch.jit.ScriptModule, + torch.jit.ScriptFunction], + args: Union[Tuple[Any, ...], torch.Tensor], + output_path: str, + opset_version: Optional[int] = 11, + **kwargs): + """Export the onnx model that can be deployed to OpenVino backend.""" + + symbolic_output_path = output_path.replace('.onnx', '_symbolic.onnx') + torch.onnx.export( + model, + args, + symbolic_output_path, + opset_version=opset_version, + **kwargs) + + from .exporters import OpenVinoQuantizeExportor + exporter = OpenVinoQuantizeExportor(symbolic_output_path, output_path) + exporter.export() + + @property + def module_prev_wo_fakequant(self): + """Configurate the modules that their previous nodes are redundant + fakequants.""" + return (torch.nn.ReLU6, torch.nn.Identity) + + @property + def module_next_wo_fakequant(self): + """Configurate the modules that their next nodes are redundant + fakequants.""" + return (torch.nn.MaxPool2d, ) + + @property + def method_next_wo_fakequant(self): + """Configurate the methods that their next nodes are redundant + fakequants.""" + return ('flatten', ) + + @property + def op_prev_wo_fakequant(self): + """Configurate the OPs that their previous nodes are redundant + fakequants.""" + return ('output', ) diff --git a/mmrazor/models/quantizers/tensorrt_quantizer.py b/mmrazor/models/quantizers/tensorrt_quantizer.py new file mode 100644 index 000000000..be067fd4f --- /dev/null +++ b/mmrazor/models/quantizers/tensorrt_quantizer.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Optional, Tuple, Union + +import torch + +from mmrazor.registry import MODELS +from .native_quantizer import TorchNativeQuantizer + + +@MODELS.register_module() +class TensorRTQuantizer(TorchNativeQuantizer): + """Quantizer for quantizing and deploying to TensorRT backend. + + Each backend has its own features, for reducing the gap of quantized + performance between before and after deployment as possible, we should + match the backend's features in quantization. + + TensorRT's some important features about quantization is as follows: + * support_w_mode = ('per_tensor', 'per_channel') + * support_a_mode = ('per_tensor') + """ + + @property + def backend(self): + """The backend to deploy, also the key of the corresponding backend + config.""" + return 'tensorrt' + + @property + def support_w_modes(self): + """Supported quantization modes for weight about per_tensor or + per_channel.""" + return ('per_tensor', 'per_channel') + + @property + def support_a_modes(self): + """Supported quantization modes for activation about per_tensor or + per_channel.""" + return ('per_tensor') + + def export_onnx(self, + model: Union[torch.nn.Module, torch.jit.ScriptModule, + torch.jit.ScriptFunction], + args: Union[Tuple[Any, ...], torch.Tensor], + output_path: str, + opset_version: Optional[int] = 13, + **kwargs): + """Export the onnx model that can be deployed to OpenVino backend.""" + + symbolic_output_path = output_path.replace('.onnx', '_symbolic.onnx') + torch.onnx.export( + model, + args, + symbolic_output_path, + opset_version=opset_version, + **kwargs) + + from .exporters import TensorRTExplicitExporter + exporter = TensorRTExplicitExporter(symbolic_output_path, output_path) + exporter.export() + + @property + def module_prev_wo_fakequant(self): + """Configurate the modules that their previous nodes are redundant + fakequants.""" + return (torch.nn.ReLU6, torch.nn.Identity) + + @property + def module_next_wo_fakequant(self): + """Configurate the modules that their next nodes are redundant + fakequants.""" + return (torch.nn.MaxPool2d, ) + + @property + def method_next_wo_fakequant(self): + """Configurate the methods that their next nodes are redundant + fakequants.""" + return ('flatten', ) + + @property + def op_prev_wo_fakequant(self): + """Configurate the OPs that their previous nodes are redundant + fakequants.""" + return ('output', ) diff --git a/mmrazor/models/task_modules/tracer/__init__.py b/mmrazor/models/task_modules/tracer/__init__.py index c3ff8dd66..987030d81 100644 --- a/mmrazor/models/task_modules/tracer/__init__.py +++ b/mmrazor/models/task_modules/tracer/__init__.py @@ -2,6 +2,8 @@ from .backward_tracer import BackwardTracer from .channel_analyzer import ChannelAnalyzer # from .razor_tracer import RazorFxTracer +from .fx import (CustomTracer, UntracedMethodRegistry, build_graphmodule, + custom_symbolic_trace) from .loss_calculator import * # noqa: F401,F403 from .parsers import * # noqa: F401,F403 from .path import (Path, PathConcatNode, PathConvNode, PathDepthWiseConvNode, @@ -10,5 +12,6 @@ __all__ = [ 'BackwardTracer', 'PathConvNode', 'PathLinearNode', 'PathNormNode', 'PathConcatNode', 'Path', 'PathList', 'PathNode', 'PathDepthWiseConvNode', - 'ChannelAnalyzer' + 'ChannelAnalyzer', 'CustomTracer', 'UntracedMethodRegistry', + 'custom_symbolic_trace', 'build_graphmodule' ] diff --git a/mmrazor/models/task_modules/tracer/fx/__init__.py b/mmrazor/models/task_modules/tracer/fx/__init__.py new file mode 100644 index 000000000..82f723f10 --- /dev/null +++ b/mmrazor/models/task_modules/tracer/fx/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .custom_tracer import (CustomTracer, UntracedMethodRegistry, + build_graphmodule, custom_symbolic_trace) +from .graph_utils import (del_fakequant_after_function, + del_fakequant_after_method, + del_fakequant_after_module, del_fakequant_after_op, + del_fakequant_before_function, + del_fakequant_before_method, + del_fakequant_before_module, del_fakequant_before_op) + +__all__ = [ + 'CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace', + 'build_graphmodule', 'del_fakequant_before_module', + 'del_fakequant_after_module', 'del_fakequant_after_function', + 'del_fakequant_before_function', 'del_fakequant_after_op', + 'del_fakequant_before_op', 'del_fakequant_before_method', + 'del_fakequant_after_method' +] diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py new file mode 100644 index 000000000..68d5f0809 --- /dev/null +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -0,0 +1,477 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +from copy import deepcopy +from types import FunctionType +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.nn as nn + +try: + from torch._C import ScriptObject # type: ignore[attr-defined] + from torch.ao.quantization.quantize_fx import QuantizationTracer + from torch.fx import Graph, GraphModule, Tracer + from torch.fx._symbolic_trace import (_autowrap_check, + _patch_wrapped_functions, _Patcher) + from torch.fx.proxy import Proxy +except ImportError: + from mmrazor.utils import get_placeholder + ScriptObject = get_placeholder('torch>=1.13') + QuantizationTracer = get_placeholder('torch>=1.13') + GraphModule = get_placeholder('torch>=1.13') + Tracer = get_placeholder('torch>=1.13') + Graph = get_placeholder('torch>=1.13') + _autowrap_check = get_placeholder('torch>=1.13') + _patch_wrapped_functions = get_placeholder('torch>=1.13') + _Patcher = get_placeholder('torch>=1.13') + Proxy = get_placeholder('torch>=1.13') + +from mmengine.utils import import_modules_from_strings + +from mmrazor.registry import TASK_UTILS + +_orig_module_call: Callable = nn.Module.__call__ +_orig_module_getattr: Callable = nn.Module.__getattr__ + + +class UntracedMethodRegistry: + """A `Descriptor` class which records untraced methods. Thus, when the + class is traced with CustomTracer, the decorated method will be as a leaf + node, not be nested traced. + + Example: + >>> # `imported_cls` is the owner of the untraced method; + >>> # `method_str` is the name of the untraced method. + >>> method_registry = UntracedMethodRegistry(method) + >>> method_registry.__set_name__(imported_cls, method_str) + + Args: + method (FunctionType): Function to be registered. + """ + method_dict: Dict = dict() + tracer = None + + def __init__(self, method: FunctionType): + self.method = method + self.owner = None + + def __set_name__(self, owner, name): + self.owner = owner + self.name = name + wrapped = self.method_wrapper() + self.method_dict[name] = dict(mod=self.owner, wrapped=wrapped) + + def method_wrapper(self): + + @functools.wraps(self.method) + def wrapped_method(mod, *args, **kwargs): + + def method(*args, **kwargs): + return self.method(mod, *args, **kwargs) + + return self.tracer.call_method(mod, self.name, method, args, + kwargs) + + return wrapped_method + + +def _prepare_module_dict(model: torch.nn.Module, fx_graph): + """If there is a class method that can not be traced by the symbolic + tracer, a ``call_method`` ``Node`` will be inserted into the ``Graph`` in + ``CustomTracer``. + + Example: + >>> class Model: + ... def __init__(self): + ... self.head = ClsHead() + ... + >>> class ClsHead(nn.Module): + ... def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + ... return feats[-1] + ... + ... def loss(self, feats: Tuple[torch.Tensor], + ... data_samples: List[ClsDataSample], **kwargs) -> dict: + ... cls_score = self(feats) + ... # The part can not be traced by torch.fx + ... losses = self._get_loss(cls_score, data_samples, **kwargs) + ... return losses + ... + ... def _get_loss(self, cls_score: torch.Tensor, + ... data_samples: List[ClsDataSample], **kwargs): + ... if 'score' in data_samples[0].gt_label: + ... xxx + ... else: + ... xxx + ... losses = xxx + ... return losses + + As the ``_get_loss`` can not be traced by torch.fx, ``Toy._get_loss`` need + to be added to ``skipped_methods`` in ``CustomTracer``. Hence the code + above will product the following Graph:: + + .. code-block:: text + ... ... + %head : [#users=1] = get_attr[target=head] + %_get_loss : [#users=1] = call_method[target=_get_loss](args = (%head, %head_fc, %data_samples), kwargs = {}) # noqa: E501 + return _get_loss + + Hence, the head module in the ``GraphModule`` and that in the original + model are the same one (refer to https://github.com/pytorch/pytorch/blob/master/torch/fx/graph_module.py#L346). # noqa: E501 + So changes made to the graph module (in ``prepare()``) will also modify + the original model. + + Args: + model (torch.nn.Module): Module or function to be + traced and converted into a Graph representation. + fx_graph (torch.fx.Graph): The fx Graph traced by fx tracer. It + contains the nodes this GraphModule should use for code generation. + """ + + def _get_attrs(target, attrs): + attrs = attrs.split('.') + for att in attrs: + target = getattr(target, att) + return target + + module_dict = dict() + special_nodes = [] + + for node in fx_graph.nodes: + if node.op == 'get_attr': + attr = _get_attrs(model, node.target) + if isinstance(attr, nn.Module): + module_dict[node.target] = nn.Module() + special_nodes.append(node) + elif node.op == 'call_method': + for special_node in special_nodes: + if special_node in node.args or \ + special_node in node.kwargs.values(): + origin_module = getattr(model, special_node.target) + setattr(module_dict[special_node.target], node.target, + getattr(origin_module, node.target)) + + return module_dict + + +def duplicate_reused_nodes(graph: Graph, modules: Dict[str, Any] = {}): + """Deepcopy the shared modules (e.g. shared detection head in RetinaNet) to + make sure modules can be fused correctly. + + Modified from https://github.com/ModelTC/MQBench/blob/main/mqbench/prepare_by_platform.py # noqa: E501 + """ + _dup_prefix = '_dup' + target_dict = dict() + dup_modules = dict() + for node in graph.nodes: + if node.op == 'call_module': + if node.target not in target_dict: + target_dict[node.target] = [node] + else: + target_dict[node.target].append(node) + for key in target_dict: + if len(target_dict[key]) > 1: + for idx, node in enumerate(target_dict[key]): + if idx == 0: + continue + module = deepcopy(modules[node.target]) + node.target += _dup_prefix + str(idx) + dup_modules[node.target] = module + graph.lint() + return graph, dup_modules + + +def build_graphmodule(model: torch.nn.Module, + fx_graph, + name: str = 'GraphModule'): + """To build GraphModule with the generated graph by CustomTracer. The + implement of skipping methods in CustomTracer will cause the confliction of + that a node is both a leaf node and non-leaf node, which will lead that the + modification to the ``graph`` also change the original ``forward``. + + Args: + model (torch.nn.Module): Module or function to be + traced and converted into a Graph representation. + fx_graph (torch.fx.Graph): The fx Graph traced by fx tracer. It + contains the nodes this GraphModule should use for code generation. + name (str): The name of generated GraphModule. + + Returns: + GraphModule: GraphModule is an nn.Module generated from an fx.Graph. + Graphmodule has a ``graph`` attribute, as well as ``code`` and + ``forward`` attributes generated from that ``graph``. + + .. warning:: + When ``graph`` is reassigned, ``code`` and ``forward`` will be + automatically regenerated. However, if you edit the contents of the + ``graph`` without reassigning the ``graph`` attribute itself, you must + call ``recompile()`` to update the generated code. + """ + modules = dict(model.named_modules()) + module_dict = _prepare_module_dict(model, fx_graph) + fx_graph, duplicated_modules = duplicate_reused_nodes(fx_graph, modules) + modules.update(module_dict) + modules.update(duplicated_modules) + return GraphModule(modules, fx_graph, name) + + +@TASK_UTILS.register_module() +class CustomTracer(QuantizationTracer): + """Custom tracer based on QuantizationTracer of pytorch. It can not only + skip some modules and classes while tracing, but also skip some methods + untraced by torch.fx.Tracer. + + Args: + skipped_methods (List[str], optional): Methods to be skipped while + tracing. Defaults to None. + skipped_module_names (List[str], optional): Modules to be skipped + while tracing. Defaults to None. + skipped_module_classes (List[Callable], optional): Class to be skipped + while tracing. Defaults to None. + """ + + def __init__(self, + skipped_methods: List[str] = [], + skipped_module_names: List[str] = [], + skipped_module_classes: List[Callable] = [], + *args, + **kwargs): + super(CustomTracer, self).__init__(skipped_module_names, + skipped_module_classes) + UntracedMethodRegistry.tracer = self # type: ignore + self.skipped_methods = skipped_methods + if self.skipped_methods: + self.register_skipped_methods() + + @staticmethod + def _check_valid_source(source): + """Check if the source's format is valid.""" + if not isinstance(source, str): + raise TypeError(f'source should be a str ' + f'instance, but got {type(source)}') + + assert len(source.split('.')) > 1, \ + 'source must have at least one `.`' + + def register_skipped_methods(self): + """Register skipped methods to UntracedMethodRegistry.method_dict.""" + if not isinstance(self.skipped_methods, list): + self.skipped_methods = [self.skipped_methods] + for s_method in self.skipped_methods: + self._check_valid_source(s_method) + mod_str = '.'.join(s_method.split('.')[:-2]) + cls_str = s_method.split('.')[-2] + method_str = s_method.split('.')[-1] + + try: + mod = import_modules_from_strings(mod_str) + except ImportError: + raise ImportError(f'{mod_str} is not imported correctly.') + + imported_cls: type = getattr(mod, cls_str) + if not isinstance(imported_cls, type): + raise TypeError(f'{cls_str} should be a type ' + f'instance, but got {type(imported_cls)}') + assert hasattr(imported_cls, method_str), \ + f'{method_str} is not in {mod_str}.' + + method = getattr(imported_cls, method_str) + + method_registry = UntracedMethodRegistry(method) + method_registry.__set_name__(imported_cls, method_str) + + def call_method(self, m: torch.nn.Module, name: str, method: Callable, + args: Tuple, kwargs: Dict): + """Method that specifies the behavior of this ``Tracer`` when it + encounters a call to an ``nn.Module`` instance. + + By default, the behavior is to check if the called module is a leaf + module via ``is_leaf_module``. If it is, emit a ``call_module`` + node referring to ``m`` in the ``Graph``. Otherwise, call the + ``Module`` normally, tracing through the operations in its ``forward`` + function. + + This method can be overridden to--for example--create nested traced + GraphModules, or any other behavior you would want while tracing across + ``Module`` boundaries. + + Args: + m (torch.nn.Module): The module for which a call is being emitted + name (str): The name of proxy to be created. + method (Callable): The method of the ``Module`` to be invoked + args (Tuple): args of the module callsite + kwargs (Dict): kwargs of the module callsite + + Return: + + The return value from the Module call. In the case that a + ``call_module`` node was emitted, this is a ``Proxy`` value. + Otherwise, it is whatever value was returned from the ``Module`` + invocation. + """ + # module_qualified_name = self.path_of_module(m) + if not self.is_skipped_method(m): + return method(*args, **kwargs) + args_l = list(args) + args_l.insert(0, m) + args = tuple(args_l) + return self.create_proxy('call_method', name, args, kwargs) + + def trace(self, + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None) -> Graph: + """Trace ``root`` and return the corresponding FX ``Graph`` + representation. ``root`` can either be an ``nn.Module`` instance or a + Python callable. Note that after this call, ``self.root`` may be + different from the ``root`` passed in here. For example, when a free + function is passed to ``trace()``, we will create an ``nn.Module`` + instance to use as the root and add embedded constants to. + + Args: + root (Union[Module, Callable]): Either a ``Module`` or a function + to be traced through. Backwards-compatibility for this + parameter is guaranteed. + concrete_args (Optional[Dict[str, any]]): Concrete arguments that + should not be treated as Proxies. This parameter is + experimental and its backwards-compatibility is *NOT* + guaranteed. + + Returns: + A ``Graph`` representing the semantics of the passed-in ``root``. + """ + if isinstance(root, torch.nn.Module): + self.root = root + fn = type(root).forward + self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = { + mod: name + for name, mod in root.named_modules() + } + else: + self.root = nn.Module() + fn = root + + tracer_cls: Optional[Type['Tracer']] = getattr(self, '__class__', None) + self.graph = Graph(tracer_cls=tracer_cls) + + # When we encounter a Tensor value that's not a parameter, we look if + # it is some other attribute on the model. Construct a dict mapping + # Tensor values to the qualified name here for efficiency. This is + # used downstream in create_arg + self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {} + + def collect_tensor_attrs(m: nn.Module, prefix_atoms: List[str]): + for k, v in m.__dict__.items(): + if isinstance(v, (torch.Tensor, ScriptObject)): + self.tensor_attrs[v] = '.'.join(prefix_atoms + [k]) + for k, v in m.named_children(): + collect_tensor_attrs(v, prefix_atoms + [k]) + + collect_tensor_attrs(self.root, []) + + assert isinstance(fn, FunctionType) + + fn_globals = fn.__globals__ # run before it gets patched + fn, args = self.create_args_for_root(fn, isinstance(root, nn.Module), + concrete_args) + + # Reduce number of get_attr calls + parameter_proxy_cache: Dict[str, Proxy] = {} + + # Method dispatch on parameters is not recorded unless it's directly + # used. Thus, we need to insert a proxy when __getattr__ requests a + # parameter. + @functools.wraps(_orig_module_getattr) + def module_getattr_wrapper(mod, attr): + attr_val = _orig_module_getattr(mod, attr) + return self.getattr(attr, attr_val, parameter_proxy_cache) + + @functools.wraps(_orig_module_call) + def module_call_wrapper(mod, *args, **kwargs): + + def forward(*args, **kwargs): + return _orig_module_call(mod, *args, **kwargs) + + _autowrap_check( + patcher, + getattr(getattr(mod, 'forward', mod), '__globals__', {}), + self._autowrap_function_ids) + return self.call_module(mod, forward, args, kwargs) + + with _Patcher() as patcher: + # allow duplicate patches to support the case of nested calls + patcher.patch_method( + nn.Module, + '__getattr__', + module_getattr_wrapper, + deduplicate=False) + patcher.patch_method( + nn.Module, '__call__', module_call_wrapper, deduplicate=False) + + for name, value in UntracedMethodRegistry.method_dict.items(): + wrapped = value['wrapped'] + patcher.patch_method( + value['mod'], name, wrapped, deduplicate=False) + + _patch_wrapped_functions(patcher) + _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) + for module in self._autowrap_search: + _autowrap_check(patcher, module.__dict__, + self._autowrap_function_ids) + self.create_node( + 'output', + 'output', (self.create_arg(fn(*args)), ), {}, + type_expr=fn.__annotations__.get('return', None)) + + self.submodule_paths = None + + return self.graph + + def is_skipped_method(self, m: torch.nn.Module): + """Judge if ``m`` is registered skipped method.""" + mods = tuple(value['mod'] + for value in UntracedMethodRegistry.method_dict.values()) + custom = isinstance(m, mods) + return custom + + def is_leaf_module(self, m: torch.nn.Module, + module_qualified_name: str) -> bool: + """A method to specify whether a given ``nn.Module`` is a "leaf" + module. Leaf modules are the atomic units that appear in the IR, + referenced by ``call_module`` calls. By default, Modules in the PyTorch + standard library namespace (torch.nn) are leaf modules. All other + modules are traced through and their constituent ops are recorded, + unless specified otherwise via this parameter. + + Args: + m (Module): The module being queried about + module_qualified_name (str): The path to root of this module. + For example, if you have a module hierarchy where submodule + ``foo`` contains submodule ``bar``, which contains submodule + ``baz``, that module will appear with the qualified name + ``foo.bar.baz`` here. + """ + leaf = super().is_leaf_module(m, module_qualified_name) + return leaf + + +def custom_symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule: + """Modified `symbolic_trace` function in pytorch. Given an ``nn.Module`` or + function instance ``root``, this function will return a ``GraphModule`` + constructed by recording operations seen while tracing through ``root``. + + Args: + root (torch.nn.Module): Module or function to be + traced and converted into a Graph representation. + concrete_args (Optional[Dict[str, any]]): Inputs to be partially + specialized. + + Returns: + GraphModule: a Module created from the recorded operations from + ``root``. + """ + tracer = CustomTracer() + graph = tracer.trace(root, concrete_args) + name = root.__class__.__name__ if isinstance( + root, torch.nn.Module) else root.__name__ + return GraphModule(tracer.root, graph, name) diff --git a/mmrazor/models/task_modules/tracer/fx/graph_utils.py b/mmrazor/models/task_modules/tracer/fx/graph_utils.py new file mode 100644 index 000000000..ca1291711 --- /dev/null +++ b/mmrazor/models/task_modules/tracer/fx/graph_utils.py @@ -0,0 +1,387 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Any, List, Tuple + +import torch + +try: + from torch.ao.quantization.fake_quantize import FakeQuantizeBase + from torch.fx import Node +except ImportError: + from mmrazor.utils import get_placeholder + FakeQuantizeBase = get_placeholder('torch>=1.13') + Node = get_placeholder('torch>=1.13') + + +def _get_attrs(target: torch.nn.Module, attr: str) -> Any: + """Get the attribute from target. + + Args: + target (torch.nn.Module): Get the attribute from target module. + attr (str): The target attribute. + + Returns: + Any: The target attribute. + """ + + attrs: List[str] = attr.split('.') + + for att in attrs: + target = getattr(target, att, None) + return target + + +def recursive_find_erased_nodes(node, prepared_model): + """Find FakeQuant before target node recursively. + + Examples: + head_fc = self.head.fc(activation_post_process_87); \ + activation_post_process_87 = None + activation_post_process_88 = \ + self.activation_post_process_88(head_fc); head_fc = None + head = self.head + _get_loss = head._get_loss(activation_post_process_88, + data_samples); \ + head = activation_post_process_88 = data_samples = None + return _get_loss + + node | node.args + -------------------- + output | (_get_loss, ) + _get_loss | (head, activation_post_process_88, + data_samples) + head | () + activation_post_process_88 | (head_fc, ) + data_samples | (None, ) + """ + if node is None: + return [] + + if node.op == 'call_module' and isinstance( + _get_attrs(prepared_model, node.target), FakeQuantizeBase): + return [node] + + nodes_to_erase = [] + for prev_node in node.args: + if isinstance(prev_node, Node): + nodes_to_erase.extend( + recursive_find_erased_nodes(prev_node, prepared_model)) + for prev_node in node.kwargs.values(): + if isinstance(prev_node, Node): + nodes_to_erase.extend( + recursive_find_erased_nodes(prev_node, prepared_model)) + + return nodes_to_erase + + +def del_fakequant_before_op(prepared_model, + target_ops: Tuple, + inplace: bool = True): + """Delete useless fakequant before nodes whose ``op`` attribute (node.op) + is in `target_ops`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_ops (tuple): Fakequants before nodes whose op attribute + (node.op) is in `target_ops` will be deleted. + inplace (bool): Can optionally do the operation in-place. Defaults to + True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ + + if not inplace: + prepared_model = copy.deepcopy(prepared_model) + new_graph = copy.deepcopy(prepared_model.graph) + for node in new_graph.nodes: + if node.op in target_ops: + nodes_to_erase: List[Node] = recursive_find_erased_nodes( + node, prepared_model) + for to_erase in nodes_to_erase: + assert to_erase.op == 'call_module' and isinstance( + _get_attrs(prepared_model, to_erase.target), + FakeQuantizeBase) and len(to_erase.args) == 1 + to_erase.replace_all_uses_with(to_erase.args[0]) + new_graph.erase_node(to_erase) + delattr(prepared_model, to_erase.target) + + new_graph.lint() + prepared_model.graph = new_graph + return prepared_model + + +def del_fakequant_after_op(prepared_model, + target_ops: Tuple, + inplace: bool = True): + """Delete useless fakequant after nodes whose ``op`` attribute (node.op) is + in `target_ops`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_ops (tuple): Fakequants after nodes whose op attribute + (node.op) is in `target_ops` will be deleted. + inplace (bool): Can optionally do the operation in-place. Defaults to + True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ + if not inplace: + prepared_model = copy.deepcopy(prepared_model) + new_graph = copy.deepcopy(prepared_model.graph) + + target_nodes = [] + for node in new_graph.nodes: + if node.op in target_ops: + target_nodes.append(node) + + for node in new_graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared_model, node.target), FakeQuantizeBase): + assert len(node.args) == 1 + prev_node = node.args[0] + if prev_node not in target_nodes: + continue + node.replace_all_uses_with(prev_node) + new_graph.erase_node(node) + delattr(prepared_model, node.target) + + new_graph.lint() + prepared_model.graph = new_graph + return prepared_model + + +def del_fakequant_before_method(prepared_model, + method_patterns: Tuple, + inplace: bool = True): + """Delete useless fakequant before nodes whose op attribute (node.op) is + `call_method` and target attribute (node.target) is in `target_patterns`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_patterns (tuple): Fakequants before nodes whose op attribute + (node.op) is `call_method` and target attribute (node.target) is + in `target_patterns` will be deleted. + inplace (bool): Can optionally do the operation in-place. Defaults to + True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ + if not inplace: + prepared_model = copy.deepcopy(prepared_model) + new_graph = copy.deepcopy(prepared_model.graph) + for node in new_graph.nodes: + if node.op == 'call_method' and node.target in method_patterns: + nodes_to_erase: List[Node] = recursive_find_erased_nodes( + node, prepared_model) + for to_erase in nodes_to_erase: + assert to_erase.op == 'call_module' and isinstance( + _get_attrs(prepared_model, to_erase.target), + FakeQuantizeBase) and len(to_erase.args) == 1 + to_erase.replace_all_uses_with(to_erase.args[0]) + new_graph.erase_node(to_erase) + delattr(prepared_model, to_erase.target) + + new_graph.lint() + prepared_model.graph = new_graph + return prepared_model + + +def del_fakequant_after_method(prepared_model, + method_patterns: Tuple, + inplace: bool = True): + """Delete useless fakequant after nodes whose op attribute (node.op) is + `call_method` and target attribute (node.target) is in `target_patterns`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_patterns (tuple): Fakequants after nodes whose op attribute + (node.op) is `call_method` and target attribute (node.target) + is in `target_patterns` will be deleted. + inplace (bool): Can optionally do the operation in-place. Defaults to + True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ + if not inplace: + prepared_model = copy.deepcopy(prepared_model) + new_graph = copy.deepcopy(prepared_model.graph) + + target_nodes = [] + for node in new_graph.nodes: + if node.op == 'call_method' and node.target in method_patterns: + target_nodes.append(node) + + for node in new_graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared_model, node.target), FakeQuantizeBase): + assert len(node.args) == 1 + prev_node = node.args[0] + if prev_node not in target_nodes: + continue + node.replace_all_uses_with(prev_node) + new_graph.erase_node(node) + delattr(prepared_model, node.target) + + new_graph.lint() + prepared_model.graph = new_graph + return prepared_model + + +def del_fakequant_before_function(prepared_model, + function_patterns: Tuple, + inplace: bool = True): + """Delete useless fakequant before nodes whose op attribute (node.op) is + `call_function` and target attribute (node.target) is in `target_patterns`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_patterns (tuple): Fakequants before nodes whose op attribute + (node.op) is `call_function` and target attribute (node.target) is + in `target_patterns` will be deleted. + inplace (bool): Can optionally do the operation in-place. Defaults to + True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ + if not inplace: + prepared_model = copy.deepcopy(prepared_model) + new_graph = copy.deepcopy(prepared_model.graph) + for node in new_graph.nodes: + if node.op == 'call_function' and node.target in function_patterns: + nodes_to_erase: List[Node] = recursive_find_erased_nodes( + node, prepared_model) + for to_erase in nodes_to_erase: + assert to_erase.op == 'call_module' and isinstance( + _get_attrs(prepared_model, to_erase.target), + FakeQuantizeBase) and len(to_erase.args) == 1 + to_erase.replace_all_uses_with(to_erase.args[0]) + new_graph.erase_node(to_erase) + delattr(prepared_model, to_erase.target) + + new_graph.lint() + prepared_model.graph = new_graph + return prepared_model + + +def del_fakequant_after_function(prepared_model, + function_patterns: Tuple, + inplace: bool = True): + """Delete useless fakequant after nodes whose op attribute (node.op) is + `call_function` and target attribute (node.target) is in `target_patterns`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + function_patterns (tuple): Fakequants after nodes whose op attribute + (node.op) is `call_function` and target attribute (node.target) is + in `target_patterns` will be deleted. + inplace (bool): Can optionally do the operation in-place. Defaults to + True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ + if not inplace: + prepared_model = copy.deepcopy(prepared_model) + new_graph = copy.deepcopy(prepared_model.graph) + + target_nodes = [] + for node in new_graph.nodes: + if node.op == 'call_function' and node.target in function_patterns: + target_nodes.append(node) + + for node in new_graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared_model, node.target), FakeQuantizeBase): + assert len(node.args) == 1 + prev_node = node.args[0] + if prev_node not in target_nodes: + continue + node.replace_all_uses_with(prev_node) + new_graph.erase_node(node) + delattr(prepared_model, node.target) + + new_graph.lint() + prepared_model.graph = new_graph + return prepared_model + + +def del_fakequant_before_module(prepared_model, + module_patterns: Tuple, + inplace: bool = True): + """Delete useless fakequant before modules whose type are in + `module_patterns`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_patterns (tuple): Fakequants before modules whose type is in + `module_patterns` will be deleted. + inplace (bool): Can optionally do the operation in-place. + Defaults to True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ + if not inplace: + prepared_model = copy.deepcopy(prepared_model) + new_graph = copy.deepcopy(prepared_model.graph) + for node in new_graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared_model, node.target), module_patterns): + to_erase = node.args[0] + if not (to_erase.op == 'call_module' and isinstance( + _get_attrs(prepared_model, to_erase.target), + FakeQuantizeBase)): + continue + to_erase.replace_all_uses_with(to_erase.args[0]) + new_graph.erase_node(to_erase) + delattr(prepared_model, to_erase.target) + + new_graph.lint() + prepared_model.graph = new_graph + return prepared_model + + +def del_fakequant_after_module(prepared_model, + module_patterns: Tuple, + inplace: bool = True): + """Delete useless fakequant after modules whose type are in + `module_patterns`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_patterns (tuple): Fakequants after modules whose type is in + `module_patterns` will be deleted. + inplace (bool): Can optionally do the operation in-place. + Defaults to True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ + if not inplace: + prepared_model = copy.deepcopy(prepared_model) + new_graph = copy.deepcopy(prepared_model.graph) + target_nodes = [] + for node in new_graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared_model, node.target), module_patterns): + target_nodes.append(node) + + for node in new_graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared_model, node.target), FakeQuantizeBase): + assert len(node.args) == 1 + prev_node = node.args[0] + if prev_node not in target_nodes: + continue + node.replace_all_uses_with(prev_node) + new_graph.erase_node(node) + delattr(prepared_model, node.target) + + new_graph.lint() + prepared_model.graph = new_graph + return prepared_model diff --git a/mmrazor/models/utils/__init__.py b/mmrazor/models/utils/__init__.py index 1b3eea2a1..e3be94946 100644 --- a/mmrazor/models/utils/__init__.py +++ b/mmrazor/models/utils/__init__.py @@ -3,9 +3,11 @@ from .misc import add_prefix from .optim_wrapper import reinitialize_optim_wrapper_count_status from .parse_values import parse_values +from .quantization_util import pop_rewriter_function_record, str2class from .utils import get_module_device, set_requires_grad __all__ = [ - 'add_prefix', 'reinitialize_optim_wrapper_count_status', 'make_divisible', - 'get_module_device', 'set_requires_grad', 'parse_values' + 'make_divisible', 'add_prefix', 'reinitialize_optim_wrapper_count_status', + 'str2class', 'get_module_device', 'set_requires_grad', 'parse_values', + 'pop_rewriter_function_record' ] diff --git a/mmrazor/models/utils/quantization_util.py b/mmrazor/models/utils/quantization_util.py new file mode 100644 index 000000000..36d108372 --- /dev/null +++ b/mmrazor/models/utils/quantization_util.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.utils import import_modules_from_strings + + +def pop_rewriter_function_record(rewriter_context, function_record_to_pop): + """Delete user-specific rewriters from `RewriterContext._rewriter_manager`. + + We use the model which is rewritten by mmdeploy to build quantized models. + However not all the functions rewritten by mmdeploy need to be rewritten in + mmrazor. For example, mmdeploy rewrite + `mmcls.models.classifiers.ImageClassifier.forward` and + `mmcls.models.classifiers.BaseClassifier.forward` for deployment. But they + can't be rewritten by mmrazor as ptq and qat are done in mmrazor. So to + ensure ptq and qat proceed normally, we have to remove these record from + `RewriterContext._rewriter_manager`. + """ + function_record_backup = {} + for record in function_record_to_pop: + records = rewriter_context._rewriter_manager.function_rewriter. \ + _registry._rewrite_records + if record in records: + function_record_backup[record] = records.pop(record) + return function_record_backup + + +def _check_valid_source(source): + """Check if the source's format is valid.""" + if not isinstance(source, str): + raise TypeError(f'source should be a str ' + f'instance, but got {type(source)}') + + assert len(source.split('.')) > 1, \ + 'source must have at least one `.`' + + +def str2class(str_inputs): + clss = [] + if not isinstance(str_inputs, tuple) and not isinstance(str_inputs, list): + str_inputs_list = [str_inputs] + else: + str_inputs_list = str_inputs + for s_class in str_inputs_list: + _check_valid_source(s_class) + mod_str = '.'.join(s_class.split('.')[:-1]) + cls_str = s_class.split('.')[-1] + try: + mod = import_modules_from_strings(mod_str) + except ImportError: + raise ImportError(f'{mod_str} is not imported correctly.') + imported_cls: type = getattr(mod, cls_str) + if not isinstance(imported_cls, type): + raise TypeError(f'{cls_str} should be a type ' + f'instance, but got {type(imported_cls)}') + clss.append(imported_cls) + if isinstance(str_inputs, list): + return clss + elif isinstance(str_inputs, tuple): + return tuple(clss) + else: + return clss[0] diff --git a/mmrazor/structures/__init__.py b/mmrazor/structures/__init__.py index 6dfcfbdc8..7f15c5d45 100644 --- a/mmrazor/structures/__init__.py +++ b/mmrazor/structures/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .quantization import * # noqa: F401,F403 from .subnet import * # noqa: F401,F403 diff --git a/mmrazor/structures/quantization/__init__.py b/mmrazor/structures/quantization/__init__.py new file mode 100644 index 000000000..cbf28034f --- /dev/null +++ b/mmrazor/structures/quantization/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .backend_config import * # noqa: F401,F403 +from .qconfig import * # noqa: F401,F403 diff --git a/mmrazor/structures/quantization/backend_config/__init__.py b/mmrazor/structures/quantization/backend_config/__init__.py new file mode 100644 index 000000000..151968f8d --- /dev/null +++ b/mmrazor/structures/quantization/backend_config/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .academic import (get_academic_backend_config, + get_academic_backend_config_dict) +from .mapping import BackendConfigs +from .native import get_native_backend_config, get_native_backend_config_dict +from .openvino import (get_openvino_backend_config, + get_openvino_backend_config_dict) +from .tensorrt import (get_tensorrt_backend_config, + get_tensorrt_backend_config_dict) + +__all__ = [ + 'BackendConfigs', + 'get_native_backend_config', + 'get_native_backend_config_dict', + 'get_academic_backend_config', + 'get_academic_backend_config_dict', + 'get_openvino_backend_config', + 'get_openvino_backend_config_dict', + 'get_tensorrt_backend_config', + 'get_tensorrt_backend_config_dict', +] diff --git a/mmrazor/structures/quantization/backend_config/academic.py b/mmrazor/structures/quantization/backend_config/academic.py new file mode 100644 index 000000000..6b4f0d598 --- /dev/null +++ b/mmrazor/structures/quantization/backend_config/academic.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +try: + from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') + +from .common_operator_config_utils import (_get_conv_configs, + _get_linear_configs) + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_academic_backend_config() -> BackendConfig: + """Return the `BackendConfig` for academic reseaching. + + Note: + Learn more about BackendConfig, please refer to: + https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501 + """ + + # =================== + # | DTYPE CONFIGS | + # =================== + # weighted op int8 dtype config + # this is config for ops that has quantized weights, like linear, conv + weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + + conv_dtype_configs = [weighted_op_int8_dtype_config] + linear_dtype_configs = [weighted_op_int8_dtype_config] + + return BackendConfig('academic') \ + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \ + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + + +def get_academic_backend_config_dict(): + """Return the `BackendConfig` for academic reseaching in dictionary + form.""" + return get_academic_backend_config().to_dict() + + +__all__ = [ + 'get_academic_backend_config', + 'get_academic_backend_config_dict', +] diff --git a/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py b/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py new file mode 100644 index 000000000..0a381d5d0 --- /dev/null +++ b/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py @@ -0,0 +1,639 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import operator +from collections import namedtuple +from typing import List + +import torch +import torch.nn as nn + +from mmrazor import digit_version + +try: + import torch.nn.functional as F + import torch.nn.intrinsic as nni + import torch.nn.intrinsic.qat as nniqat + import torch.nn.qat as nnqat + import torch.nn.quantized._reference as nnqr + from torch.ao.quantization.backend_config import (BackendPatternConfig, + DTypeConfig, + ObservationType) + from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize + from torch.ao.quantization.fuser_method_mappings import ( + fuse_conv_bn, fuse_conv_bn_relu, fuse_convtranspose_bn, fuse_linear_bn, + reverse2, reverse3, reverse_sequential_wrapper2) + from torch.ao.quantization.qconfig_mapping import \ + _FIXED_QPARAMS_OP_TO_OBSERVER +except ImportError: + from mmrazor.utils import get_package_placeholder, get_placeholder + F = get_package_placeholder('torch>=1.13') + nni = get_package_placeholder('torch>=1.13') + nniqat = get_package_placeholder('torch>=1.13') + nnqat = get_package_placeholder('torch>=1.13') + nnqr = get_package_placeholder('torch>=1.13') + BackendPatternConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') + ObservationType = get_placeholder('torch>=1.13') + FixedQParamsFakeQuantize = get_placeholder('torch>=1.13') + fuse_conv_bn = get_placeholder('torch>=1.13') + fuse_conv_bn_relu = get_placeholder('torch>=1.13') + fuse_convtranspose_bn = get_placeholder('torch>=1.13') + fuse_linear_bn = get_placeholder('torch>=1.13') + reverse2 = get_placeholder('torch>=1.13') + reverse3 = get_placeholder('torch>=1.13') + reverse_sequential_wrapper2 = get_placeholder('torch>=1.13') + _FIXED_QPARAMS_OP_TO_OBSERVER = get_placeholder('torch>=1.13') + +_ConvMetadata = namedtuple('_ConvMetadata', [ + 'root', 'transpose', 'bn', 'reference', 'transpose_reference', + 'fused_conv_relu', 'fused_conv_bn', 'fused_conv_bn_relu', 'qat', + 'relu_qat', 'bn_qat', 'bn_relu_qat', 'func' +]) + +if digit_version(torch.__version__) >= digit_version('1.13.0'): + _Conv1dMetadata = _ConvMetadata( + nn.Conv1d, nn.ConvTranspose1d, nn.BatchNorm1d, nnqr.Conv1d, + nnqr.ConvTranspose1d, nni.ConvReLU1d, nni.ConvBn1d, nni.ConvBnReLU1d, + nnqat.Conv1d, nniqat.ConvReLU1d, nniqat.ConvBn1d, nniqat.ConvBnReLU1d, + F.conv1d) + _Conv2dMetadata = _ConvMetadata( + nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nnqr.Conv2d, + nnqr.ConvTranspose2d, nni.ConvReLU2d, nni.ConvBn2d, nni.ConvBnReLU2d, + nnqat.Conv2d, nniqat.ConvReLU2d, nniqat.ConvBn2d, nniqat.ConvBnReLU2d, + F.conv2d) + _Conv3dMetadata = _ConvMetadata( + nn.Conv3d, nn.ConvTranspose3d, nn.BatchNorm3d, nnqr.Conv3d, + nnqr.ConvTranspose3d, nni.ConvReLU3d, nni.ConvBn3d, nni.ConvBnReLU3d, + nnqat.Conv3d, nniqat.ConvReLU3d, nniqat.ConvBn3d, nniqat.ConvBnReLU3d, + F.conv3d) +else: + toy_val = _ConvMetadata(*[i for i in range(13)]) + _Conv1dMetadata = toy_val + _Conv2dMetadata = toy_val + _Conv3dMetadata = toy_val + + +def _get_binary_op_configs( + dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + binary_op_configs: List[BackendPatternConfig] = [] + num_tensor_args_to_observation_type_mapping = { + # TODO: this is not used right now since we have extra check in prepare + # will need to change this to NO_OBSERVER later after we implemented + # Tensor dtype inference properly + 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, + 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + } + for op_with_quantized_bop_scalar_variant in [ + operator.add, torch.add, operator.mul, torch.mul + ]: + bop_patterns = [(torch.nn.ReLU, op_with_quantized_bop_scalar_variant), + (torch.nn.functional.relu, + op_with_quantized_bop_scalar_variant), + (torch.relu, op_with_quantized_bop_scalar_variant), + op_with_quantized_bop_scalar_variant] + for bop_pattern in bop_patterns: + binary_op_configs.append( + BackendPatternConfig(bop_pattern).set_dtype_configs( + dtype_configs) # noqa: E131 + ._set_num_tensor_args_to_observation_type( + num_tensor_args_to_observation_type_mapping)) + # matmul + binary_op_configs.append( + BackendPatternConfig(torch.matmul).set_dtype_configs( + dtype_configs) # noqa: E131 + ) + return binary_op_configs + + +def _get_linear_configs( + dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + """Return all configs related to linear modules and ops.""" + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + linear_configs: List[BackendPatternConfig] = [] + + # (1) Single linear modules/functions + # ------------------------------------- + # linear module + linear_configs.append( + BackendPatternConfig(torch.nn.Linear).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + torch.nn.Linear).set_reference_quantized_module( + nnqr.Linear).set_qat_module(nnqat.Linear)) + # linear qat module + linear_configs.append( + BackendPatternConfig(nnqat.Linear).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + torch.nn.Linear).set_reference_quantized_module(nnqr.Linear)) + # functional linear + linear_configs.append( + BackendPatternConfig(torch.nn.functional.linear).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)._set_input_type_to_index({ + 'weight': 1, + 'bias': 2 + })) + + # (2) Linear + relu + # ------------------- + # 2.1 linear module + relu fusion config + # linear relu, linear module + relu module + linear_configs.append( + BackendPatternConfig( + (torch.nn.ReLU, + torch.nn.Linear)).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(reverse_sequential_wrapper2( + nni.LinearReLU)).set_fused_module(nni.LinearReLU)) + # linear relu, linear module + functional relu + linear_configs.append( + BackendPatternConfig( + (torch.nn.functional.relu, + torch.nn.Linear)).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(reverse_sequential_wrapper2( + nni.LinearReLU)).set_fused_module(nni.LinearReLU)) + + # 2.2 linear module + relu, fused module configs + # linear relu, fused module + linear_configs.append( + BackendPatternConfig(nni.LinearReLU).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + torch.nn.Linear).set_reference_quantized_module( + nnqr.Linear).set_qat_module(nniqat.LinearReLU)) + # linear relu, qat fused module + linear_configs.append( + BackendPatternConfig(nniqat.LinearReLU).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + torch.nn.Linear).set_reference_quantized_module(nnqr.Linear)) + # 2.3 functional linear + relu configs + # linear relu, functional linear + relu module + linear_configs.append( + BackendPatternConfig( + (torch.nn.ReLU, + F.linear)).set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)) + # linear relu, functional linear + functional relu + linear_configs.append( + BackendPatternConfig( + (F.relu, + F.linear)).set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)) + + # (3) Linear + batchnorm + # ------------------------ + # 3.1 linear bn fusion + linear_configs.append( + BackendPatternConfig( + (nn.BatchNorm1d, + nn.Linear)).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(reverse2(fuse_linear_bn)).set_fused_module( + nni.LinearBn1d)) + + # 3.2 linear bn fused + # linear bn, fused module + linear_configs.append( + BackendPatternConfig(nni.LinearBn1d).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + torch.nn.Linear).set_reference_quantized_module( + nnqr.Linear).set_qat_module(nniqat.LinearBn1d)) + # linear bn, qat fused module + linear_configs.append( + BackendPatternConfig(nniqat.LinearBn1d).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + torch.nn.Linear).set_reference_quantized_module(nnqr.Linear)) + return linear_configs + + +def _get_conv_configs(dtype_configs): + """Return all configs related to conv modules and ops.""" + conv_configs = [] + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + for convs in [_Conv1dMetadata, _Conv2dMetadata, _Conv3dMetadata]: + + # (1) Single conv modules/functions + # ----------------------------------- + # conv module + conv_configs.append( + BackendPatternConfig(convs.root).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + convs.root).set_reference_quantized_module( + convs.reference).set_qat_module(convs.qat)) + # conv qat module + conv_configs.append( + BackendPatternConfig(convs.qat).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + convs.root).set_reference_quantized_module(convs.reference)) + # functional conv + conv_configs.append( + BackendPatternConfig(convs.func).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)._set_input_type_to_index({ + 'weight': + 1, + 'bias': + 2 + })) + + # (2) Conv + relu + # ----------------- + # 2.1 conv module + relu fusion configs + # conv relu fusion, conv module + relu module + conv_configs.append( + BackendPatternConfig( + (torch.nn.ReLU, + convs.root)).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method( + reverse_sequential_wrapper2( + convs.fused_conv_relu)).set_fused_module( + convs.fused_conv_relu)) + # conv relu fusion, conv module + functional relu + conv_configs.append( + BackendPatternConfig( + (F.relu, + convs.root)).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method( + reverse_sequential_wrapper2( + convs.fused_conv_relu)).set_fused_module( + convs.fused_conv_relu)) + # 2.2 conv module + relu fused module configs + # conv relu, fused module + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + convs.root).set_reference_quantized_module( + convs.reference).set_qat_module(convs.relu_qat)) + # conv relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.relu_qat).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + convs.root).set_reference_quantized_module(convs.reference)) + # 2.3 functional conv + relu configs + # conv relu, functional conv + relu module + conv_configs.append( + BackendPatternConfig( + (torch.nn.ReLU, convs.func)).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)) + # conv relu, functional conv + functional relu + conv_configs.append( + BackendPatternConfig((F.relu, convs.func)).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)) + + # fused conv relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu).set_dtype_configs( + dtype_configs) # noqa: E131 + .set_qat_module(convs.relu_qat)) + + conv_configs.append( + BackendPatternConfig(convs.relu_qat).set_dtype_configs( + dtype_configs) # noqa: E131 + .set_root_module(convs.root).set_reference_quantized_module( + convs.reference)) + + # (3) Conv + batchnorm (+ relu) + # ------------------------------- + # 3.1 conv bn fusion configs + # conv + bn fusion + conv_configs.append( + BackendPatternConfig( + (convs.bn, + convs.root)).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(reverse2(fuse_conv_bn)).set_fused_module( + convs.fused_conv_bn)) + # conv + bn + relu module fusion + conv_configs.append( + BackendPatternConfig( + (nn.ReLU, + (convs.bn, + convs.root))).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(reverse3(fuse_conv_bn_relu)).set_fused_module( + convs.fused_conv_bn_relu)) + # conv + bn + relu functional fusion + conv_configs.append( + BackendPatternConfig( + (F.relu, + (convs.bn, + convs.root))).set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root).set_fuser_method( + reverse3(fuse_conv_bn_relu)).set_fused_module( + convs.fused_conv_bn_relu)) + # TODO: we can add fusion for torch.relu as well + + # 3.2 conv + bn (+ relu) fused module configs + # fused conv bn + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn).set_dtype_configs( + dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_qat)) + + # fused conv bn relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn_relu).set_dtype_configs( + dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_relu_qat)) + + # conv bn, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_qat).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + convs.root).set_reference_quantized_module(convs.reference)) + # conv bn relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_relu_qat).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + convs.root).set_reference_quantized_module(convs.reference)) + + # (4) conv transpose and its fusion + # 4.1 conv transpose config + conv_configs.append( + BackendPatternConfig(convs.transpose).set_dtype_configs( + dtype_configs) # noqa: E131 + .set_root_module(convs.transpose).set_reference_quantized_module( + convs.transpose_reference)) + + # 4.2 conv transpose + bn fusion + conv_configs.append( + BackendPatternConfig( + (convs.bn, convs.transpose)).set_dtype_configs( + dtype_configs) # noqa: E131 + .set_fuser_method(reverse2(fuse_convtranspose_bn)).set_root_module( + convs.transpose).set_reference_quantized_module( + convs.transpose_reference)) + + return conv_configs + + +def _get_cat_config(dtype_configs: List[DTypeConfig]) -> BackendPatternConfig: + return BackendPatternConfig(torch.cat) \ + .set_observation_type( + ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \ + .set_dtype_configs(dtype_configs) + + +def _get_ln_configs( + dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + ln_configs = [] + ln_configs.append( + BackendPatternConfig(torch.nn.LayerNorm).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs)) + ln_configs.append( + BackendPatternConfig( + torch.nn.functional.layer_norm).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs)._set_input_type_to_index({ + 'weight': 2, + 'bias': 3 + })) + return ln_configs + + +def _get_default_op_configs( + dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + configs = [] + default_ops = [ + torch.nn.ELU, + torch.nn.LeakyReLU, + torch.nn.Hardswish, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.Dropout, + torch.nn.PReLU, + torch.nn.functional.elu, + torch.nn.functional.hardswish, + torch.nn.functional.leaky_relu, + torch.nn.functional.dropout, + ] + for op in default_ops: + configs.append( + BackendPatternConfig(op).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs)) + + configs.append( + BackendPatternConfig( + torch.nn.functional.group_norm).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs)._set_input_type_to_index({ + 'weight': 2, + 'bias': 3 + })) + + configs.append( + BackendPatternConfig( + torch.nn.functional.instance_norm).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs)._set_input_type_to_index({ + 'weight': 3, + 'bias': 4 + })) + return configs + + +def _get_fixed_qparams_op_configs( + dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + fixed_qparams_op_configs = [] + op_to_obs = _FIXED_QPARAMS_OP_TO_OBSERVER.items() + for fixed_qparam_op, output_observer in op_to_obs: + fixed_qparams_op_configs.append( + # TODO: The _overwrite_output keys are temporary, since we don't + # want to put observer in the configs we expect that it's provided + # by user What we want to put here is the requirement on observers, + # in this case dtype, quant_min, quant_max etc., but we need to + # first move all configs to backend_config_dict to do that, we'll + # remove these keys after we fully migrated everything to use + # backend_config_dict + BackendPatternConfig(fixed_qparam_op).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs). + _set_overwrite_output_fake_quantize( + FixedQParamsFakeQuantize.with_args(observer=output_observer) + )._set_overwrite_output_observer(output_observer)) + return fixed_qparams_op_configs + + +def _get_share_qparams_op_configs(dtype_configs): + """Get the operator config for the operators that works for both float and + quantized input if input is quantized, the output Tensor shares the same + quantization parameter with input. Example operator: avgpool2d, reshape, + transpose, maxpool2d Example observed operator: + + observer_0 - avgpool2d - observer_0 (same observer instance as input) + """ + + def _get_share_qprams_op_backend_config(op): + return BackendPatternConfig(op) \ + .set_observation_type( + ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \ + .set_dtype_configs(dtype_configs) + + share_qparams_ops = [ + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.Hardtanh, + torch.nn.Identity, + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.ReLU, + torch.adaptive_avg_pool1d, + torch.nn.functional.adaptive_avg_pool2d, + torch.nn.functional.adaptive_avg_pool3d, + torch.nn.functional.hardtanh, + torch.nn.functional.hardtanh_, + torch.nn.functional.interpolate, + torch.nn.functional.max_pool1d, + torch.nn.functional.max_pool2d, + torch.nn.functional.max_pool3d, + torch.nn.functional.relu, + torch.nn.functional.relu6, + torch.avg_pool1d, + torch._C._nn.avg_pool2d, + torch._C._nn.avg_pool3d, + torch.clamp, + torch.flatten, + torch.mean, + torch.repeat_interleave, + torch.transpose, + torch.squeeze, + torch.stack, + torch.unsqueeze, + operator.floordiv, + 'contiguous', + 'clamp', + 'detach', + 'detach_', + 'mean', + 'permute', + 'repeat', + 'repeat_interleave', + 'reshape', + 'resize_', + 'relu', + 'relu_', + 'shape', + 'size', + 'squeeze', + 'squeeze_', + 'transpose', + 'unsqueeze', + 'unsqueeze_', + 'view', + ] + return [ + _get_share_qprams_op_backend_config(op) for op in share_qparams_ops + ] + + +def _get_bn_configs( + dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + """Get configs related to batchnorm.""" + bn_configs = [] + bn_to_fused_bn = { + torch.nn.BatchNorm2d: nni.BNReLU2d, + torch.nn.BatchNorm3d: nni.BNReLU3d, + } + for bn in bn_to_fused_bn.keys(): + fused_bn = bn_to_fused_bn[bn] + # bn module + relu module fusion config + bn_configs.append( + BackendPatternConfig( + (torch.nn.ReLU, + bn)).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(reverse_sequential_wrapper2( + fused_bn)).set_fused_module(fused_bn)) + # bn module + F.relu fusion config + bn_configs.append( + BackendPatternConfig( + (torch.nn.functional.relu, + bn)).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(reverse_sequential_wrapper2( + bn_to_fused_bn[bn])).set_fused_module(fused_bn)) + bn_configs.append( + BackendPatternConfig(bn).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs)) + + # fused bn configs + for fused_bn in bn_to_fused_bn.values(): + bn_configs.append( + BackendPatternConfig(fused_bn).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs)) + return bn_configs + + +def _get_rnn_op_configs( + dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + rnn_op_configs = [] + for rnn_op, ref_rnn_op in [(nn.GRUCell, nnqr.GRUCell), + (nn.LSTMCell, nnqr.LSTMCell), + (nn.RNNCell, nnqr.RNNCell), + (nn.LSTM, nnqr.LSTM)]: + rnn_op_configs.append( + BackendPatternConfig(rnn_op).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + rnn_op).set_reference_quantized_module(ref_rnn_op)) + return rnn_op_configs + + +def _get_embedding_op_configs( + dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + embedding_op_configs = [] + for embedding_op, qat_embedding_op, ref_embedding_op in [ + (nn.Embedding, nnqat.Embedding, nnqr.Embedding), + (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag), + ]: + embedding_op_configs.append( + BackendPatternConfig(embedding_op).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs).set_qat_module(qat_embedding_op). + set_root_module(embedding_op).set_reference_quantized_module( + ref_embedding_op)._set_input_output_observed( + False)) # This is temporary, and will be removed soon + # config for qat op + embedding_op_configs.append( + BackendPatternConfig(qat_embedding_op).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + embedding_op).set_reference_quantized_module( + ref_embedding_op)._set_input_output_observed( + False)) # This is temporary, and will be removed soon + return embedding_op_configs + + +__all__ = [ + '_get_binary_op_configs', + '_get_linear_configs', + '_get_conv_configs', + '_get_share_qparams_op_configs', +] diff --git a/mmrazor/structures/quantization/backend_config/mapping.py b/mmrazor/structures/quantization/backend_config/mapping.py new file mode 100644 index 000000000..b9cc5372b --- /dev/null +++ b/mmrazor/structures/quantization/backend_config/mapping.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmrazor import digit_version +from .academic import get_academic_backend_config +from .native import get_native_backend_config +from .openvino import get_openvino_backend_config +from .tensorrt import get_tensorrt_backend_config + +if digit_version(torch.__version__) >= digit_version('1.13.0'): + BackendConfigs = { + 'academic': get_academic_backend_config(), + 'native': get_native_backend_config(), + 'tensorrt': get_tensorrt_backend_config(), + 'openvino': get_openvino_backend_config() + } +else: + BackendConfigs = { + 'academic': None, + 'native': None, + 'tensorrt': None, + 'openvino': None + } diff --git a/mmrazor/structures/quantization/backend_config/native.py b/mmrazor/structures/quantization/backend_config/native.py new file mode 100644 index 000000000..59085a56a --- /dev/null +++ b/mmrazor/structures/quantization/backend_config/native.py @@ -0,0 +1,147 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +try: + from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') + +from .common_operator_config_utils import ( # noqa: F401,F403 + _get_binary_op_configs, _get_bn_configs, _get_cat_config, + _get_conv_configs, _get_default_op_configs, _get_embedding_op_configs, + _get_fixed_qparams_op_configs, _get_linear_configs, _get_ln_configs, + _get_rnn_op_configs, _get_share_qparams_op_configs) + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_native_backend_config() -> BackendConfig: + """Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack). + + Note: + Learn more about BackendConfig, please refer to: + https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501 + """ + # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK + # BackendConfigs + + # =================== + # | DTYPE CONFIGS | + # =================== + # weighted op int8 dtype config + # this is config for ops that has quantized weights, like linear, conv + weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + + default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + ) + + default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the + # dtype_configs but it is not really used yet, + # we will enable it a bit later after we moved everything to + # backend_config_dict + is_dynamic=True, + ) + + default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the + # dtype_configs but it is not really used yet, we will enable it a bit + # later after we moved everything to backend_config_dict + is_dynamic=True, + ) + + # Needed for LayerNorm and f.layer_norm, since currently the kernel only + # supports float weights + input_output_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.float, + bias_dtype=torch.float, + ) + + weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, + ) + + weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, + ) + + conv_dtype_configs = [weighted_op_int8_dtype_config] + linear_dtype_configs = [ + weighted_op_int8_dtype_config, + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [weighted_op_int8_dtype_config] + default_op_dtype_configs = [default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [weighted_op_int8_dtype_config] + share_qparams_op_dtype_configs = [default_op_quint8_dtype_config] + rnn_op_dtype_configs = [ + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + weight_only_quint8_dtype_config, + weight_only_quint4x2_dtype_config, + ] + layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config] + + return BackendConfig('native') \ + .set_backend_pattern_configs( + _get_conv_configs(conv_dtype_configs)) \ + .set_backend_pattern_configs( + _get_linear_configs(linear_dtype_configs)) \ + .set_backend_pattern_configs( + _get_binary_op_configs(binary_op_dtype_configs)) \ + .set_backend_pattern_config( + _get_cat_config(default_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_default_op_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_bn_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_ln_configs(layer_norm_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_rnn_op_configs(rnn_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs)) + + +def get_native_backend_config_dict(): + """Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) + in dictionary form.""" + return get_native_backend_config().to_dict() + + +__all__ = [ + 'get_native_backend_config', + 'get_native_backend_config_dict', +] diff --git a/mmrazor/structures/quantization/backend_config/openvino.py b/mmrazor/structures/quantization/backend_config/openvino.py new file mode 100644 index 000000000..5e3051f75 --- /dev/null +++ b/mmrazor/structures/quantization/backend_config/openvino.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +try: + from torch.ao.quantization.backend_config import (BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType) +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + BackendPatternConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') + ObservationType = get_placeholder('torch>=1.13') + +from .common_operator_config_utils import (_get_binary_op_configs, + _get_conv_configs, + _get_linear_configs, + _get_share_qparams_op_configs) + + +def get_openvino_backend_config() -> BackendConfig: + """Return the `BackendConfig` for the OpenVINO backend. + + Note: + Learn more about BackendConfig, please refer to: + https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501 + """ + # dtype configs + weighted_op_qint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + non_weighted_op_qint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + ) + + addmm_config = BackendPatternConfig(torch.addmm) \ + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_op_qint8_dtype_config) \ + ._set_input_type_to_index({ + 'bias': 0, + 'input': 1, + 'weight': 2, + }) + cat_config = BackendPatternConfig(torch.cat) \ + .set_observation_type( + ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \ + .add_dtype_config(non_weighted_op_qint8_dtype_config) + conv_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + linear_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + binary_op_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + share_qparams_op_dtype_configs = [ + non_weighted_op_qint8_dtype_config, + ] + # there might be things not supported in fx2trt, but it will error out + # during fx2trt conversion and can support them after that + return BackendConfig('openvino') \ + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \ + .set_backend_pattern_config(addmm_config) \ + .set_backend_pattern_config(cat_config) \ + .set_backend_pattern_configs( + _get_linear_configs(linear_dtype_configs)) \ + .set_backend_pattern_configs( + _get_binary_op_configs(binary_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs)) + + +def get_openvino_backend_config_dict(): + """Return the `BackendConfig` for the OpenVINO backend in dictionary + form.""" + return get_openvino_backend_config().to_dict() + + +__all__ = [ + 'get_openvino_backend_config', + 'get_openvino_backend_config_dict', +] diff --git a/mmrazor/structures/quantization/backend_config/tensorrt.py b/mmrazor/structures/quantization/backend_config/tensorrt.py new file mode 100644 index 000000000..8dddbac91 --- /dev/null +++ b/mmrazor/structures/quantization/backend_config/tensorrt.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +try: + from torch.ao.quantization.backend_config import (BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType) +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + BackendPatternConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') + ObservationType = get_placeholder('torch>=1.13') + +from .common_operator_config_utils import (_get_conv_configs, + _get_linear_configs) + + +def get_tensorrt_backend_config() -> BackendConfig: + """Return the `BackendConfig` for the TensorRT backend. + + Note: + Learn more about BackendConfig, please refer to: + https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501 + """ + # dtype configs + weighted_op_qint8_dtype_config = DTypeConfig( + input_dtype=torch.qint8, + output_dtype=torch.qint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + + addmm_config = BackendPatternConfig(torch.addmm) \ + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_op_qint8_dtype_config) \ + ._set_input_type_to_index({ + 'bias': 0, + 'input': 1, + 'weight': 2, + }) + conv_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + linear_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + # there might be things not supported in fx2trt, but it will error out + # during fx2trt conversion and can support them after that + return BackendConfig('tensorrt') \ + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \ + .set_backend_pattern_config(addmm_config) \ + .set_backend_pattern_configs( + _get_linear_configs(linear_dtype_configs)) + + +def get_tensorrt_backend_config_dict(): + """Return the `BackendConfig` for the TensorRT backend in dictionary + form.""" + return get_tensorrt_backend_config().to_dict() + + +__all__ = [ + 'get_tensorrt_backend_config', + 'get_tensorrt_backend_config_dict', +] diff --git a/mmrazor/structures/quantization/qconfig.py b/mmrazor/structures/quantization/qconfig.py new file mode 100644 index 000000000..ab682be39 --- /dev/null +++ b/mmrazor/structures/quantization/qconfig.py @@ -0,0 +1,200 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Union + +import torch +from mmengine.config import Config + +try: + from torch.ao.quantization import FakeQuantize, QConfig + from torch.ao.quantization.utils import is_per_tensor +except ImportError: + from mmrazor.utils import get_placeholder + QConfig = get_placeholder('torch>=1.13') + FakeQuantize = get_placeholder('torch>=1.13') + is_per_tensor = get_placeholder('torch>=1.13') + +from mmrazor.registry import MODELS + +RequiredArgs = [ + 'w_qscheme', 'a_qscheme', 'w_fake_quant', 'a_fake_quant', 'w_observer', + 'a_observer' +] + +RetainArgsPerTensor = [ + 'dtype', 'qscheme', 'quant_min', 'quant_max', 'reduce_range' +] +RetainArgsPerChannel = RetainArgsPerTensor + ['ch_axis'] + + +class QSchemeHandler(object): + """Convert the qscheme of custom user-friendly qconfig to args needed in + observers. + + Args: + qdtype (str): Quantization dtype. It should is 'quint8' or 'qint8', + and should be supported by the deploy backend. Defaults to 'quint8' + bit (int): Quantization bit number. Defaults to 8. + is_symmetry (bool): Is symmetry quantization or not. Defaults to True. + is_per_channel (bool): Is per-channel quantization or not. + Defaults to False. + """ + + def __init__(self, + qdtype: str = 'quint8', + bit: int = 8, + is_symmetry: bool = True, + is_per_channel: bool = False, + **kwargs): + assert qdtype in ('quint8', 'qint8'), \ + 'qdtype is incorrect, it should be quint8 or qint8.' + self.qdtype = qdtype + self.bit = bit + self.is_symmetry = is_symmetry + self.is_per_channel = is_per_channel + + if self.is_per_channel: + self.torch_qscheme = torch.per_channel_symmetric \ + if self.is_symmetry else torch.per_channel_affine + else: + self.torch_qscheme = torch.per_tensor_symmetric \ + if self.is_symmetry else torch.per_tensor_affine + if 'is_symmetric_range' in kwargs: + self.is_symmetric_range = kwargs['is_symmetric_range'] + del kwargs['is_symmetric_range'] + else: + self.is_symmetric_range = False + self.kwargs = kwargs + + def to_observer_params(self): + """Generate the args needed in observers.""" + if self.qdtype == 'quint8': + quant_min = 0 + quant_max = 2**self.bit - 1 + else: + quant_max = 2**(self.bit - 1) - 1 + if self.is_symmetric_range: + quant_min = -2**(self.bit - 1) + 1 + else: + quant_min = -2**(self.bit - 1) + + # `dtype` will be same as BackenConfig's + naive_para = { + 'dtype': torch.quint8 if self.qdtype == 'quint8' else torch.qint8, + 'quant_min': quant_min, + 'quant_max': quant_max, + 'qscheme': self.torch_qscheme, + 'reduce_range': False + } + if self.is_per_channel: + naive_para['ch_axis'] = 0 + all_para = self.kwargs.copy() + all_para.update(naive_para) + return all_para + + def __str__(self): + """Print generated args for observers.""" + return f'dtype: {self.dtype} / bit: {self.bit} / is_symmetry: {self.is_symmetry} / \ + is_per_channel: {self.is_per_channel} \ + / extra_kwargs: {self.kwargs}' + + +class QConfigHandler(): + """Convert custom user-friendly qconfig format to torch's QConfig. + + Args: + qconfig (Dict | Config): custom user-friendly qconfig format, + including setting observers, fakequants and quantization schemes + for weights and activations. + Note: + whether quantization scheme is per-channel or not depends on + used observer, if observer support per-channel quantization, its name + should contain 'PerChannel'. + """ + + def __init__(self, qconfig: Union[Dict, Config]): + if not self.check_qconfig(qconfig): + raise ValueError('The format of qconfig is incorrect.') + else: + w_observer = MODELS.get(qconfig['w_observer']['type']) + a_observer = MODELS.get(qconfig['a_observer']['type']) + w_is_per_channel = False + a_is_per_channel = False + # import pdb;pdb.set_trace() + if 'PerChannel' in w_observer.__name__: + w_is_per_channel = True + if 'PerChannel' in a_observer.__name__: + a_is_per_channel = True + self.w_qscheme = QSchemeHandler( + is_per_channel=w_is_per_channel, **qconfig['w_qscheme']) + self.a_qscheme = QSchemeHandler( + is_per_channel=a_is_per_channel, **qconfig['a_qscheme']) + + w_fake_quant = MODELS.get(qconfig['w_fake_quant']['type']) + w_observer_kwargs = self.w_qscheme.to_observer_params() + a_fake_quant = MODELS.get(qconfig['a_fake_quant']['type']) + a_observer_kwargs = self.a_qscheme.to_observer_params() + + self.w_fake_quant = w_fake_quant.with_args( + observer=w_observer, **w_observer_kwargs) + self.a_fake_quant = a_fake_quant.with_args( + observer=a_observer, **a_observer_kwargs) + + @staticmethod + def check_qconfig(qconfig: Union[Dict, Config]): + """Check whether the passed qconfig's format meets requirement.""" + is_pass = True + for arg in RequiredArgs: + val = qconfig.get(arg, None) + if isinstance(val, dict) and arg in qconfig.keys(): + continue + else: + is_pass = False + break + return is_pass + + def convert(self): + """Generate torch's QConfig with built fake_quants.""" + torch_qconfig = QConfig( + weight=self.w_fake_quant, activation=self.a_fake_quant) + return torch_qconfig + + @staticmethod + def replace_fakequant(fake_quant_org: FakeQuantize, + qscheme_org: QSchemeHandler, + update_qparams: bool = True): + """Replace origin fakequants in model with the specified fakequant, + which is in favor of deploying the quantized model.""" + assert isinstance(qscheme_org, QSchemeHandler) + observer_kwargs = qscheme_org.to_observer_params() + if is_per_tensor(observer_kwargs['qscheme']): + observer = MODELS.get('MinMaxObserver') + retain_args = RetainArgsPerTensor + else: + observer = MODELS.get('PerChannelMinMaxObserver') + retain_args = RetainArgsPerChannel + pop_keys = [] + for k in observer_kwargs.keys(): + if k not in retain_args: + pop_keys.append(k) + for k in pop_keys: + observer_kwargs.pop(k) + fake_quant = MODELS.get('FakeQuantize') + fake_quant_wrapper = fake_quant.with_args( + observer=observer, **observer_kwargs) + if update_qparams: + device = fake_quant_org.scale.device + fake_quant_ins = fake_quant_wrapper().to(device) + fake_quant_ins.scale.copy_(fake_quant_org.scale) + fake_quant_ins.zero_point.copy_(fake_quant_org.zero_point) + obs = fake_quant_ins.activation_post_process + obs_org = fake_quant_org.activation_post_process + obs.min_val.resize_(obs_org.min_val.shape).copy_(obs_org.min_val) + obs.max_val.resize_(obs_org.max_val.shape).copy_(obs_org.max_val) + return fake_quant_ins + else: + return fake_quant_wrapper + + def fixed_w_fakequant(self): + """Make `self.w_fake_quant` fixed as the consistent fakequant.""" + self.w_fake_quant = self.replace_fakequant( + self.w_fake_quant(), self.w_qscheme, update_qparams=False) diff --git a/mmrazor/testing/__init__.py b/mmrazor/testing/__init__.py index 009dd844d..54dfd30ed 100644 --- a/mmrazor/testing/__init__.py +++ b/mmrazor/testing/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. from ._fast_stop_training_hook import FastStopTrainingHook # noqa: F401,F403 +from ._fx_models import * # noqa: F401, F403 diff --git a/mmrazor/testing/_fx_models.py b/mmrazor/testing/_fx_models.py new file mode 100644 index 000000000..6bf42e16a --- /dev/null +++ b/mmrazor/testing/_fx_models.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Tuple, Union + +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class ConvBNReLU(nn.Module): + + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: Union[int, Tuple[int, int]] = 1, + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: Union[str, bool] = 'auto', + conv_cfg: Optional[Dict] = None, + norm_cfg: Optional[Dict] = None, + act_cfg: Dict = dict(type='ReLU'), + inplace: bool = True, + with_spectral_norm: bool = False, + padding_mode: str = 'zeros', + order: tuple = ('conv', 'norm', 'act'), + init_cfg: Optional[Dict] = None, + ) -> None: + super().__init__() + self.conv_module = ConvModule(in_channel, out_channel, kernel_size, + stride, padding, dilation, groups, bias, + conv_cfg, norm_cfg, act_cfg, inplace, + with_spectral_norm, padding_mode, order) + self.toy_attr1 = 1 + self.toy_attr2 = 2 + + def forward(self, x): + x = self.conv_module.conv(x) + x = self.conv_module.norm(x) + x = self.conv_module.activate(x) + return x diff --git a/mmrazor/utils/__init__.py b/mmrazor/utils/__init__.py index a69480e94..7d23ca632 100644 --- a/mmrazor/utils/__init__.py +++ b/mmrazor/utils/__init__.py @@ -2,7 +2,7 @@ from .index_dict import IndexDict from .log_tools import get_level, print_log from .misc import find_latest_checkpoint -from .placeholder import get_placeholder +from .placeholder import get_package_placeholder, get_placeholder from .runtime_info import RuntimeInfo from .setup_env import register_all_modules, setup_multi_processes from .typing import (FixMutable, MultiMutatorsRandomSubnet, @@ -13,5 +13,6 @@ 'find_latest_checkpoint', 'setup_multi_processes', 'register_all_modules', 'FixMutable', 'ValidFixMutable', 'SingleMutatorRandomSubnet', 'MultiMutatorsRandomSubnet', 'SupportRandomSubnet', 'get_placeholder', - 'IndexDict', 'get_level', 'print_log', 'RuntimeInfo' + 'IndexDict', 'get_level', 'print_log', 'RuntimeInfo', + 'get_package_placeholder' ] diff --git a/mmrazor/utils/placeholder.py b/mmrazor/utils/placeholder.py index 553223b20..9af35f7a4 100644 --- a/mmrazor/utils/placeholder.py +++ b/mmrazor/utils/placeholder.py @@ -23,3 +23,35 @@ def __init__(self) -> None: raise_import_error(string) return PlaceHolder + + +def get_package_placeholder(string: str) -> object: + """Get placeholder instance which can avoid raising errors when down-stream + dependency is not installed properly. + + Args: + string (str): the dependency's name, i.e. `mmcls` + + Raises: + ImportError: raise it when the dependency is not installed properly. + + Returns: + object: PlaceHolder instance. + """ + + def raise_import_error(package_name): + raise ImportError( + f'`{package_name}` is not installed properly, plz check.') + + class PlaceHolderMetaclass(type): + """Used to support usage of PlaceHolder.xxxx.""" + + def __getattr__(self, name): + raise_import_error(string) + + class PlaceHolder(metaclass=PlaceHolderMetaclass): + + def __init__(self) -> None: + raise_import_error(string) + + return PlaceHolder diff --git a/mmrazor/version.py b/mmrazor/version.py index 68cc1a1f1..6a60b40f3 100644 --- a/mmrazor/version.py +++ b/mmrazor/version.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved -__version__ = '1.0.0rc2' +__version__ = '1.0.0' def parse_version_info(version_str): diff --git a/model-index.yml b/model-index.yml index 15e4595cb..6204bceae 100644 --- a/model-index.yml +++ b/model-index.yml @@ -3,7 +3,6 @@ Import: - configs/distill/mmdet/cwd/metafile.yml - configs/distill/mmcls/wsld/metafile.yml - configs/distill/mmcls/rkd/metafile.yml - - configs/nas/mmcls/spos/metafile.yml - configs/distill/mmcls/abloss/metafile.yml - configs/distill/mmcls/byot/metafile.yml - configs/distill/mmcls/dafl/metafile.yml @@ -15,6 +14,7 @@ Import: - configs/distill/mmdet/fbkd/metafile.yml - configs/distill/mmcls/factor_transfer/metafile.yml - configs/distill/mmcls/ofd/metafile.yml + - configs/nas/mmcls/spos/metafile.yml - configs/nas/mmcls/autoslim/metafile.yml - configs/nas/mmcls/darts/metafile.yml - configs/nas/mmdet/detnas/metafile.yml @@ -25,4 +25,7 @@ Import: - configs/pruning/mmcls/group_fisher/resnet50/metafile.yml - configs/pruning/mmdet/group_fisher/retinanet/metafile.yml - configs/pruning/mmcls/l1-norm/metafile.yml - - configs/pruning/mmcls/dmcp/metafile.yml + # - configs/pruning/mmcls/dmcp/metafile.yml + - configs/quantization/ptq/base/metafile.yml + - configs/quantization/qat/base/metafile.yml + - configs/quantization/qat/lsq/metafile.yml diff --git a/requirements/optional.txt b/requirements/optional.txt index f9b68e8dd..bb7173848 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,4 +1,4 @@ pydacefit pySOT==0.2.3 scipy -# timm +timm diff --git a/requirements/tests.txt b/requirements/tests.txt index 8763670ef..5980dc303 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,10 +1,11 @@ -codecov +coverage flake8 interrogate isort==4.3.21 nbconvert nbformat numpy < 1.24.0 # A temporary solution for tests with mmdet. +onnx pytest xdoctest >= 0.10.0 yapf diff --git a/tests/data/MBV2_slimmable_config.json b/tests/data/MBV2_slimmable_config.json index 5b9a5573a..9010b83e2 100644 --- a/tests/data/MBV2_slimmable_config.json +++ b/tests/data/MBV2_slimmable_config.json @@ -1,396 +1,377 @@ { - "type":"OneShotChannelMutator", - "channel_unit_cfg":{ - "type":"OneShotMutableChannelUnit", - "default_args":{ - "choice_mode":"number" + "backbone.conv1.conv_(0, 48)_48": { + "init_args": { + "num_channels": 48, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 8, + 8, + 32 + ], + "choice_mode": "number" }, - "units":{ - "backbone.conv1.conv_(0, 48)_48": { - "init_args": { - "num_channels": 48, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 8, - 8, - 32 - ], - "choice_mode": "number" - }, - "choice": 32 - }, - "backbone.layer1.0.conv.1.conv_(0, 24)_24": { - "init_args": { - "num_channels": 24, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 8, - 8, - 16 - ], - "choice_mode": "number" - }, - "choice": 16 - }, - "backbone.layer2.0.conv.0.conv_(0, 144)_144": { - "init_args": { - "num_channels": 144, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 96, - 96, - 144 - ], - "choice_mode": "number" - }, - "choice": 144 - }, - "backbone.layer2.0.conv.2.conv_(0, 40)_40": { - "init_args": { - "num_channels": 40, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 16, - 16, - 24 - ], - "choice_mode": "number" - }, - "choice": 24 - }, - "backbone.layer2.1.conv.0.conv_(0, 240)_240": { - "init_args": { - "num_channels": 240, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 96, - 96, - 176 - ], - "choice_mode": "number" - }, - "choice": 176 - }, - "backbone.layer3.0.conv.0.conv_(0, 240)_240": { - "init_args": { - "num_channels": 240, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 96, - 96, - 192 - ], - "choice_mode": "number" - }, - "choice": 192 - }, - "backbone.layer3.0.conv.2.conv_(0, 48)_48": { - "init_args": { - "num_channels": 48, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 24, - 24, - 48 - ], - "choice_mode": "number" - }, - "choice": 48 - }, - "backbone.layer3.1.conv.0.conv_(0, 288)_288": { - "init_args": { - "num_channels": 288, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 144, - 144, - 240 - ], - "choice_mode": "number" - }, - "choice": 240 - }, - "backbone.layer3.2.conv.0.conv_(0, 288)_288": { - "init_args": { - "num_channels": 288, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 144, - 144, - 144 - ], - "choice_mode": "number" - }, - "choice": 144 - }, - "backbone.layer4.0.conv.0.conv_(0, 288)_288": { - "init_args": { - "num_channels": 288, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 144, - 144, - 264 - ], - "choice_mode": "number" - }, - "choice": 264 - }, - "backbone.layer4.0.conv.2.conv_(0, 96)_96": { - "init_args": { - "num_channels": 96, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 48, - 56, - 88 - ], - "choice_mode": "number" - }, - "choice": 88 - }, - "backbone.layer4.1.conv.0.conv_(0, 576)_576": { - "init_args": { - "num_channels": 576, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 288, - 288, - 288 - ], - "choice_mode": "number" - }, - "choice": 288 - }, - "backbone.layer4.2.conv.0.conv_(0, 576)_576": { - "init_args": { - "num_channels": 576, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 288, - 288, - 336 - ], - "choice_mode": "number" - }, - "choice": 336 - }, - "backbone.layer4.3.conv.0.conv_(0, 576)_576": { - "init_args": { - "num_channels": 576, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 288, - 288, - 432 - ], - "choice_mode": "number" - }, - "choice": 432 - }, - "backbone.layer5.0.conv.0.conv_(0, 576)_576": { - "init_args": { - "num_channels": 576, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 288, - 288, - 576 - ], - "choice_mode": "number" - }, - "choice": 576 - }, - "backbone.layer5.0.conv.2.conv_(0, 144)_144": { - "init_args": { - "num_channels": 144, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 64, - 96, - 144 - ], - "choice_mode": "number" - }, - "choice": 144 - }, - "backbone.layer5.1.conv.0.conv_(0, 864)_864": { - "init_args": { - "num_channels": 864, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 432, - 432, - 576 - ], - "choice_mode": "number" - }, - "choice": 576 - }, - "backbone.layer5.2.conv.0.conv_(0, 864)_864": { - "init_args": { - "num_channels": 864, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 432, - 432, - 648 - ], - "choice_mode": "number" - }, - "choice": 648 - }, - "backbone.layer6.0.conv.0.conv_(0, 864)_864": { - "init_args": { - "num_channels": 864, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 648, - 864, - 864 - ], - "choice_mode": "number" - }, - "choice": 864 - }, - "backbone.layer6.0.conv.2.conv_(0, 240)_240": { - "init_args": { - "num_channels": 240, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 176, - 240, - 240 - ], - "choice_mode": "number" - }, - "choice": 240 - }, - "backbone.layer6.1.conv.0.conv_(0, 1440)_1440": { - "init_args": { - "num_channels": 1440, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 720, - 1440, - 1440 - ], - "choice_mode": "number" - }, - "choice": 1440 - }, - "backbone.layer6.2.conv.0.conv_(0, 1440)_1440": { - "init_args": { - "num_channels": 1440, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 720, - 960, - 1440 - ], - "choice_mode": "number" - }, - "choice": 1440 - }, - "backbone.layer7.0.conv.0.conv_(0, 1440)_1440": { - "init_args": { - "num_channels": 1440, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 1440, - 1440, - 1440 - ], - "choice_mode": "number" - }, - "choice": 1440 - }, - "backbone.layer7.0.conv.2.conv_(0, 480)_480": { - "init_args": { - "num_channels": 480, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 280, - 480, - 480 - ], - "choice_mode": "number" - }, - "choice": 480 - }, - "backbone.conv2.conv_(0, 1920)_1920": { - "init_args": { - "num_channels": 1920, - "divisor": 1, - "min_value": 1, - "min_ratio": 0.9, - "candidate_choices": [ - 1920, - 1920, - 1920 - ], - "choice_mode": "number" - }, - "choice": 1920 - } - } + "choice": 32 }, - "parse_cfg":{ - "type":"ChannelAnalyzer", - "demo_input":[ - 1, - 3, - 224, - 224 - ], - "tracer_type":"BackwardTracer" + "backbone.layer1.0.conv.1.conv_(0, 24)_24": { + "init_args": { + "num_channels": 24, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 8, + 8, + 16 + ], + "choice_mode": "number" + }, + "choice": 16 + }, + "backbone.layer2.0.conv.0.conv_(0, 144)_144": { + "init_args": { + "num_channels": 144, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 96, + 96, + 144 + ], + "choice_mode": "number" + }, + "choice": 144 + }, + "backbone.layer2.0.conv.2.conv_(0, 40)_40": { + "init_args": { + "num_channels": 40, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 16, + 16, + 24 + ], + "choice_mode": "number" + }, + "choice": 24 + }, + "backbone.layer2.1.conv.0.conv_(0, 240)_240": { + "init_args": { + "num_channels": 240, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 96, + 96, + 176 + ], + "choice_mode": "number" + }, + "choice": 176 + }, + "backbone.layer3.0.conv.0.conv_(0, 240)_240": { + "init_args": { + "num_channels": 240, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 96, + 96, + 192 + ], + "choice_mode": "number" + }, + "choice": 192 + }, + "backbone.layer3.0.conv.2.conv_(0, 48)_48": { + "init_args": { + "num_channels": 48, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 24, + 24, + 48 + ], + "choice_mode": "number" + }, + "choice": 48 + }, + "backbone.layer3.1.conv.0.conv_(0, 288)_288": { + "init_args": { + "num_channels": 288, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 144, + 144, + 240 + ], + "choice_mode": "number" + }, + "choice": 240 + }, + "backbone.layer3.2.conv.0.conv_(0, 288)_288": { + "init_args": { + "num_channels": 288, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 144, + 144, + 144 + ], + "choice_mode": "number" + }, + "choice": 144 + }, + "backbone.layer4.0.conv.0.conv_(0, 288)_288": { + "init_args": { + "num_channels": 288, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 144, + 144, + 264 + ], + "choice_mode": "number" + }, + "choice": 264 + }, + "backbone.layer4.0.conv.2.conv_(0, 96)_96": { + "init_args": { + "num_channels": 96, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 48, + 56, + 88 + ], + "choice_mode": "number" + }, + "choice": 88 + }, + "backbone.layer4.1.conv.0.conv_(0, 576)_576": { + "init_args": { + "num_channels": 576, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 288, + 288, + 288 + ], + "choice_mode": "number" + }, + "choice": 288 + }, + "backbone.layer4.2.conv.0.conv_(0, 576)_576": { + "init_args": { + "num_channels": 576, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 288, + 288, + 336 + ], + "choice_mode": "number" + }, + "choice": 336 + }, + "backbone.layer4.3.conv.0.conv_(0, 576)_576": { + "init_args": { + "num_channels": 576, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 288, + 288, + 432 + ], + "choice_mode": "number" + }, + "choice": 432 + }, + "backbone.layer5.0.conv.0.conv_(0, 576)_576": { + "init_args": { + "num_channels": 576, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 288, + 288, + 576 + ], + "choice_mode": "number" + }, + "choice": 576 + }, + "backbone.layer5.0.conv.2.conv_(0, 144)_144": { + "init_args": { + "num_channels": 144, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 64, + 96, + 144 + ], + "choice_mode": "number" + }, + "choice": 144 + }, + "backbone.layer5.1.conv.0.conv_(0, 864)_864": { + "init_args": { + "num_channels": 864, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 432, + 432, + 576 + ], + "choice_mode": "number" + }, + "choice": 576 + }, + "backbone.layer5.2.conv.0.conv_(0, 864)_864": { + "init_args": { + "num_channels": 864, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 432, + 432, + 648 + ], + "choice_mode": "number" + }, + "choice": 648 + }, + "backbone.layer6.0.conv.0.conv_(0, 864)_864": { + "init_args": { + "num_channels": 864, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 648, + 864, + 864 + ], + "choice_mode": "number" + }, + "choice": 864 + }, + "backbone.layer6.0.conv.2.conv_(0, 240)_240": { + "init_args": { + "num_channels": 240, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 176, + 240, + 240 + ], + "choice_mode": "number" + }, + "choice": 240 + }, + "backbone.layer6.1.conv.0.conv_(0, 1440)_1440": { + "init_args": { + "num_channels": 1440, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 720, + 1440, + 1440 + ], + "choice_mode": "number" + }, + "choice": 1440 + }, + "backbone.layer6.2.conv.0.conv_(0, 1440)_1440": { + "init_args": { + "num_channels": 1440, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 720, + 960, + 1440 + ], + "choice_mode": "number" + }, + "choice": 1440 + }, + "backbone.layer7.0.conv.0.conv_(0, 1440)_1440": { + "init_args": { + "num_channels": 1440, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 1440, + 1440, + 1440 + ], + "choice_mode": "number" + }, + "choice": 1440 + }, + "backbone.layer7.0.conv.2.conv_(0, 480)_480": { + "init_args": { + "num_channels": 480, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 280, + 480, + 480 + ], + "choice_mode": "number" + }, + "choice": 480 + }, + "backbone.conv2.conv_(0, 1920)_1920": { + "init_args": { + "num_channels": 1920, + "divisor": 1, + "min_value": 1, + "min_ratio": 0.9, + "candidate_choices": [ + 1920, + 1920, + 1920 + ], + "choice_mode": "number" + }, + "choice": 1920 } } \ No newline at end of file diff --git a/tests/data/models.py b/tests/data/models.py index 33fb0c624..0347b9147 100644 --- a/tests/data/models.py +++ b/tests/data/models.py @@ -78,7 +78,6 @@ def untracable_method(self, x): x = x * -2 return x - @MODELS.register_module() class UntracableBackBone(nn.Module): @@ -123,7 +122,6 @@ def forward(self, x): x_last = self.conv2(x_attn) return self.head(x_last) - @MODELS.register_module() class LinearHeadForTest(Module): @@ -704,7 +702,6 @@ def current_choice(self): def current_choice(self, choice): super().current_choice(choice) - class DynamicLinearModel(nn.Module): """ x diff --git a/tests/data/test_models/test_task_modules/mmcls_cfg.py b/tests/data/test_models/test_task_modules/mmcls_cfg.py new file mode 100644 index 000000000..117b9383e --- /dev/null +++ b/tests/data/test_models/test_task_modules/mmcls_cfg.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] \ No newline at end of file diff --git a/tests/test_models/test_algorithms/test_general_quant.py b/tests/test_models/test_algorithms/test_general_quant.py new file mode 100644 index 000000000..94a2485bc --- /dev/null +++ b/tests/test_models/test_algorithms/test_general_quant.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch.nn as nn + + +class ToyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + # TODO + + +class TestGeneralQuant(TestCase): + """TODO. + + Args: + TestCase (_type_): _description_ + """ + + def test_init(self): + pass + + def test_prepare(self): + pass + + def test_convert(self): + pass + + def test_states(self): + pass + + def test_forward(self): + pass diff --git a/tests/test_models/test_algorithms/test_mm_architecture.py b/tests/test_models/test_algorithms/test_mm_architecture.py new file mode 100644 index 000000000..310d42f5e --- /dev/null +++ b/tests/test_models/test_algorithms/test_mm_architecture.py @@ -0,0 +1,225 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +import shutil +import tempfile +from unittest import TestCase, skipIf + +import torch +import torch.nn as nn + +try: + from torch.fx import GraphModule +except ImportError: + from mmrazor.utils import get_placeholder + GraphModule = get_placeholder('torch>=1.13') + +from mmengine import ConfigDict +from mmengine.model import BaseModel + +try: + import mmdeploy +except ImportError: + from mmrazor.utils import get_package_placeholder + mmdeploy = get_package_placeholder('mmdeploy') + +from mmrazor import digit_version +from mmrazor.models.algorithms import MMArchitectureQuant +from mmrazor.registry import MODELS + + +class BasicBlock(nn.Module): + + def __init__(self, in_channels, out_channels): + super(BasicBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.mid_channels = out_channels + + self.norm1 = nn.BatchNorm2d(self.mid_channels) + self.norm2 = nn.BatchNorm2d(out_channels) + self.conv1 = nn.Conv2d(in_channels, self.mid_channels, 1) + self.conv2 = nn.Conv2d(self.mid_channels, out_channels, 1) + + self.relu = nn.ReLU6() + self.drop_path = nn.Identity() + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + out = self.drop_path(out) + + out += identity + + return out + + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class ToyModel(nn.Module): + + def __init__(self): + super(ToyModel, self).__init__() + self.stem_layer = nn.Sequential( + nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3), nn.ReLU()) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.block = BasicBlock(3, 3) + self.block2 = BasicBlock(3, 3) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(3, 4) + + def forward(self, x): + x = self.stem_layer(x) + x = self.maxpool(x) + x = self.block(x) + x = self.block2(x) + x = self.gap(x) + x = x.flatten(1) + x = self.fc(x) + return x + + +class ToyQuantModel(BaseModel): + + def __init__(self): + super().__init__() + self.architecture = ToyModel() + + def loss(self, outputs, data_samples): + return dict(loss=outputs.sum() - data_samples.sum()) + + def forward(self, inputs, data_samples, mode: str = 'tensor'): + if isinstance(inputs, list): + inputs = torch.stack(inputs) + outputs = self.architecture(inputs) + + return outputs + + +DEPLOY_CFG = ConfigDict( + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='end2end.onnx', + input_names=['input'], + output_names=['output'], + input_shape=None, + optimize=True, + dynamic_axes={ + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'output': { + 0: 'batch' + } + }), + backend_config=dict( + type='openvino', + model_inputs=[dict(opt_shapes=dict(input=[1, 3, 224, 224]))]), + codebase_config=dict(type='mmcls', task='Classification'), + function_record_to_pop=[ + 'mmcls.models.classifiers.ImageClassifier.forward', + 'mmcls.models.classifiers.BaseClassifier.forward' + ], +) + + +@skipIf( + digit_version(torch.__version__) < digit_version('1.13.0'), + 'PyTorch version lower than 1.13.0 is not supported.') +class TestMMArchitectureQuant(TestCase): + + def setUp(self): + + MODELS.register_module(module=ToyQuantModel, force=True) + + self.temp_dir = tempfile.mkdtemp() + filename = 'fp_model.pth' + filename = os.path.join(self.temp_dir, filename) + toymodel = ToyQuantModel() + torch.save(toymodel.state_dict(), filename) + + global_qconfig = ConfigDict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', + bit=8, + is_symmetry=True, + is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', + bit=8, + is_symmetry=True, + averaging_constant=0.1), + ) + alg_kwargs = ConfigDict( + type='mmrazor.MMArchitectureQuant', + architecture=dict(type='ToyQuantModel'), + float_checkpoint=filename, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict(type='mmrazor.CustomTracer'))) + self.alg_kwargs = alg_kwargs + + def tearDown(self): + MODELS.module_dict.pop('ToyQuantModel') + shutil.rmtree(self.temp_dir) + + def test_init(self): + self.toy_model = MODELS.build(self.alg_kwargs) + assert isinstance(self.toy_model, MMArchitectureQuant) + assert hasattr(self.toy_model, 'quantizer') + + alg_kwargs = copy.deepcopy(self.alg_kwargs) + alg_kwargs.deploy_cfg = DEPLOY_CFG + assert isinstance(self.toy_model, MMArchitectureQuant) + assert hasattr(self.toy_model, 'quantizer') + + def test_sync_qparams(self): + self.toy_model = MODELS.build(self.alg_kwargs) + mode = self.toy_model.forward_modes[0] + self.toy_model.sync_qparams(mode) + w_loss = self.toy_model.qmodels[ + 'loss'].architecture.block.conv1.state_dict()['weight'] + w_tensor = self.toy_model.qmodels[ + 'tensor'].architecture.block.conv1.state_dict()['weight'] + w_pred = self.toy_model.qmodels[ + 'predict'].architecture.block.conv1.state_dict()['weight'] + assert w_loss.equal(w_pred) + assert w_loss.equal(w_tensor) + + def test_build_qmodels(self): + self.toy_model = MODELS.build(self.alg_kwargs) + for forward_modes in self.toy_model.forward_modes: + qmodels = self.toy_model.qmodels[forward_modes] + assert isinstance(qmodels, GraphModule) + + def test_get_deploy_model(self): + self.toy_model = MODELS.build(self.alg_kwargs) + deploy_model = self.toy_model.get_deploy_model() + self.assertIsInstance(deploy_model, torch.fx.graph_module.GraphModule) + + def test_calibrate_step(self): + # TODO + pass diff --git a/tests/test_models/test_fake_quants/test_lsq_fake_quants.py b/tests/test_models/test_fake_quants/test_lsq_fake_quants.py new file mode 100644 index 000000000..dcbda5d40 --- /dev/null +++ b/tests/test_models/test_fake_quants/test_lsq_fake_quants.py @@ -0,0 +1,208 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +from torch.nn.parameter import Parameter + +from mmrazor import digit_version +from mmrazor.models import LearnableFakeQuantize + +try: + from torch.ao.quantization import (MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver) +except ImportError: + from mmrazor.utils import get_placeholder + MovingAverageMinMaxObserver = get_placeholder('torch>=1.13') + MovingAveragePerChannelMinMaxObserver = get_placeholder('torch>=1.13') + + +class TestLearnableFakeQuantize(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.zero_point_trainable_fakequant = LearnableFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=True, + zero_point_trainable=True) + + self.zero_point_untrainable_fakequant = \ + LearnableFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=True, + zero_point_trainable=False) + + self.zero_point_untrainable_per_channel_fakequant = \ + LearnableFakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_channel_affine, + reduce_range=True, + zero_point_trainable=False) + + def test_repr(self): + fq_module = self.zero_point_untrainable_fakequant() + repr_str = f'static_enabled={torch.tensor([1], dtype=torch.uint8)}, ' + repr_str += f'fake_quant_enabled=' \ + f'{torch.tensor([1], dtype=torch.uint8)}, ' + repr_str += 'quant_min=0, ' + repr_str += 'quant_max=127, ' + repr_str += f'dtype={torch.quint8}, ' + repr_str += f'qscheme={torch.per_tensor_affine}, ' + repr_str += f'scale={Parameter(torch.tensor([1.0]))}, ' + repr_str += f'zero_point={torch.tensor([0.])}, ' + repr_str += 'zero_point_trainable=False' + self.assertEqual(fq_module.extra_repr(), repr_str) + + fq_module = self.zero_point_trainable_fakequant() + repr_str = f'static_enabled={torch.tensor([1], dtype=torch.uint8)}, ' + repr_str += f'fake_quant_enabled=' \ + f'{torch.tensor([1], dtype=torch.uint8)}, ' + repr_str += 'quant_min=0, ' + repr_str += 'quant_max=127, ' + repr_str += f'dtype={torch.quint8}, ' + repr_str += f'qscheme={torch.per_tensor_affine}, ' + repr_str += f'scale={Parameter(torch.tensor([1.0]))}, ' + repr_str += f'zero_point={Parameter(torch.tensor([0.]))}, ' + repr_str += 'zero_point_trainable=True' + self.assertEqual(fq_module.extra_repr(), repr_str) + + def test_calculate_qparams(self): + fq_module = self.zero_point_untrainable_fakequant() + scale, zero_point = fq_module.calculate_qparams() + self.assertEqual(scale, 1.) + self.assertEqual(zero_point, 0.) + + fq_module = self.zero_point_trainable_fakequant() + scale, zero_point = fq_module.calculate_qparams() + self.assertEqual(scale, 1.) + self.assertEqual(zero_point, 0.) + + def test_forward(self): + fq_module = self.zero_point_untrainable_fakequant() + torch.manual_seed(42) + X = torch.rand(20, 10, dtype=torch.float32) + # Output of fake quant is not identical to input + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # self.assertNotEqual(Y, X) + fq_module.toggle_fake_quant(False) + X = torch.rand(20, 10, dtype=torch.float32) + Y = fq_module(X) + # Fake quant is disabled,output is identical to input + self.assertTrue(torch.equal(Y, X)) + + # Explicit copy at this point in time, because FakeQuant keeps internal + # state in mutable buffers. + scale = fq_module.scale.clone().detach() + zero_point = fq_module.zero_point.clone().detach() + + fq_module.toggle_observer_update(False) + fq_module.toggle_fake_quant(True) + X = 10.0 * torch.rand(20, 10, dtype=torch.float32) - 5.0 + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # Observer is disabled, scale and zero-point do not change + self.assertEqual(fq_module.scale, scale) + self.assertEqual(fq_module.zero_point, zero_point) + + fq_module.toggle_observer_update(True) + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # Observer is enabled, scale and zero-point are different + self.assertNotEqual(fq_module.scale, scale) + self.assertNotEqual(fq_module.zero_point, zero_point) + + fq_module = self.zero_point_trainable_fakequant() + torch.manual_seed(42) + X = torch.rand(20, 10, dtype=torch.float32) + # Output of fake quant is not identical to input + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # self.assertNotEqual(Y, X) + fq_module.toggle_fake_quant(False) + X = torch.rand(20, 10, dtype=torch.float32) + Y = fq_module(X) + # Fake quant is disabled,output is identical to input + self.assertTrue(torch.equal(Y, X)) + + # Explicit copy at this point in time, because FakeQuant keeps internal + # state in mutable buffers. + scale = fq_module.scale.clone().detach() + zero_point = fq_module.zero_point.clone().detach() + + fq_module.toggle_observer_update(False) + fq_module.toggle_fake_quant(True) + X = 10.0 * torch.rand(20, 10, dtype=torch.float32) - 5.0 + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # Observer is disabled, scale and zero-point do not change + self.assertEqual(fq_module.scale, scale) + self.assertEqual(fq_module.zero_point, zero_point) + + fq_module.toggle_observer_update(True) + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # Observer is enabled, scale and zero-point are different + self.assertNotEqual(fq_module.scale, scale) + self.assertNotEqual(fq_module.zero_point, zero_point) + + def test_state(self): + fq_module = self.zero_point_untrainable_fakequant() + + fq_module.enable_param_learning() + self.assertEqual(fq_module.learning_enabled[0], 1) + self.assertEqual(fq_module.scale.requires_grad, 1) + self.assertEqual(fq_module.zero_point.requires_grad, 0) + self.assertEqual(fq_module.fake_quant_enabled[0], 1) + self.assertEqual(fq_module.static_enabled[0], 0) + + fq_module.enable_static_estimate() + self.assertEqual(fq_module.learning_enabled[0], 0) + self.assertEqual(fq_module.scale.requires_grad, 0) + self.assertEqual(fq_module.zero_point.requires_grad, 0) + self.assertEqual(fq_module.fake_quant_enabled[0], 1) + self.assertEqual(fq_module.static_enabled[0], 1) + + fq_module.enable_val() + self.assertEqual(fq_module.learning_enabled[0], 0) + self.assertEqual(fq_module.scale.requires_grad, 0) + self.assertEqual(fq_module.zero_point.requires_grad, 0) + self.assertEqual(fq_module.fake_quant_enabled[0], 1) + self.assertEqual(fq_module.static_enabled[0], 0) + + fq_module.enable_static_observation() + self.assertEqual(fq_module.learning_enabled[0], 0) + self.assertEqual(fq_module.scale.requires_grad, 0) + self.assertEqual(fq_module.zero_point.requires_grad, 0) + self.assertEqual(fq_module.fake_quant_enabled[0], 0) + self.assertEqual(fq_module.static_enabled[0], 1) + + fq_module = self.zero_point_trainable_fakequant() + + fq_module.enable_param_learning() + self.assertEqual(fq_module.learning_enabled[0], 1) + self.assertEqual(fq_module.scale.requires_grad, 1) + self.assertEqual(fq_module.zero_point.requires_grad, 1) + self.assertEqual(fq_module.fake_quant_enabled[0], 1) + self.assertEqual(fq_module.static_enabled[0], 0) + + def test_load_state_dict(self): + fq_module = self.zero_point_untrainable_per_channel_fakequant() + state_dict = fq_module.state_dict() + X = torch.rand(32, 16, 3, 3, dtype=torch.float32) + # After forwarding, the shape of `scale` and `zero_point` in + # `fq_module` will be in shape (32, ), while the shape of those in + # `state_dict` are in shape (1, ). + _ = fq_module(X) + fq_module.load_state_dict(state_dict) diff --git a/tests/test_models/test_fake_quants/test_torch_fake_quants.py b/tests/test_models/test_fake_quants/test_torch_fake_quants.py new file mode 100644 index 000000000..485113e90 --- /dev/null +++ b/tests/test_models/test_fake_quants/test_torch_fake_quants.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmrazor import digit_version +from mmrazor.models.fake_quants import register_torch_fake_quants +from mmrazor.registry import MODELS + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_register_torch_fake_quants(): + + TORCH_fake_quants = register_torch_fake_quants() + assert isinstance(TORCH_fake_quants, list) + for fake_quant in TORCH_fake_quants: + assert MODELS.get(fake_quant) diff --git a/tests/test_models/test_observers/test_lsq_observer.py b/tests/test_models/test_observers/test_lsq_observer.py new file mode 100644 index 000000000..a61f95d7f --- /dev/null +++ b/tests/test_models/test_observers/test_lsq_observer.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmrazor import digit_version +from mmrazor.models import LSQObserver, LSQPerChannelObserver + + +class TestLSQObserver(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.lsq = LSQObserver.with_args( + dtype=torch.quint8, + qscheme=torch.per_tensor_symmetric, + reduce_range=False, + quant_min=0, + quant_max=255) + + def test_forward(self): + lsq_observer = self.lsq() + torch.manual_seed(42) + X = torch.rand(20, 10, dtype=torch.float32) + Y = lsq_observer(X) + # Output of observer is identical to input + self.assertTrue(torch.equal(Y, X)) + + X = torch.rand(0, dtype=torch.float32) + Y = lsq_observer(X) + # Output of observer is identical to input + self.assertTrue(torch.equal(Y, X)) + + def test_calculate_qparams(self): + lsq_observer = self.lsq() + X = torch.ones(10, dtype=torch.float32) + _ = lsq_observer(X) + scale, zero_point = lsq_observer.calculate_qparams() + # tensor_norm = 1, quant_max = 255 + self.assertEqual(scale, 2 * torch.tensor([1.]) / (255**0.5)) + self.assertEqual(zero_point, 127) + + +class TestLSQPerChannelObserver(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.lsq = LSQPerChannelObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + reduce_range=False, + quant_min=-127, + quant_max=127) + + def test_forward(self): + lsq_observer = self.lsq() + torch.manual_seed(42) + X = torch.rand(2, 10, dtype=torch.float32) + Y = lsq_observer(X) + # Output of observer is identical to input + self.assertTrue(torch.equal(Y, X)) + + X = torch.rand(0, dtype=torch.float32) + Y = lsq_observer(X) + # Output of observer is identical to input + self.assertTrue(torch.equal(Y, X)) + + def test_calculate_qparams(self): + lsq_observer = self.lsq() + X = torch.ones(2, 10, dtype=torch.float32) + X[0] -= 1 + _ = lsq_observer(X) + scale, zero_point = lsq_observer.calculate_qparams() + self.assertEqual(scale[0], 2 * torch.tensor([0.]) / (127**0.5)) + self.assertEqual(scale[1], 2 * torch.tensor([1.]) / (127**0.5)) diff --git a/tests/test_models/test_observers/test_torch_observers.py b/tests/test_models/test_observers/test_torch_observers.py new file mode 100644 index 000000000..cc32e69d8 --- /dev/null +++ b/tests/test_models/test_observers/test_torch_observers.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmrazor import digit_version +from mmrazor.models.observers import register_torch_observers +from mmrazor.registry import MODELS + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_register_torch_observers(): + + TORCH_observers = register_torch_observers() + assert isinstance(TORCH_observers, list) + for observer in TORCH_observers: + assert MODELS.get(observer) diff --git a/tests/test_models/test_quantizers/test_academic_quantizer.py b/tests/test_models/test_quantizers/test_academic_quantizer.py new file mode 100644 index 000000000..c95060a00 --- /dev/null +++ b/tests/test_models/test_quantizers/test_academic_quantizer.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import copy +from unittest import TestCase + +import torch +from mmengine.model import BaseModule + +try: + from torch.ao.nn.intrinsic import ConvBnReLU2d + from torch.ao.quantization.backend_config import BackendConfig + from torch.ao.quantization.fx.custom_config import PrepareCustomConfig + from torch.ao.quantization.fx.graph_module import ObservedGraphModule + from torch.ao.quantization.qconfig_mapping import QConfigMapping + from torch.ao.quantization.quant_type import QuantType +except ImportError: + from mmrazor.utils import get_placeholder + ConvBnReLU2d = get_placeholder('torch>=1.13') + BackendConfig = get_placeholder('torch>=1.13') + PrepareCustomConfig = get_placeholder('torch>=1.13') + ConObservedGraphModuleBnReLU2d = get_placeholder('torch>=1.13') + QConfigMapping = get_placeholder('torch>=1.13') + QuantType = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.quantizers import AcademicQuantizer +from mmrazor.models.quantizers.academic_quantizer import ( + FLOAT_TO_OBSERVED_DICT_KEY, GLOBAL_DICT_KEY, MODULE_NAME_DICT_KEY, + OBJECT_TYPE_DICT_KEY, PRESERVED_ATTRIBUTES_DICT_KEY) +from mmrazor.registry import MODELS +from mmrazor.testing import ConvBNReLU + + +@MODELS.register_module() +class ToyFloatModel(BaseModule): + + def __init__(self) -> None: + super().__init__() + + +@MODELS.register_module() +class ToyObservedModel(BaseModule): + + def __init__(self) -> None: + super().__init__() + + +class TestAcademicQuantizer(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + self.global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict(qdtype='qint8', bit=8, is_symmetry=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), + ) + self.qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict(qdtype='qint8', bit=4, is_symmetry=True), + a_qscheme=dict(qdtype='quint8', bit=4, is_symmetry=True), + ) + self.model = ConvBNReLU(3, 3, norm_cfg=dict(type='BN')) + + def test_gen_qconfig_mapping(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # test set GLOBAL_DICT_KEY by QConfigMapping + global_qconfig = copy(self.global_qconfig) + qconfig_mapping = {GLOBAL_DICT_KEY: global_qconfig} + quantizer = AcademicQuantizer(qconfig_mapping=qconfig_mapping) + assert hasattr(quantizer, 'qconfig_mapping') + assert isinstance(quantizer.qconfig_mapping, QConfigMapping) + assert quantizer.qconfig_mapping.global_qconfig + + # test set OBJECT_TYPE_DICT_KEY by QConfigMapping + qconfig = copy(self.qconfig) + qconfig_mapping = { + OBJECT_TYPE_DICT_KEY: + [('torch.ao.nn.intrinsic.ConvBnReLU2d', qconfig)] + } + quantizer = AcademicQuantizer(qconfig_mapping=qconfig_mapping) + assert hasattr(quantizer, 'qconfig_mapping') + assert isinstance(quantizer.qconfig_mapping, QConfigMapping) + assert quantizer.qconfig_mapping.object_type_qconfigs.get(ConvBnReLU2d) + + # test set MODULE_NAME_DICT_KEY by QConfigMapping + qconfig = copy(self.qconfig) + qconfig_mapping = { + MODULE_NAME_DICT_KEY: [('conv_module.conv', qconfig)] + } + quantizer = AcademicQuantizer(qconfig_mapping=qconfig_mapping) + assert hasattr(quantizer, 'qconfig_mapping') + assert isinstance(quantizer.qconfig_mapping, QConfigMapping) + assert quantizer.qconfig_mapping.module_name_qconfigs.get( + 'conv_module.conv') + + def test_gen_prepare_custom_config(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # test prepare_custom_config is None + global_qconfig = copy(self.global_qconfig) + qconfig_mapping = {GLOBAL_DICT_KEY: global_qconfig} + quantizer = AcademicQuantizer(qconfig_mapping=qconfig_mapping) + assert hasattr(quantizer, 'prepare_custom_config') + assert isinstance(quantizer.prepare_custom_config, PrepareCustomConfig) + + # test set FLOAT_TO_OBSERVED_DICT_KEY and PRESERVED_ATTRIBUTES_DICT_KEY + # by PrepareCustomConfig + global_qconfig = copy(self.global_qconfig) + qconfig_mapping = {GLOBAL_DICT_KEY: global_qconfig} + flop_to_observed_list = [('ToyFloatModel', 'ToyObservedModel')] + preserved_attributes_list = ['toy_attr1', 'toy_attr2'] + prepare_custom_config = { + FLOAT_TO_OBSERVED_DICT_KEY: flop_to_observed_list, + PRESERVED_ATTRIBUTES_DICT_KEY: preserved_attributes_list + } + quantizer = AcademicQuantizer( + qconfig_mapping=qconfig_mapping, + prepare_custom_config=prepare_custom_config) + + assert hasattr(quantizer, 'prepare_custom_config') + assert isinstance(quantizer.prepare_custom_config, PrepareCustomConfig) + mapping = quantizer.prepare_custom_config.float_to_observed_mapping[ + QuantType.STATIC] + assert mapping.get(ToyFloatModel) + assert mapping[ToyFloatModel] == ToyObservedModel + + attributes = quantizer.prepare_custom_config.preserved_attributes + assert attributes == preserved_attributes_list + + def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + global_qconfig = copy(self.global_qconfig) + qconfig_mapping = {GLOBAL_DICT_KEY: global_qconfig} + quantizer = AcademicQuantizer(qconfig_mapping=qconfig_mapping) + assert hasattr(quantizer, 'backend_config') + assert isinstance(quantizer.backend_config, BackendConfig) + + def test_prepare(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + global_qconfig = copy(self.global_qconfig) + qconfig_mapping = {GLOBAL_DICT_KEY: global_qconfig} + preserved_attributes_list = ['toy_attr1', 'toy_attr2'] + prepare_custom_config = { + PRESERVED_ATTRIBUTES_DICT_KEY: preserved_attributes_list + } + quantizer = AcademicQuantizer( + qconfig_mapping=qconfig_mapping, + prepare_custom_config=prepare_custom_config) + model = copy(self.model) + prepared = quantizer.prepare(model) + assert isinstance(prepared, ObservedGraphModule) + assert hasattr(prepared, 'toy_attr1') + assert hasattr(prepared, 'toy_attr2') diff --git a/tests/test_models/test_quantizers/test_exporter.py b/tests/test_models/test_quantizers/test_exporter.py new file mode 100644 index 000000000..04bd8a671 --- /dev/null +++ b/tests/test_models/test_quantizers/test_exporter.py @@ -0,0 +1,348 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +import shutil +import tempfile +from unittest import TestCase, skipIf + +import torch +import torch.nn as nn + +try: + import onnx + from onnx import helper + from torch.fx import GraphModule +except ImportError: + from mmrazor.utils import get_package_placeholder, get_placeholder + GraphModule = get_placeholder('torch>=1.13') + onnx = get_package_placeholder('No module named onnx') + helper = get_package_placeholder('No module named onnx.helper') + +from mmengine import ConfigDict +from mmengine.model import BaseModel + +try: + import mmdeploy +except ImportError: + from mmrazor.utils import get_package_placeholder + mmdeploy = get_package_placeholder('mmdeploy') + +from mmrazor import digit_version +from mmrazor.models.quantizers.exporters import (OpenVinoQuantizeExportor, + TensorRTExplicitExporter) +from mmrazor.models.quantizers.exporters.optim_utils import ONNXOptimUtils +from mmrazor.registry import MODELS + + +class BasicBlock(nn.Module): + + def __init__(self, in_channels, out_channels): + super(BasicBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.mid_channels = out_channels + + self.norm1 = nn.BatchNorm2d(self.mid_channels) + self.norm2 = nn.BatchNorm2d(out_channels) + self.conv1 = nn.Conv2d(in_channels, self.mid_channels, 1) + self.conv2 = nn.Conv2d(self.mid_channels, out_channels, 1) + + self.relu = nn.ReLU6() + self.drop_path = nn.Identity() + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + out = self.drop_path(out) + + out += identity + + return out + + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class ToyModel(nn.Module): + + def __init__(self): + super(ToyModel, self).__init__() + self.stem_layer = nn.Sequential( + nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3), nn.ReLU()) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.block = BasicBlock(3, 3) + self.block2 = BasicBlock(3, 3) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(3, 4) + + def forward(self, x): + x = self.stem_layer(x) + x = self.maxpool(x) + x = self.block(x) + x = self.block2(x) + x = self.gap(x) + x = x.flatten(1) + x = self.fc(x) + return x + + +class ToyQuantModel(BaseModel): + + def __init__(self): + super().__init__() + self.architecture = ToyModel() + + def loss(self, outputs, data_samples): + return dict(loss=outputs.sum() - data_samples.sum()) + + def forward(self, inputs, data_samples, mode: str = 'tensor'): + if isinstance(inputs, list): + inputs = torch.stack(inputs) + outputs = self.architecture(inputs) + + return outputs + + +OpenVINO_GLOBAL_QCONFIG = ConfigDict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +OpenVINO_ALG_CONFIG = ConfigDict( + type='mmrazor.MMArchitectureQuant', + architecture=dict(type='ToyQuantModel'), + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=OpenVINO_GLOBAL_QCONFIG, + tracer=dict(type='mmrazor.CustomTracer'))) + +TensorRT_GLOBAL_QCONFIG = ConfigDict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict(qdtype='qint8', bit=8, is_symmetry=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +TensorRT_ALG_CONFIG = ConfigDict( + type='mmrazor.MMArchitectureQuant', + architecture=dict(type='ToyQuantModel'), + quantizer=dict( + type='mmrazor.TensorRTQuantizer', + global_qconfig=OpenVINO_GLOBAL_QCONFIG, + tracer=dict(type='mmrazor.CustomTracer'))) + + +@skipIf( + digit_version(torch.__version__) < digit_version('1.13.0'), + 'PyTorch version lower than 1.13.0 is not supported.') +class TestONNXOptimUtils(TestCase): + + def setUp(self): + MODELS.register_module(module=ToyQuantModel, force=True) + self.temp_dir = tempfile.mkdtemp() + filename = 'symbolic.onnx' + filename = os.path.join(self.temp_dir, filename) + toy_model = MODELS.build(OpenVINO_ALG_CONFIG) + observed_model = toy_model.get_deploy_model() + torch.onnx.export( + observed_model, + torch.rand(2, 3, 16, 16), + filename, + opset_version=11) + self.onnx_model = onnx.load(filename) + self.optimizer = ONNXOptimUtils + + def tearDown(self): + MODELS.module_dict.pop('ToyQuantModel') + shutil.rmtree(self.temp_dir) + + def test_map_name_and_data(self): + params = self.optimizer.map_name_and_data(self.onnx_model) + params_keys = [ + 'activation_post_process_0.scale', + 'activation_post_process_0.zero_point', + 'architecture.stem_layer.0.weight', + 'architecture.stem_layer.0.bias', + 'architecture.stem_layer.0.weight_fake_quant.scale', + 'architecture.stem_layer.0.weight_fake_quant.zero_point', + 'architecture.block.conv1.weight', 'architecture.block.conv1.bias', + 'architecture.block.conv1.weight_fake_quant.scale', + 'architecture.block.conv2.bias', + 'architecture.block2.conv1.weight', + 'architecture.block2.conv1.bias', + 'architecture.block2.conv1.weight_fake_quant.scale', + 'architecture.block2.conv2.weight', + 'architecture.block2.conv2.bias', + 'architecture.block2.conv2.weight_fake_quant.scale', + 'architecture.fc.weight', 'architecture.fc.bias', + 'architecture.fc.weight_fake_quant.scale', + 'architecture.fc.weight_fake_quant.zero_point', + 'activation_post_process_15.zero_point', + 'activation_post_process_15.scale', + 'activation_post_process_14.zero_point', + 'activation_post_process_14.scale', + 'activation_post_process_12.zero_point', + 'activation_post_process_12.scale', + 'activation_post_process_10.zero_point', + 'activation_post_process_10.scale', + 'activation_post_process_8.zero_point', + 'activation_post_process_8.scale', + 'activation_post_process_6.zero_point', + 'activation_post_process_6.scale', + 'activation_post_process_4.zero_point', + 'activation_post_process_4.scale', + 'activation_post_process_1.zero_point', + 'activation_post_process_1.scale', + 'architecture.block2.conv2.weight_fake_quant.zero_point', + 'architecture.block2.conv1.weight_fake_quant.zero_point', + 'architecture.block.conv2.weight_fake_quant.zero_point', + 'architecture.block.conv2.weight_fake_quant.scale', + 'architecture.block.conv2.weight', + 'architecture.block.conv1.weight_fake_quant.zero_point', + '/activation_post_process_0/Constant_output_0', + '/activation_post_process_0/Constant_1_output_0', + '/stem_layer.0/weight_fake_quant/Constant_output_0', + '/stem_layer.0/weight_fake_quant/Constant_1_output_0', + '/relu/Constant_output_0', '/relu/Constant_1_output_0', + '/relu_dup1/Constant_output_0', '/relu_dup1/Constant_1_output_0', + '/relu_1/Constant_output_0', '/relu_1/Constant_1_output_0', + '/relu_dup1_1/Constant_output_0', + '/relu_dup1_1/Constant_1_output_0' + ] + self.assertEqual(set(params.keys()), set(params_keys)) + + def test_map_name_and_initializer(self): + initializers = self.optimizer.map_name_and_initializer(self.onnx_model) + for init in self.onnx_model.graph.initializer: + self.assertIn(init.name, initializers.keys()) + # self.assertEqual(set(initializers.keys()), set(initializers_keys)) + + def test_map_output_and_node(self): + _ = self.optimizer.map_output_and_node(self.onnx_model) + + def test_map_input_and_node(self): + _ = self.optimizer.map_input_and_node(self.onnx_model) + + def test_remove_node_from_onnx(self): + onnx_model = copy.deepcopy(self.onnx_model) + node_to_remove = next(iter(onnx_model.graph.node)) + self.optimizer.remove_node_from_onnx(node_to_remove, onnx_model) + for node in onnx_model.graph.node: + self.assertNotEqual(node, node_to_remove) + + def test_remove_initializer_from_onnx(self): + onnx_model = copy.deepcopy(self.onnx_model) + initializer_to_remove = next(iter(onnx_model.graph.initializer)) + self.optimizer.remove_initializer_from_onnx(initializer_to_remove, + onnx_model) + for initializer in onnx_model.graph.initializer: + self.assertNotEqual(initializer, initializer_to_remove) + + def test_find_standalone_nodes(self): + standalone_nodes = self.optimizer.find_standalone_nodes( + self.onnx_model) + self.assertEqual(standalone_nodes, []) + + def test_find_redundant_initializers(self): + redundant_initializers = self.optimizer.find_redundant_initializers( + self.onnx_model) + self.assertEqual(redundant_initializers, []) + + def test_topo_sort(self): + onnx_model = copy.deepcopy(self.onnx_model) + onnx_model_topo_sort = self.optimizer.topo_sort(onnx_model) + self.assertEqual( + len(onnx_model_topo_sort.graph.node), + len(self.onnx_model.graph.node)) + + def test_optimize(self): + onnx_model = copy.deepcopy(self.onnx_model) + fake_node = helper.make_node('fake_node', [], [], mode='constant') + self.optimizer.insert_node_to_onnx(fake_node, onnx_model) + self.optimizer.optimize(onnx_model) + for node in onnx_model.graph.node: + self.assertNotEqual(node, fake_node) + + +@skipIf( + digit_version(torch.__version__) < digit_version('1.13.0'), + 'PyTorch version lower than 1.13.0 is not supported.') +class TestOpenVinoQuantizeExportor(TestCase): + + def setUp(self): + MODELS.register_module(module=ToyQuantModel, force=True) + self.temp_dir = tempfile.mkdtemp() + filename = 'toy_model_symbolic.onnx' + filename = os.path.join(self.temp_dir, filename) + toy_model = MODELS.build(OpenVINO_ALG_CONFIG) + observed_model = toy_model.get_deploy_model() + torch.onnx.export( + observed_model, + torch.rand(2, 3, 16, 16), + filename, + opset_version=11) + self.onnx_model = onnx.load(filename) + self.export_path = os.path.join(self.temp_dir, 'toy_model.onnx') + + def tearDown(self): + MODELS.module_dict.pop('ToyQuantModel') + shutil.rmtree(self.temp_dir) + + def test_export(self): + exporter = OpenVinoQuantizeExportor(self.onnx_model, self.export_path) + exporter.export() + self.assertTrue(os.path.exists(self.export_path)) + onnx_model = onnx.load(self.export_path) + self.assertIsInstance(onnx_model, onnx.ModelProto) + + +@skipIf( + digit_version(torch.__version__) < digit_version('1.13.0'), + 'PyTorch version lower than 1.13.0 is not supported.') +class TestTensorRTExplicitExporter(TestCase): + + def setUp(self): + MODELS.register_module(module=ToyQuantModel, force=True) + self.temp_dir = tempfile.mkdtemp() + filename = 'toy_model_symbolic.onnx' + filename = os.path.join(self.temp_dir, filename) + toy_model = MODELS.build(TensorRT_ALG_CONFIG) + observed_model = toy_model.get_deploy_model() + torch.onnx.export( + observed_model, + torch.rand(2, 3, 16, 16), + filename, + opset_version=11) + self.onnx_model = onnx.load(filename) + self.export_path = os.path.join(self.temp_dir, 'toy_model.onnx') + + def tearDown(self): + MODELS.module_dict.pop('ToyQuantModel') + shutil.rmtree(self.temp_dir) + + def test_export(self): + exporter = TensorRTExplicitExporter(self.onnx_model, self.export_path) + exporter.export() + self.assertTrue(os.path.exists(self.export_path)) + onnx_model = onnx.load(self.export_path) + self.assertIsInstance(onnx_model, onnx.ModelProto) diff --git a/tests/test_models/test_quantizers/test_native_quantizer.py b/tests/test_models/test_quantizers/test_native_quantizer.py new file mode 100644 index 000000000..8f982c139 --- /dev/null +++ b/tests/test_models/test_quantizers/test_native_quantizer.py @@ -0,0 +1,224 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +import torch.nn as nn + +from mmrazor import digit_version +from mmrazor.models.quantizers import TorchNativeQuantizer +from mmrazor.models.quantizers.native_quantizer import SUPPORT_QAT_MODULES +from mmrazor.models.task_modules.tracer import CustomTracer +from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ + build_graphmodule +from mmrazor.registry import MODELS +from mmrazor.structures.quantization import BackendConfigs, QConfigHandler + +try: + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.fx.graph_module import ObservedGraphModule + from torch.ao.quantization.qconfig_mapping import QConfigMapping + from torch.ao.quantization.quantize_fx import _fuse_fx + from torch.fx import GraphModule +except ImportError: + from mmrazor.utils import get_placeholder + GraphModule = get_placeholder('torch>=1.13') + ObservedGraphModule = get_placeholder('torch>=1.13') + QConfigMapping = get_placeholder('torch>=1.13') + prepare = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + + +class BasicBlock(nn.Module): + + def __init__(self, in_channels, out_channels): + super(BasicBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.mid_channels = out_channels + + self.norm1 = nn.BatchNorm2d(self.mid_channels) + self.norm2 = nn.BatchNorm2d(out_channels) + self.conv1 = nn.Conv2d(in_channels, self.mid_channels, 1) + self.conv2 = nn.Conv2d(self.mid_channels, out_channels, 1) + + self.relu = nn.ReLU6() + self.drop_path = nn.Identity() + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + out = self.drop_path(out) + + out += identity + + return out + + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class ToyQuantModel(nn.Module): + + def __init__(self): + super().__init__() + self.stem_layer = nn.Sequential( + nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3), nn.ReLU()) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.block = BasicBlock(3, 3) + self.block2 = BasicBlock(3, 3) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(3, 4) + + def forward(self, x): + x = self.stem_layer(x) + x = self.maxpool(x) + x = self.block(x) + x = self.block2(x) + x = self.gap(x) + x = x.flatten(1) + x = self.fc(x) + return x + + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1)) + +no_observer_modules = [ + 'torch.nn.Conv2d', +] + +q_kwargs = dict( + type='mmrazor.TorchNativeQuantizer', + global_qconfig=global_qconfig, + no_observer_modules=no_observer_modules, + tracer=dict(type='CustomTracer'), +) + + +class TestTorchNativeQuantizer(TestCase): + """TODO. + + Args: + TestCase (_type_): _description_ + """ + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.q_kwargs = q_kwargs + self.tracer = CustomTracer() + self.backend_config = BackendConfigs['native'] + self.qconfig = QConfigHandler(global_qconfig) + self.qconfig_mapping = QConfigMapping().set_global( + self.qconfig.convert()) + self.example_inputs = (torch.randn(1, 3, 224, 224), ) + self.native_quantizer = MODELS.build(self.q_kwargs) + + def tearDown(self): + pass + + def swap_ff_with_fxff(self, model): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + modules_to_swap = [] + for name, module in model.named_children(): + if isinstance(module, torch.ao.nn.quantized.FloatFunctional): + modules_to_swap.append(name) + else: + self.swap_ff_with_fxff(module) + + for name in modules_to_swap: + del model._modules[name] + model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() + + def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + native_quantizer = MODELS.build(self.q_kwargs) + self.assertIsInstance(native_quantizer, TorchNativeQuantizer) + + def test_prepare(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + toy_model = ToyQuantModel() + toy_model.eval() + + self.swap_ff_with_fxff(toy_model) + traced_graph = self.tracer.trace(toy_model) + graph_module = build_graphmodule(toy_model, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + assert isinstance(graph_module, GraphModule) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + assert isinstance(prepared, ObservedGraphModule) + + prepared = self.native_quantizer.del_redundant_fakequant(prepared) + assert isinstance(prepared, GraphModule) + + def post_process_for_deploy(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + toy_model = ToyQuantModel() + toy_model.eval() + + self.swap_ff_with_fxff(toy_model) + traced_graph = self.tracer.trace(toy_model) + graph_module = build_graphmodule(toy_model, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + assert isinstance(graph_module, GraphModule) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + assert isinstance(prepared, ObservedGraphModule) + + prepared = self.native_quantizer.del_redundant_fakequant(prepared) + assert isinstance(prepared, GraphModule) + + prepared_no_fq = prepared + + self.native_quantizer.post_process_weight_fakequant(prepared) + for name, child in prepared.named_children(): + if isinstance(child, SUPPORT_QAT_MODULES): + raise ValueError + self.native_quantizer.post_process_weight_fakequant( + prepared_no_fq, True) + for name, child in prepared_no_fq.named_children(): + if isinstance(child, SUPPORT_QAT_MODULES): + raise ValueError diff --git a/tests/test_models/test_quantizers/test_openvino_quantizer.py b/tests/test_models/test_quantizers/test_openvino_quantizer.py new file mode 100644 index 000000000..7b60dc4a3 --- /dev/null +++ b/tests/test_models/test_quantizers/test_openvino_quantizer.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import shutil +import tempfile +from copy import copy +from unittest import TestCase + +import torch + +try: + from torch.ao.quantization.fx.graph_module import ObservedGraphModule +except ImportError: + from mmrazor.utils import get_placeholder + ObservedGraphModule = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.quantizers import OpenVINOQuantizer +from mmrazor.testing import ConvBNReLU + + +class TestOpenVINOQuantizer(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + self.global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict(qdtype='qint8', bit=8, is_symmetry=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), + ) + self.temp_dir = tempfile.mkdtemp() + self.model = ConvBNReLU(3, 3, norm_cfg=dict(type='BN')) + + def tearDown(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + shutil.rmtree(self.temp_dir) + + def test_property(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + global_qconfig = copy(self.global_qconfig) + quantizer = OpenVINOQuantizer(global_qconfig=global_qconfig) + assert quantizer.backend == 'openvino' + assert quantizer.support_w_modes == ('per_tensor', 'per_channel') + assert quantizer.support_a_modes == ('per_tensor') + assert quantizer.module_prev_wo_fakequant + assert quantizer.module_next_wo_fakequant + assert quantizer.method_next_wo_fakequant + assert quantizer.op_prev_wo_fakequant diff --git a/tests/test_models/test_quantizers/test_tensorrt_quantizer.py b/tests/test_models/test_quantizers/test_tensorrt_quantizer.py new file mode 100644 index 000000000..f5433a0f9 --- /dev/null +++ b/tests/test_models/test_quantizers/test_tensorrt_quantizer.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import shutil +import tempfile +from copy import copy +from unittest import TestCase + +import torch + +try: + from torch.ao.quantization.fx.graph_module import ObservedGraphModule +except ImportError: + from mmrazor.utils import get_placeholder + ObservedGraphModule = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.quantizers import TensorRTQuantizer +from mmrazor.testing import ConvBNReLU + + +class TestTensorRTQuantizer(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + self.global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict(qdtype='qint8', bit=8, is_symmetry=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), + ) + self.temp_dir = tempfile.mkdtemp() + self.model = ConvBNReLU(3, 3, norm_cfg=dict(type='BN')) + + def tearDown(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + shutil.rmtree(self.temp_dir) + + def test_property(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + global_qconfig = copy(self.global_qconfig) + quantizer = TensorRTQuantizer(global_qconfig=global_qconfig) + assert quantizer.backend == 'tensorrt' + assert quantizer.support_w_modes == ('per_tensor', 'per_channel') + assert quantizer.support_a_modes == ('per_tensor') diff --git a/tests/test_models/test_task_modules/test_custom_tracer.py b/tests/test_models/test_task_modules/test_custom_tracer.py new file mode 100644 index 000000000..2d01ea496 --- /dev/null +++ b/tests/test_models/test_task_modules/test_custom_tracer.py @@ -0,0 +1,184 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import pytest +import torch +from mmcls.models.backbones.resnet import ResLayer +from mmengine.config import Config +from mmengine.registry import MODELS + +try: + from torch.fx import GraphModule + from torch.fx._symbolic_trace import Graph +except ImportError: + from mmrazor.utils import get_placeholder + GraphModule = get_placeholder('torch>=1.13') + Graph = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.task_modules.tracer import (CustomTracer, + UntracedMethodRegistry, + build_graphmodule, + custom_symbolic_trace) +from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ + _prepare_module_dict + + +class ToyModel(torch.nn.Module): + + def __init__(self): + super().__init__() + + def get_loss(self, x): + return x * 0.1 + + def extrac_feature(self, x): + return x * 2 + + def forward(self, x): + x = self.extrac_feature(x) + x = self.get_loss(x) + return x + + +class testUntracedMethodRgistry(TestCase): + + def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + method = ToyModel.get_loss + method_registry = UntracedMethodRegistry(method) + assert hasattr(method_registry, 'method') + assert hasattr(method_registry, 'method_dict') + + def test_registry_method(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + model = ToyModel + method = ToyModel.get_loss + method_registry = UntracedMethodRegistry(method) + method_registry.__set_name__(model, 'get_loss') + assert 'get_loss' in method_registry.method_dict.keys() + assert method_registry.method_dict['get_loss']['mod'] == model + + +class testCustomTracer(TestCase): + + def setUp(self): + self.cfg = Config.fromfile( + 'tests/data/test_models/test_task_modules/mmcls_cfg.py') + self.skipped_methods = [ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ] + self.skipped_module_names = ['backbone.layer4.0'] + self.skipped_module_classes = [ResLayer] + + def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # init without skipped_methods + tracer = CustomTracer() + assert hasattr(tracer, 'skipped_methods') + assert len(tracer.skipped_methods) == 0 + # init with skipped_methods(list) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer(skipped_methods=self.skipped_methods) + assert '_get_loss' in UntracedMethodRegistry.method_dict.keys() + assert '_get_predictions' in UntracedMethodRegistry.method_dict.keys() + # init with skipped_methods(str) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer(skipped_methods=self.skipped_methods[0]) + assert '_get_loss' in UntracedMethodRegistry.method_dict.keys() + # init with skipped_methods(int, error) + with self.assertRaises(TypeError): + CustomTracer(skipped_methods=123) + # init with skipped_methods(str, error) + with self.assertRaises(AssertionError): + CustomTracer(skipped_methods='_get_loss') + + def test_trace(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # test trace with skipped_methods + model = MODELS.build(self.cfg.model) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer(skipped_methods=self.skipped_methods) + graph_tensor = tracer.trace(model, concrete_args={'mode': 'tensor'}) + graph_loss = tracer.trace(model, concrete_args={'mode': 'loss'}) + graph_predict = tracer.trace(model, concrete_args={'mode': 'predict'}) + assert isinstance(graph_tensor, Graph) + assert isinstance(graph_loss, Graph) + skip_flag_loss = False + for node in graph_loss.nodes: + if node.op == 'call_method' and node.target == '_get_loss': + skip_flag_loss = True + assert isinstance(graph_predict, Graph) + skip_flag_predict = False + for node in graph_predict.nodes: + if node.op == 'call_method' and node.target == '_get_predictions': + skip_flag_predict = True + assert skip_flag_loss and skip_flag_predict + + # test trace with skipped_module_names + model = MODELS.build(self.cfg.model) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer(skipped_module_names=self.skipped_module_names) + graph_tensor = tracer.trace(model, concrete_args={'mode': 'tensor'}) + skip_flag = False + for node in graph_tensor.nodes: + skipped_module_name = self.skipped_module_names[0] + if node.op == 'call_module' and node.target == skipped_module_name: + skip_flag = True + assert skip_flag + + # test trace with skipped_module_classes + model = MODELS.build(self.cfg.model) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer( + skipped_module_classes=self.skipped_module_classes) + graph_tensor = tracer.trace(model, concrete_args={'mode': 'tensor'}) + skip_flag = False + for node in graph_tensor.nodes: + if node.op == 'call_module' and node.target == 'backbone.layer1': + skip_flag = True + assert skip_flag + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_custom_symbolic_trace(): + cfg = Config.fromfile( + 'tests/data/test_models/test_task_modules/mmcls_cfg.py') + model = MODELS.build(cfg.model) + UntracedMethodRegistry.method_dict = dict() + graph_module = custom_symbolic_trace( + model, concrete_args={'mode': 'tensor'}) + assert isinstance(graph_module, GraphModule) + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_build_graphmodule(): + skipped_methods = ['mmcls.models.heads.ClsHead._get_predictions'] + cfg = Config.fromfile( + 'tests/data/test_models/test_task_modules/mmcls_cfg.py') + model = MODELS.build(cfg.model) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer(skipped_methods=skipped_methods) + graph_predict = tracer.trace(model, concrete_args={'mode': 'predict'}) + graph_module = build_graphmodule(model, graph_predict) + assert isinstance(graph_module, GraphModule) + + # test _prepare_module_dict + modules = dict(model.named_modules()) + module_dict = _prepare_module_dict(model, graph_predict) + for k, v in module_dict.items(): + assert isinstance(v, torch.nn.Module) + assert not isinstance(v, modules[k].__class__) diff --git a/tests/test_models/test_task_modules/test_graph_utils.py b/tests/test_models/test_task_modules/test_graph_utils.py new file mode 100644 index 000000000..ea7f90565 --- /dev/null +++ b/tests/test_models/test_task_modules/test_graph_utils.py @@ -0,0 +1,536 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import operator +from unittest import TestCase + +import torch +import torch.nn as nn + +try: + from torch.ao.quantization import QConfigMapping + from torch.ao.quantization.fake_quantize import FakeQuantizeBase + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.quantize_fx import _fuse_fx +except ImportError: + from mmrazor.utils import get_placeholder + QConfigMapping = get_placeholder('torch>=1.13') + FakeQuantizeBase = get_placeholder('torch>=1.13') + prepare = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.task_modules.tracer import CustomTracer, build_graphmodule +from mmrazor.models.task_modules.tracer.fx import ( + del_fakequant_after_function, del_fakequant_after_method, + del_fakequant_after_module, del_fakequant_after_op, + del_fakequant_before_function, del_fakequant_before_method, + del_fakequant_before_module, del_fakequant_before_op) +from mmrazor.structures.quantization import BackendConfigs, QConfigHandler + + +def _get_attrs(target, attrs): + attrs = attrs.split('.') + + for att in attrs: + target = getattr(target, att, None) + return target + + +class BasicBlock(nn.Module): + + def __init__(self, in_channels, out_channels): + super(BasicBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.mid_channels = out_channels + + self.norm1 = nn.BatchNorm2d(self.mid_channels) + self.norm2 = nn.BatchNorm2d(out_channels) + self.conv1 = nn.Conv2d(in_channels, self.mid_channels, 1) + self.conv2 = nn.Conv2d(self.mid_channels, out_channels, 1) + + self.relu = nn.ReLU6() + self.drop_path = nn.Identity() + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + out = self.drop_path(out) + + out += identity + + return out + + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class ToyModel(nn.Module): + + def __init__(self): + super().__init__() + self.stem_layer = nn.Sequential( + nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3), nn.ReLU()) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.block = BasicBlock(3, 3) + self.block2 = BasicBlock(3, 3) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(3, 4) + + def forward(self, x): + x = self.stem_layer(x) + x = self.maxpool(x) + x = self.block(x) + x = self.block2(x) + x = self.gap(x) + x = x.flatten(1) + x = self.fc(x) + return x + + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + + +class TestGraphUtils(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + self.tracer = CustomTracer() + self.backend_config = BackendConfigs['native'] + self.qconfig = QConfigHandler(global_qconfig) + self.qconfig_mapping = QConfigMapping().set_global( + self.qconfig.convert()) + self.example_inputs = (torch.randn(1, 3, 224, 224), ) + + def swap_ff_with_fxff(self, model): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + modules_to_swap = [] + for name, module in model.named_children(): + if isinstance(module, torch.ao.nn.quantized.FloatFunctional): + modules_to_swap.append(name) + else: + self.swap_ff_with_fxff(module) + + for name in modules_to_swap: + del model._modules[name] + model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() + + def test_del_fakequant_before_op(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + op_del_prev_fakequant = ('output', ) + + prepared_after_del = del_fakequant_before_op( + prepared, op_del_prev_fakequant, inplace=False) + for node in prepared.graph.nodes: + if node.op in op_del_prev_fakequant: + args = node.args + self.assertIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op in op_del_prev_fakequant: + args = node.args + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_before_op( + prepared, op_del_prev_fakequant, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op in op_del_prev_fakequant: + args = node.args + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + def test_del_fakequant_after_op(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + op_del_next_fakequant = ('placeholder', ) + + prepared_after_del = del_fakequant_after_op( + prepared, op_del_next_fakequant, inplace=False) + for node in prepared.graph.nodes: + if node.op in op_del_next_fakequant: + self.assertIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op in op_del_next_fakequant: + self.assertNotIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_after_op( + prepared, op_del_next_fakequant, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op in op_del_next_fakequant: + self.assertNotIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + def test_del_fakequant_before_method(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + method_del_prev_fakequant = ('flatten', ) + + prepared_after_del = del_fakequant_before_method( + prepared, method_del_prev_fakequant, inplace=False) + for node in prepared.graph.nodes: + if node.op == 'call_method' and \ + node.target in method_del_prev_fakequant: + args = node.args + self.assertIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op == 'call_method' and \ + node.target in method_del_prev_fakequant: + args = node.args + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_before_method( + prepared, method_del_prev_fakequant, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op == 'call_method' and \ + node.target in method_del_prev_fakequant: + args = node.args + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + def test_del_fakequant_after_method(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + method_del_next_fakequant = ('flatten', ) + + prepared_after_del = del_fakequant_after_method( + prepared, method_del_next_fakequant, inplace=False) + for node in prepared.graph.nodes: + if node.op == 'call_method' and \ + node.target in method_del_next_fakequant: + self.assertIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op == 'call_method' and \ + node.target in method_del_next_fakequant: + self.assertNotIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_after_method( + prepared, method_del_next_fakequant, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op == 'call_method' and \ + node.target in method_del_next_fakequant: + self.assertNotIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + def test_del_fakequant_before_function(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + function_del_prev_fakequant = (operator.add, ) + + prepared_after_del = del_fakequant_before_function( + prepared, function_del_prev_fakequant, inplace=False) + for node in prepared.graph.nodes: + if node.op == 'call_function' and \ + node.target in function_del_prev_fakequant: + args = node.args + self.assertIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op == 'call_function' and \ + node.target in function_del_prev_fakequant: + args = node.args + self.assertEqual(len(args), 2) + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + self.assertNotIsInstance( + _get_attrs(prepared, args[1].target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_before_function( + prepared, function_del_prev_fakequant, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op == 'call_function' and \ + node.target in function_del_prev_fakequant: + args = node.args + self.assertEqual(len(args), 2) + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + self.assertNotIsInstance( + _get_attrs(prepared, args[1].target), FakeQuantizeBase) + + def test_del_fakequant_after_function(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + function_del_next_fakequant = (operator.add, ) + + prepared_after_del = del_fakequant_after_function( + prepared, function_del_next_fakequant, inplace=False) + for node in prepared.graph.nodes: + if node.op == 'call_function' and \ + node.target in function_del_next_fakequant: + self.assertIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op == 'call_function' and \ + node.target in function_del_next_fakequant: + self.assertNotIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_after_function( + prepared, function_del_next_fakequant, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op == 'call_function' and \ + node.target in function_del_next_fakequant: + self.assertNotIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + def test_del_fakequant_before_module(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + module_del_prev_fakequant = (torch.nn.ReLU6, torch.nn.Identity) + + prepared_after_del = del_fakequant_before_module( + prepared, module_del_prev_fakequant, inplace=False) + for node in prepared.graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared, node.target), + module_del_prev_fakequant): + args = node.args + self.assertIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared, node.target), + module_del_prev_fakequant): + args = node.args + if args[0].op == 'call_module': + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_before_module( + prepared, module_del_prev_fakequant, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared, node.target), + module_del_prev_fakequant): + args = node.args + if args[0].op == 'call_module': + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + def test_del_fakequant_after_module(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + module_del_next_fakequant = (torch.nn.MaxPool2d, ) + + prepared_after_del = del_fakequant_after_module( + prepared, module_del_next_fakequant, inplace=False) + for node in prepared.graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared, node.target), + module_del_next_fakequant): + self.assertIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared, node.target), + module_del_next_fakequant): + self.assertNotIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_after_module( + prepared, module_del_next_fakequant, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared, node.target), + module_del_next_fakequant): + self.assertNotIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index 6652cb943..c8340f352 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -83,6 +83,24 @@ def test_build_razor_from_cfg(self): model = MODELS.build(cfg.model) self.assertTrue(isinstance(model, BaseModel)) + def test_build_subnet_prune_from_cfg(self): + mutator_cfg = fileio.load('tests/data/test_registry/subnet.json') + init_cfg = dict( + type='Pretrained', + checkpoint='tests/data/test_registry/subnet_weight.pth') + # test fix subnet + model_cfg = dict( + # use mmrazor's build_func + type='mmrazor.sub_model', + cfg=dict( + cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', + pretrained=False), + fix_subnet=mutator_cfg, + mode='mutator', + init_cfg=init_cfg) + model = MODELS.build(model_cfg) + self.assertTrue(isinstance(model, BaseModel)) + def test_build_subnet_prune_from_cfg_by_mutator(self): mutator_cfg = fileio.load('tests/data/test_registry/subnet.json') init_cfg = dict( diff --git a/tests/test_runners/test_quantization_loop.py b/tests/test_runners/test_quantization_loop.py new file mode 100644 index 000000000..6a300fb91 --- /dev/null +++ b/tests/test_runners/test_quantization_loop.py @@ -0,0 +1,413 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import logging +import shutil +import tempfile +from unittest import TestCase + +import torch +import torch.nn as nn +from mmengine.config import Config, ConfigDict +from mmengine.evaluator import BaseMetric +from mmengine.hooks import Hook +from mmengine.logging import MMLogger +from mmengine.model import BaseModel +from mmengine.optim import OptimWrapper +from mmengine.registry import DATASETS, HOOKS, METRICS, MODELS, OPTIM_WRAPPERS +from mmengine.runner import Runner +from torch.nn.intrinsic.qat import ConvBnReLU2d +from torch.utils.data import Dataset + +from mmrazor import digit_version +from mmrazor.engine import (LSQEpochBasedLoop, PTQLoop, QATEpochBasedLoop, + QATValLoop) + +try: + from torch.ao.nn.quantized import FloatFunctional, FXFloatFunctional + from torch.ao.quantization import QConfigMapping + from torch.ao.quantization.fake_quantize import FakeQuantizeBase + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.qconfig_mapping import \ + get_default_qconfig_mapping + from torch.ao.quantization.quantize_fx import _fuse_fx +except ImportError: + from mmrazor.utils import get_placeholder + QConfigMapping = get_placeholder('torch>=1.13') + FakeQuantizeBase = get_placeholder('torch>=1.13') + prepare = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + get_default_qconfig_mapping = get_placeholder('torch>=1.13') + FloatFunctional = get_placeholder('torch>=1.13') + FXFloatFunctional = get_placeholder('torch>=1.13') + + +class ToyDataset(Dataset): + METAINFO = dict() # type: ignore + data = torch.randn(12, 3, 4, 4) + label = torch.ones(12) + + @property + def metainfo(self): + return self.METAINFO + + def __len__(self): + return self.data.size(0) + + def __getitem__(self, index): + return dict(inputs=self.data[index], data_sample=self.label[index]) + + +class MMArchitectureQuant(BaseModel): + + def __init__(self, data_preprocessor=None): + super().__init__(data_preprocessor=data_preprocessor) + self.architecture = ToyModel() + + def calibrate_step(self, data): + data = self.data_preprocessor(data, False) + return self.architecture(**data) + + def sync_qparams(self, src_mode): + pass + + def forward(self, inputs, data_sample, mode='tensor'): + return self.architecture(inputs, data_sample, mode) + + +class ToyModel(BaseModel): + + def __init__(self, data_preprocessor=None): + super().__init__(data_preprocessor=data_preprocessor) + qconfig = get_default_qconfig_mapping().to_dict()[''] + self.architecture = nn.Sequential( + ConvBnReLU2d(3, 3, 1, qconfig=qconfig)) + + def forward(self, inputs, data_sample, mode='tensor'): + if isinstance(inputs, list): + inputs = torch.stack(inputs) + if isinstance(data_sample, list): + data_sample = torch.stack(data_sample) + outputs = self.architecture(inputs) + + if mode == 'tensor': + return outputs + elif mode == 'loss': + loss = data_sample.sum() - outputs.sum() + outputs = dict(loss=loss) + return outputs + elif mode == 'predict': + return outputs + + +class ToyOptimWrapper(OptimWrapper): + ... + + +class ToyMetric1(BaseMetric): + + def __init__(self, collect_device='cpu', dummy_metrics=None): + super().__init__(collect_device=collect_device) + self.dummy_metrics = dummy_metrics + + def process(self, data_batch, predictions): + result = {'acc': 1} + self.results.append(result) + + def compute_metrics(self, results): + return dict(acc=1) + + +DEFAULT_CFG = ConfigDict( + model=dict(type='MMArchitectureQuant'), + train_dataloader=dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=False), + batch_size=3, + num_workers=0), + test_dataloader=dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=False), + batch_size=3, + num_workers=0), + optim_wrapper=dict( + type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), + val_evaluator=dict(type='ToyMetric1'), + test_evaluator=dict(type='ToyMetric1'), + train_cfg=dict(), + val_cfg=dict(), + test_cfg=dict(), + custom_hooks=[], + data_preprocessor=None, + launcher='none', + env_cfg=dict(dist_cfg=dict(backend='nccl')), +) + + +class TestQATEpochBasedLoop(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.temp_dir = tempfile.mkdtemp() + MODELS.register_module(module=MMArchitectureQuant, force=True) + DATASETS.register_module(module=ToyDataset, force=True) + METRICS.register_module(module=ToyMetric1, force=True) + OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True) + + default_cfg = copy.deepcopy(DEFAULT_CFG) + default_cfg = Config(default_cfg) + default_cfg.work_dir = self.temp_dir + default_cfg.train_cfg = ConfigDict( + type='mmrazor.QATEpochBasedLoop', + max_epochs=4, + val_begin=1, + val_interval=1, + disable_observer_begin=-1, + freeze_bn_begin=-1, + dynamic_intervals=None) + self.default_cfg = default_cfg + + def tearDown(self): + MODELS.module_dict.pop('MMArchitectureQuant') + DATASETS.module_dict.pop('ToyDataset') + METRICS.module_dict.pop('ToyMetric1') + OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper') + + logging.shutdown() + MMLogger._instance_dict.clear() + shutil.rmtree(self.temp_dir) + + def test_init(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_init_qat_train_loop' + runner = Runner(**cfg) + self.assertIsInstance(runner, Runner) + self.assertIsInstance(runner.train_loop, QATEpochBasedLoop) + + def test_run_epoch(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_train' + runner = Runner.from_cfg(cfg) + runner.train() + + @HOOKS.register_module(force=True) + class TestFreezeBNHook(Hook): + + def __init__(self, freeze_bn_begin): + self.freeze_bn_begin = freeze_bn_begin + + def after_train_epoch(self, runner): + + def check_bn_stats(mod): + if isinstance(mod, ConvBnReLU2d): + assert mod.freeze_bn + assert not mod.bn.training + + if runner.train_loop._epoch + 1 >= self.freeze_bn_begin: + runner.model.apply(check_bn_stats) + + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_freeze_bn' + cfg.custom_hooks = [ + dict(type='TestFreezeBNHook', priority=50, freeze_bn_begin=1) + ] + cfg.train_cfg.freeze_bn_begin = 1 + runner = Runner.from_cfg(cfg) + runner.train() + + @HOOKS.register_module(force=True) + class TestDisableObserverHook(Hook): + + def __init__(self, disable_observer_begin): + self.disable_observer_begin = disable_observer_begin + + def after_train_epoch(self, runner): + + def check_observer_stats(mod): + if isinstance(mod, FakeQuantizeBase): + assert mod.fake_quant_enabled[0] == 0 + + if runner.train_loop._epoch + 1 >= self.disable_observer_begin: + runner.model.apply(check_observer_stats) + + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_disable_observer' + cfg.custom_hooks = [ + dict( + type='TestDisableObserverHook', + priority=50, + disable_observer_begin=1) + ] + cfg.train_cfg.disable_observer_begin = 1 + runner = Runner.from_cfg(cfg) + runner.train() + + +class TestLSQEpochBasedLoop(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.temp_dir = tempfile.mkdtemp() + MODELS.register_module(module=MMArchitectureQuant, force=True) + DATASETS.register_module(module=ToyDataset, force=True) + METRICS.register_module(module=ToyMetric1, force=True) + OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True) + + default_cfg = copy.deepcopy(DEFAULT_CFG) + default_cfg = Config(default_cfg) + default_cfg.work_dir = self.temp_dir + default_cfg.train_cfg = ConfigDict( + type='mmrazor.LSQEpochBasedLoop', + max_epochs=4, + val_begin=1, + val_interval=1, + freeze_bn_begin=-1, + dynamic_intervals=None) + self.default_cfg = default_cfg + + def tearDown(self): + MODELS.module_dict.pop('MMArchitectureQuant') + DATASETS.module_dict.pop('ToyDataset') + METRICS.module_dict.pop('ToyMetric1') + OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper') + + logging.shutdown() + MMLogger._instance_dict.clear() + shutil.rmtree(self.temp_dir) + + def test_init(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_init_lsq_train_loop' + runner = Runner(**cfg) + self.assertIsInstance(runner, Runner) + self.assertIsInstance(runner.train_loop, LSQEpochBasedLoop) + + def test_run_epoch(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_train' + runner = Runner.from_cfg(cfg) + runner.train() + + @HOOKS.register_module(force=True) + class TestFreezeBNHook(Hook): + + def __init__(self, freeze_bn_begin): + self.freeze_bn_begin = freeze_bn_begin + + def after_train_epoch(self, runner): + + def check_bn_stats(mod): + if isinstance(mod, ConvBnReLU2d): + assert mod.freeze_bn + assert not mod.bn.training + + if runner.train_loop._epoch + 1 >= self.freeze_bn_begin: + runner.model.apply(check_bn_stats) + + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_freeze_bn' + cfg.custom_hooks = [ + dict(type='TestFreezeBNHook', priority=50, freeze_bn_begin=1) + ] + cfg.train_cfg.freeze_bn_begin = 1 + runner = Runner.from_cfg(cfg) + runner.train() + + +class TestQATValLoop(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.temp_dir = tempfile.mkdtemp() + MODELS.register_module(module=MMArchitectureQuant, force=True) + DATASETS.register_module(module=ToyDataset, force=True) + METRICS.register_module(module=ToyMetric1, force=True) + OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True) + + default_cfg = copy.deepcopy(DEFAULT_CFG) + default_cfg = Config(default_cfg) + default_cfg.work_dir = self.temp_dir + default_cfg.val_cfg = ConfigDict(type='mmrazor.QATValLoop') + self.default_cfg = default_cfg + + def tearDown(self): + MODELS.module_dict.pop('MMArchitectureQuant') + DATASETS.module_dict.pop('ToyDataset') + METRICS.module_dict.pop('ToyMetric1') + OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper') + + logging.shutdown() + MMLogger._instance_dict.clear() + shutil.rmtree(self.temp_dir) + + def test_init(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_init_qat_val_loop' + runner = Runner(**cfg) + self.assertIsInstance(runner, Runner) + self.assertIsInstance(runner.val_loop, QATValLoop) + + def test_run(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_qat_val' + cfg.pop('train_dataloader') + cfg.pop('train_cfg') + cfg.pop('optim_wrapper') + cfg.pop('test_dataloader') + cfg.pop('test_cfg') + cfg.pop('test_evaluator') + runner = Runner.from_cfg(cfg) + runner.val() + + +class TestPTQLoop(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.temp_dir = tempfile.mkdtemp() + MODELS.register_module(module=MMArchitectureQuant, force=True) + DATASETS.register_module(module=ToyDataset, force=True) + METRICS.register_module(module=ToyMetric1, force=True) + OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True) + + default_cfg = copy.deepcopy(DEFAULT_CFG) + default_cfg = Config(default_cfg) + default_cfg.work_dir = self.temp_dir + # save_checkpoint in PTQLoop need train_dataloader + default_cfg.train_cfg = ConfigDict(by_epoch=True, max_epochs=3) + default_cfg.test_cfg = ConfigDict( + type='mmrazor.PTQLoop', + calibrate_dataloader=default_cfg.train_dataloader, + calibrate_steps=32) + self.default_cfg = default_cfg + + def tearDown(self): + MODELS.module_dict.pop('MMArchitectureQuant') + DATASETS.module_dict.pop('ToyDataset') + METRICS.module_dict.pop('ToyMetric1') + OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper') + + logging.shutdown() + MMLogger._instance_dict.clear() + shutil.rmtree(self.temp_dir) + + def test_init(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_init_ptq_loop' + runner = Runner(**cfg) + self.assertIsInstance(runner, Runner) + self.assertIsInstance(runner.test_loop, PTQLoop) + + def test_run(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_ptq_run' + runner = Runner.from_cfg(cfg) + runner.test() diff --git a/tests/test_structures/test_backendconfig.py b/tests/test_structures/test_backendconfig.py new file mode 100644 index 000000000..24295e391 --- /dev/null +++ b/tests/test_structures/test_backendconfig.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +try: + from torch.ao.quantization.backend_config import BackendConfig +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + +import pytest +import torch + +from mmrazor import digit_version +from mmrazor.structures.quantization.backend_config import ( + BackendConfigs, get_academic_backend_config, + get_academic_backend_config_dict, get_native_backend_config, + get_native_backend_config_dict, get_openvino_backend_config, + get_openvino_backend_config_dict, get_tensorrt_backend_config, + get_tensorrt_backend_config_dict) + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_get_backend_config(): + + # test get_native_backend_config + native_backend_config = get_native_backend_config() + assert isinstance(native_backend_config, BackendConfig) + assert native_backend_config.name == 'native' + native_backend_config_dict = get_native_backend_config_dict() + assert isinstance(native_backend_config_dict, dict) + + # test get_academic_backend_config + academic_backend_config = get_academic_backend_config() + assert isinstance(academic_backend_config, BackendConfig) + assert academic_backend_config.name == 'academic' + academic_backend_config_dict = get_academic_backend_config_dict() + assert isinstance(academic_backend_config_dict, dict) + + # test get_openvino_backend_config + openvino_backend_config = get_openvino_backend_config() + assert isinstance(openvino_backend_config, BackendConfig) + assert openvino_backend_config.name == 'openvino' + openvino_backend_config_dict = get_openvino_backend_config_dict() + assert isinstance(openvino_backend_config_dict, dict) + + # test get_tensorrt_backend_config + tensorrt_backend_config = get_tensorrt_backend_config() + assert isinstance(tensorrt_backend_config, BackendConfig) + assert tensorrt_backend_config.name == 'tensorrt' + tensorrt_backend_config_dict = get_tensorrt_backend_config_dict() + assert isinstance(tensorrt_backend_config_dict, dict) + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_backendconfigs_mapping(): + + mapping = BackendConfigs + assert isinstance(mapping, dict) + assert 'academic' in mapping.keys() + assert isinstance(mapping['academic'], BackendConfig) diff --git a/tests/test_structures/test_qconfig.py b/tests/test_structures/test_qconfig.py new file mode 100644 index 000000000..7ab78243d --- /dev/null +++ b/tests/test_structures/test_qconfig.py @@ -0,0 +1,172 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from unittest import TestCase + +import torch +from mmengine.config import Config + +try: + from torch.ao.quantization import FakeQuantize, QConfig +except ImportError: + from mmrazor.utils import get_placeholder + QConfig = get_placeholder('torch>=1.13') + FakeQuantize = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.fake_quants import register_torch_fake_quants +from mmrazor.models.observers import register_torch_observers +from mmrazor.registry import MODELS +from mmrazor.structures import QConfigHandler, QSchemeHandler + +register_torch_observers() +register_torch_fake_quants() + + +class TestQSchemeHandler(TestCase): + + def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # per_channel + qscheme = QSchemeHandler(is_symmetry=True, is_per_channel=True) + assert qscheme.torch_qscheme is torch.per_channel_symmetric + + # per_tensor + qscheme = QSchemeHandler(is_symmetry=True, is_per_channel=False) + assert qscheme.torch_qscheme is torch.per_tensor_symmetric + + # qdtype is incorrect + self.assertRaises(AssertionError, QSchemeHandler, 'float') + + # is_symmetric_range + kwargs = {'is_symmetric_range': True} + qscheme = QSchemeHandler(**kwargs) + assert qscheme.is_symmetric_range is True + + def test_to_observer_params(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # qdtype = quint8 + ret_params = QSchemeHandler(qdtype='quint8').to_observer_params() + assert ret_params['dtype'] == torch.quint8 + assert ret_params['quant_min'] == 0 and ret_params['quant_max'] == 255 + + # qdtype = qint8, is_symmetric_range=False + ret_params = QSchemeHandler(qdtype='qint8').to_observer_params() + assert ret_params['dtype'] == torch.qint8 + assert ret_params['quant_min'] == -128 and ret_params[ + 'quant_max'] == 127 + + # qdtype = qint8, is_symmetric_range=True + ret_params = QSchemeHandler( + qdtype='qint8', is_symmetric_range=True).to_observer_params() + assert ret_params['quant_min'] == -127 and ret_params[ + 'quant_max'] == 127 + + # per_channel + ret_params = QSchemeHandler(is_per_channel=True).to_observer_params() + assert ret_params['ch_axis'] == 0 + + # per_tensor + ret_params = QSchemeHandler(is_per_channel=False).to_observer_params() + assert 'ch_axis' not in ret_params.keys() + + +class TestQConfigHandler(TestCase): + + def setUp(self): + self.qconfig_dict = dict( + w_observer=dict(type='MovingAveragePerChannelMinMaxObserver'), + a_observer=dict(type='MovingAveragePerChannelMinMaxObserver'), + w_fake_quant=dict(type='FakeQuantize'), + a_fake_quant=dict(type='FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', + bit=8, + is_symmetry=True, + is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), + ) + self.qconfig = Config(self.qconfig_dict) + + def test_check_qconfig(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + assert QConfigHandler.check_qconfig(self.qconfig_dict) is True + assert QConfigHandler.check_qconfig(self.qconfig) is True + qconfig_dict = copy.copy(self.qconfig_dict) + print(qconfig_dict) + qconfig_dict.pop('w_observer') + assert QConfigHandler.check_qconfig(qconfig_dict) is False + + def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # test dict init + qconfig = QConfigHandler(self.qconfig_dict) + assert hasattr(qconfig, 'w_qscheme') + assert hasattr(qconfig, 'a_qscheme') + assert hasattr(qconfig, 'w_fake_quant') + assert hasattr(qconfig, 'a_fake_quant') + + # test mmengine's Config init + qconfig = QConfigHandler(self.qconfig) + assert hasattr(qconfig, 'w_qscheme') + assert hasattr(qconfig, 'a_qscheme') + assert hasattr(qconfig, 'w_fake_quant') + assert hasattr(qconfig, 'a_fake_quant') + + # per_channel + assert qconfig.w_qscheme.is_per_channel is True + assert qconfig.a_qscheme.is_per_channel is True + + def test_convert(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + qconfig = QConfigHandler(self.qconfig) + torch_qconfig = qconfig.convert() + assert isinstance(torch_qconfig, QConfig) + + def test_replace_fakequant(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # update_qparams is False + qconfig = QConfigHandler(self.qconfig) + org_fakequant_ins = qconfig.w_fake_quant() + new_fakequant = qconfig.replace_fakequant( + org_fakequant_ins, qconfig.w_qscheme, update_qparams=False) + new_fakequant_ins = new_fakequant() + assert isinstance(new_fakequant_ins, FakeQuantize) + assert isinstance(new_fakequant_ins.activation_post_process, + MODELS.get('PerChannelMinMaxObserver')) + + # update_qparams is True + qconfig = QConfigHandler(self.qconfig) + org_fakequant_ins = qconfig.w_fake_quant() + org_fakequant_ins.scale = torch.Tensor([2]) + org_fakequant_ins.activation_post_process.min_val = torch.Tensor([1]) + new_fakequant_ins = qconfig.replace_fakequant( + org_fakequant_ins, qconfig.w_qscheme, update_qparams=True) + assert isinstance(new_fakequant_ins, FakeQuantize) + assert isinstance(new_fakequant_ins.activation_post_process, + MODELS.get('PerChannelMinMaxObserver')) + assert new_fakequant_ins.scale == org_fakequant_ins.scale + assert new_fakequant_ins.activation_post_process.min_val == \ + org_fakequant_ins.activation_post_process.min_val + + def test_fixed_w_fakequant(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + qconfig = QConfigHandler(self.qconfig) + qconfig.fixed_w_fakequant() + new_fakequant_ins = qconfig.w_fake_quant() + assert isinstance(new_fakequant_ins, FakeQuantize) + assert isinstance(new_fakequant_ins.activation_post_process, + MODELS.get('PerChannelMinMaxObserver')) diff --git a/tools/model_converters/convert_quant_ckpt.py b/tools/model_converters/convert_quant_ckpt.py new file mode 100644 index 000000000..9fbb06125 --- /dev/null +++ b/tools/model_converters/convert_quant_ckpt.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +from pathlib import Path + +import torch + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert quantized checkpoint to deploy') + parser.add_argument('checkpoint', help='input checkpoint filename') + parser.add_argument('--out-path', help='save checkpoint path') + parser.add_argument( + '--inplace', action='store_true', help='replace origin ckpt') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + checkpoint = torch.load(args.checkpoint, map_location='cpu') + new_state_dict = dict() + new_meta = checkpoint['meta'] + + for key, value in checkpoint['state_dict'].items(): + if key.startswith('qmodels.predict.'): + new_key = key.replace('qmodels.predict.', '') + if '_val' in new_key and 'weight_fake_quant' in new_key: + new_key = new_key.replace('_val', '_vals') + new_state_dict[new_key] = value + # if key.startswith('architecture.'): + # new_key = key.replace('architecture.', '') + # new_state_dict[new_key] = value + + checkpoint = dict() + checkpoint['meta'] = new_meta + checkpoint['state_dict'] = new_state_dict + + if args.inplace: + torch.save(checkpoint, args.checkpoint) + else: + ckpt_path = Path(args.checkpoint) + ckpt_name = ckpt_path.stem + if args.out_path: + ckpt_dir = Path(args.out_path) + else: + ckpt_dir = ckpt_path.parent + new_ckpt_path = ckpt_dir / f'{ckpt_name}_deploy.pth' + torch.save(checkpoint, new_ckpt_path) + + +if __name__ == '__main__': + main() diff --git a/tools/ptq.py b/tools/ptq.py new file mode 100644 index 000000000..2c00c5b11 --- /dev/null +++ b/tools/ptq.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.runner import Runner + +from mmrazor.utils import register_all_modules + + +# TODO: support fuse_conv_bn, visualization, and format_only +def parse_args(): + parser = argparse.ArgumentParser( + description='MMRazor test (and eval) a model') + parser.add_argument('config', help='test config file path') + # parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--work-dir', + help='the directory to save the file containing evaluation metrics') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def main(): + register_all_modules(False) + args = parse_args() + + # load config + cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + + # cfg.load_from = args.checkpoint + + # build the runner from config + runner = Runner.from_cfg(cfg) + + # start testing + runner.test() + + +if __name__ == '__main__': + main() diff --git a/tools/slurm_test.sh b/tools/slurm_test.sh index 6dd67e574..3c74ec6ec 100644 --- a/tools/slurm_test.sh +++ b/tools/slurm_test.sh @@ -1,24 +1,10 @@ #!/usr/bin/env bash -set -x - -PARTITION=$1 -JOB_NAME=$2 -CONFIG=$3 -CHECKPOINT=$4 -GPUS=${GPUS:-8} -GPUS_PER_NODE=${GPUS_PER_NODE:-8} -CPUS_PER_TASK=${CPUS_PER_TASK:-5} -PY_ARGS=${@:5} -SRUN_ARGS=${SRUN_ARGS:-""} +CONFIG=$1 +CHECKPOINT=$2 +GPUS=$3 +PORT=${PORT:-29500} PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ -srun -p ${PARTITION} \ - --job-name=${JOB_NAME} \ - --gres=gpu:${GPUS_PER_NODE} \ - --ntasks=${GPUS} \ - --ntasks-per-node=${GPUS_PER_NODE} \ - --cpus-per-task=${CPUS_PER_TASK} \ - --kill-on-bad-exit=1 \ - ${SRUN_ARGS} \ - python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} +python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} diff --git a/tools/test.py b/tools/test.py index fb6b00b86..a69133158 100644 --- a/tools/test.py +++ b/tools/test.py @@ -66,6 +66,8 @@ def main(): cfg.load_from = None else: cfg.load_from = args.checkpoint + if 'type' in cfg.test_cfg and cfg.test_cfg.type.endswith('PTQLoop'): + cfg.test_cfg.only_val = True # build the runner from config runner = Runner.from_cfg(cfg)