From bc2a309a0036e6c4433c9858cb59eccf1ccf2747 Mon Sep 17 00:00:00 2001 From: liukai Date: Wed, 16 Nov 2022 10:37:08 +0800 Subject: [PATCH 1/9] add mmyolo demo input and fix bug --- .../demo_inputs/default_demo_inputs.py | 26 ++++++++++++------- .../task_modules/demo_inputs/demo_inputs.py | 12 +++++---- .../task_modules/tracer/prune_tracer.py | 3 ++- tools/tests/test_tools.py | 24 ++++++++++++----- 4 files changed, 44 insertions(+), 21 deletions(-) diff --git a/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py b/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py index 1d6cf49a4..7a9d856e5 100644 --- a/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py +++ b/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py @@ -6,7 +6,8 @@ from mmrazor.utils import get_placeholder from .demo_inputs import (BaseDemoInput, DefaultMMClsDemoInput, DefaultMMDemoInput, DefaultMMDetDemoInput, - DefaultMMRotateDemoInput, DefaultMMSegDemoInput) + DefaultMMRotateDemoInput, DefaultMMSegDemoInput, + DefaultMMYoloDemoInput) try: from mmdet.models import BaseDetector @@ -35,32 +36,39 @@ 'mmdet': DefaultMMDetDemoInput, 'mmseg': DefaultMMSegDemoInput, 'mmrotate': DefaultMMRotateDemoInput, + 'mmyolo': DefaultMMYoloDemoInput, 'torchvision': BaseDemoInput, } -def defaul_demo_inputs(model, input_shape, training=False, scope=None): +def get_default_demo_input_class(model, scope): + if scope is not None: for scope_name, demo_input in default_demo_input_class_for_scope.items( ): if scope == scope_name: - return demo_input(input_shape, training).get_data(model) + return demo_input for module_type, demo_input in default_demo_input_class.items( # noqa ): # noqa if isinstance(model, module_type): - return demo_input(input_shape, training).get_data(model) + return demo_input # default - return BaseDemoInput(input_shape, training).get_data(model) + return BaseDemoInput + + +def defaul_demo_inputs(model, input_shape, training=False, scope=None): + demo_input = get_default_demo_input_class(model, scope) + return demo_input().get_data(model, input_shape, training) @TASK_UTILS.register_module() class DefaultDemoInput(BaseDemoInput): - def __init__(self, - input_shape=[1, 3, 224, 224], - training=False, - scope=None) -> None: + def __init__(self, input_shape=None, training=False, scope=None) -> None: + default_demo_input_class = get_default_demo_input_class(None, scope) + if input_shape is None: + input_shape = default_demo_input_class.default_shape super().__init__(input_shape, training) self.scope = scope diff --git a/mmrazor/models/task_modules/demo_inputs/demo_inputs.py b/mmrazor/models/task_modules/demo_inputs/demo_inputs.py index fee5343bf..f0f7f685f 100644 --- a/mmrazor/models/task_modules/demo_inputs/demo_inputs.py +++ b/mmrazor/models/task_modules/demo_inputs/demo_inputs.py @@ -6,8 +6,9 @@ @TASK_UTILS.register_module() class BaseDemoInput(): + default_shape = (1, 3, 224, 224) - def __init__(self, input_shape=[1, 3, 224, 224], training=False) -> None: + def __init__(self, input_shape=default_shape, training=None) -> None: self.input_shape = input_shape self.training = training @@ -33,10 +34,6 @@ def __call__(self, class DefaultMMDemoInput(BaseDemoInput): def _get_data(self, model, input_shape=None, training=None): - if input_shape is None: - input_shape = self.input_shape - if training is None: - training = self.training data = self._get_mm_data(model, input_shape, training) data['mode'] = 'tensor' @@ -94,3 +91,8 @@ def _get_mm_data(self, model, input_shape, training=False): data = demo_mm_inputs(1, [input_shape[1:]], use_box_type=True) data = model.data_preprocessor(data, training) return data + + +@TASK_UTILS.register_module() +class DefaultMMYoloDemoInput(DefaultMMDetDemoInput): + default_shape = (1, 3, 125, 320) diff --git a/mmrazor/models/task_modules/tracer/prune_tracer.py b/mmrazor/models/task_modules/tracer/prune_tracer.py index 7cb1abf97..f55dedf25 100644 --- a/mmrazor/models/task_modules/tracer/prune_tracer.py +++ b/mmrazor/models/task_modules/tracer/prune_tracer.py @@ -76,7 +76,8 @@ def __init__(self, self.tracer_type = tracer_type if tracer_type == 'BackwardTracer': self.tracer = BackwardTracer( - loss_calculator=SumPseudoLoss(input_shape=demo_input)) + loss_calculator=SumPseudoLoss( + input_shape=self.demo_input.input_shape)) elif tracer_type == 'FxTracer': self.tracer = CustomFxTracer(leaf_module=self.default_leaf_modules) else: diff --git a/tools/tests/test_tools.py b/tools/tests/test_tools.py index 50ca79de9..44cebb8ee 100644 --- a/tools/tests/test_tools.py +++ b/tools/tests/test_tools.py @@ -2,14 +2,26 @@ import os import shutil import subprocess -from typing import List from unittest import TestCase -config_paths: List = [ - 'mmcls::resnet/resnet34_8xb32_in1k.py', - 'mmdet::retinanet/retinanet_r18_fpn_1x_coco.py', - 'mmseg::deeplabv3plus/deeplabv3plus_r50-d8_4xb4-20k_voc12aug-512x512.py', -] +config_paths = [] + + +def add_config_path(repo_name, path): + try: + __import__(repo_name) + config_paths.append(path) + except Exception: + pass + + +add_config_path('mmcls', 'mmcls::resnet/resnet34_8xb32_in1k.py') +add_config_path('mmdet', 'mmdet::retinanet/retinanet_r18_fpn_1x_coco.py') +add_config_path( + 'mmseg', + 'mmseg::deeplabv3plus/deeplabv3plus_r50-d8_4xb4-20k_voc12aug-512x512.py') +add_config_path( + 'mmyolo', 'mmyolo::yolov5/yolov5_m-p6-v62_syncbn_fast_8xb16-300e_coco.py') class TestTools(TestCase): From d06261ceb20d5e1e592c6b9d3a58ca1bb2f6add5 Mon Sep 17 00:00:00 2001 From: liukai Date: Thu, 17 Nov 2022 10:22:47 +0800 Subject: [PATCH 2/9] rebase exp-rpuning From 79660a308b76f3562ee773352ead79f96829527a Mon Sep 17 00:00:00 2001 From: liukai Date: Thu, 17 Nov 2022 12:44:23 +0800 Subject: [PATCH 3/9] fix bug --- mmrazor/models/task_modules/demo_inputs/demo_inputs.py | 1 + .../task_modules/estimators/counters/flops_params_counter.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mmrazor/models/task_modules/demo_inputs/demo_inputs.py b/mmrazor/models/task_modules/demo_inputs/demo_inputs.py index f0f7f685f..a3c498b2d 100644 --- a/mmrazor/models/task_modules/demo_inputs/demo_inputs.py +++ b/mmrazor/models/task_modules/demo_inputs/demo_inputs.py @@ -55,6 +55,7 @@ def _get_mm_data(self, model, input_shape, training=False): 'data_samples': [ClsDataSample().set_gt_label(1) for _ in range(input_shape[0])], } + mm_inputs = model.data_preprocessor(mm_inputs, training) return mm_inputs diff --git a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py index 26c3af5e8..b8a522441 100644 --- a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py @@ -85,7 +85,10 @@ def get_model_flops_params(model, if isinstance(input_constructor, dict): input_constructor = TASK_UTILS.build(input_constructor) input = input_constructor(model, input_shape) - _ = flops_params_model(**input) + if isinstance(input, dict): + _ = flops_params_model(**input) + else: + flops_params_model(input) else: try: batch = torch.ones(()).new_empty( From 1f88ac463ad6a448607b7cca36d022dedbcc4975 Mon Sep 17 00:00:00 2001 From: liukai Date: Fri, 18 Nov 2022 09:35:14 +0800 Subject: [PATCH 4/9] fix bug in prune_evolution_searcher --- .../runner/prune_evolution_search_loop.py | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/mmrazor/engine/runner/prune_evolution_search_loop.py b/mmrazor/engine/runner/prune_evolution_search_loop.py index 1d6b60448..b9ac97d6d 100644 --- a/mmrazor/engine/runner/prune_evolution_search_loop.py +++ b/mmrazor/engine/runner/prune_evolution_search_loop.py @@ -116,6 +116,11 @@ def __init__(self, self.min_flops = self._min_flops() assert self.min_flops < self.flops_range[0], 'Cannot reach flop targe.' + if self.runner.distributed: + self.search_wrapper = runner.model.module + else: + self.search_wrapper = runner.model + def _min_flops(self): subnet = self.model.sample_subnet() for key in subnet: @@ -194,20 +199,22 @@ def _save_best_fix_subnet(self): @torch.no_grad() def _val_candidate(self) -> Dict: # bn rescale - len_img = 0 - self.runner.model.train() + max_iter = 100 + iter = 0 + self.search_wrapper.train() for _, data_batch in enumerate(self.bn_dataloader): - data = self.runner.model.data_preprocessor(data_batch, True) - self.runner.model._run_forward(data, mode='tensor') # type: ignore - len_img += len(data_batch['data_samples']) - if len_img > 1000: + data = self.search_wrapper.data_preprocessor(data_batch, True) + self.search_wrapper._run_forward( + data, mode='tensor') # type: ignore + iter += 1 + if iter > max_iter: break return super()._val_candidate() def _scale_and_check_subnet_constraints( self, random_subnet: SupportRandomSubnet, - auto_scale_times=5) -> Tuple[bool, SupportRandomSubnet]: + auto_scale_times=20) -> Tuple[bool, SupportRandomSubnet]: """Check whether is beyond constraints. Returns: @@ -225,7 +232,14 @@ def _scale_and_check_subnet_constraints( random_subnet, (self.flops_range[1] + self.flops_range[0]) / 2, flops) continue - + if is_pass: + from mmengine import MMLogger + MMLogger.get_current_instance().info( + f'sample a net,{flops},{self.flops_range}') + else: + from mmengine import MMLogger + MMLogger.get_current_instance().info( + f'sample a net failed,{flops},{self.flops_range}') return is_pass, random_subnet def _update_flop_range(self): From 9f28037dda88c2748d6afeed4d807d8758c72e8a Mon Sep 17 00:00:00 2001 From: liukai Date: Fri, 18 Nov 2022 10:29:36 +0800 Subject: [PATCH 5/9] add loss mode for prune search --- .../runner/prune_evolution_search_loop.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/mmrazor/engine/runner/prune_evolution_search_loop.py b/mmrazor/engine/runner/prune_evolution_search_loop.py index b9ac97d6d..50e9a0aa3 100644 --- a/mmrazor/engine/runner/prune_evolution_search_loop.py +++ b/mmrazor/engine/runner/prune_evolution_search_loop.py @@ -209,7 +209,10 @@ def _val_candidate(self) -> Dict: iter += 1 if iter > max_iter: break - return super()._val_candidate() + if self.score_key == 'loss': + return self._val_by_loss() + else: + return super()._val_candidate() def _scale_and_check_subnet_constraints( self, @@ -251,3 +254,16 @@ def _update_flop_range(self): def check_subnet_flops(self, flops): return self.flops_range[0] <= flops <= self.flops_range[ 1] # type: ignore + + def _val_by_loss(self): + from mmengine.dist import all_reduce + self.runner.model.eval() + sum_loss = 0 + for data_batch in self.dataloader: + data = self.search_wrapper.data_preprocessor(data_batch) + losses = self.search_wrapper.forward(**data, mode='loss') + parsed_losses, _ = self.search_wrapper.parse_losses( + losses) # type: ignore + sum_loss = sum_loss + parsed_losses + all_reduce(sum_loss) + return {'loss': sum_loss.item()} From f71f5817e0ad989945d102f75538285426a67fc8 Mon Sep 17 00:00:00 2001 From: liukai Date: Fri, 18 Nov 2022 10:45:44 +0800 Subject: [PATCH 6/9] fix bug in search using loss --- mmrazor/engine/runner/prune_evolution_search_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmrazor/engine/runner/prune_evolution_search_loop.py b/mmrazor/engine/runner/prune_evolution_search_loop.py index 50e9a0aa3..6d3ff2754 100644 --- a/mmrazor/engine/runner/prune_evolution_search_loop.py +++ b/mmrazor/engine/runner/prune_evolution_search_loop.py @@ -266,4 +266,4 @@ def _val_by_loss(self): losses) # type: ignore sum_loss = sum_loss + parsed_losses all_reduce(sum_loss) - return {'loss': sum_loss.item()} + return {'loss': sum_loss.item() * -1} From 29bb9a194458c9cdc33913e28ac40e16b51f7789 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 18 Nov 2022 17:31:56 +0800 Subject: [PATCH 7/9] tmp --- .../task_modules/demo_inputs/demo_inputs.py | 2 +- .../models/task_modules/tracer/fx_tracer.py | 1 + .../task_modules/tracer/prune_tracer.py | 1 + tests/data/model_library.py | 16 +- tests/data/tracer_passed_models.py | 8 +- .../test_graph/test_prune_tracer_model.py | 161 ++++++++++++++++++ 6 files changed, 179 insertions(+), 10 deletions(-) create mode 100644 tests/test_core/test_graph/test_prune_tracer_model.py diff --git a/mmrazor/models/task_modules/demo_inputs/demo_inputs.py b/mmrazor/models/task_modules/demo_inputs/demo_inputs.py index a3c498b2d..a9db12e40 100644 --- a/mmrazor/models/task_modules/demo_inputs/demo_inputs.py +++ b/mmrazor/models/task_modules/demo_inputs/demo_inputs.py @@ -67,7 +67,7 @@ def _get_mm_data(self, model, input_shape, training=False): from mmdet.testing._utils import demo_mm_inputs assert isinstance(model, BaseDetector) - data = demo_mm_inputs(1, [input_shape[1:]]) + data = demo_mm_inputs(1, [input_shape[1:]], with_mask=True) data = model.data_preprocessor(data, training) return data diff --git a/mmrazor/models/task_modules/tracer/fx_tracer.py b/mmrazor/models/task_modules/tracer/fx_tracer.py index 439f7ed4e..52bda0422 100644 --- a/mmrazor/models/task_modules/tracer/fx_tracer.py +++ b/mmrazor/models/task_modules/tracer/fx_tracer.py @@ -36,6 +36,7 @@ def __init__(self, } self.warp_fn = { torch: torch.arange, + torch: torch.linspace, } def trace(self, diff --git a/mmrazor/models/task_modules/tracer/prune_tracer.py b/mmrazor/models/task_modules/tracer/prune_tracer.py index f55dedf25..58cd0ff80 100644 --- a/mmrazor/models/task_modules/tracer/prune_tracer.py +++ b/mmrazor/models/task_modules/tracer/prune_tracer.py @@ -120,6 +120,7 @@ def _fx_trace(self, model): args = self.demo_input.get_data(model) if isinstance(args, dict): args.pop('inputs') + args['mode'] = 'tensor' return self.tracer.trace(model, concrete_args=args) else: return self.tracer.trace(model) diff --git a/tests/data/model_library.py b/tests/data/model_library.py index 6119f0c0f..71979b2bb 100644 --- a/tests/data/model_library.py +++ b/tests/data/model_library.py @@ -61,6 +61,7 @@ def __call__(self, *args, **kwargs): def init_model(self): self._model = self.model_src() + return self._model def forward(self, x): assert self._model is not None @@ -92,6 +93,10 @@ def get_short_name(cls, name: str): def short_name(self): return self.__class__.get_short_name(self.name) + @property + def scope(self): + return self.name.split('.')[0] + class MMModelGenerator(ModelGenerator): @@ -378,14 +383,15 @@ class MMClsModelLibrary(MMModelLibrary): 'seresnet', 'repvgg', 'seresnext', - 'deit' + 'deit', ] base_config_path = '_base_/models/' repo = 'mmcls' - def __init__(self, - include=default_includes, - exclude=['cutmix', 'cifar', 'gem']) -> None: + def __init__( + self, + include=default_includes, + exclude=['cutmix', 'cifar', 'gem', 'efficientformer']) -> None: super().__init__(include=include, exclude=exclude) @@ -506,7 +512,7 @@ def _config_process(cls, config: Dict): @classmethod def generator_type(cls): - return MMDetModelGenerator + return MMModelGenerator class MMSegModelLibrary(MMModelLibrary): diff --git a/tests/data/tracer_passed_models.py b/tests/data/tracer_passed_models.py index c07584d19..f7ff4f893 100644 --- a/tests/data/tracer_passed_models.py +++ b/tests/data/tracer_passed_models.py @@ -154,11 +154,9 @@ def mmdet_library(cls): 'fcos', 'yolo', 'gfl', - 'simple', 'lvis', 'selfsup', 'solo', - 'soft', 'instaboost', 'point', 'pafpn', @@ -177,14 +175,16 @@ def mmdet_library(cls): 'foveabox', 'resnet', 'cityscapes', - 'timm', 'atss', 'dynamic', - 'panoptic', 'solov2', 'fsaf', 'double', 'cornernet', + # 'panoptic', + # 'simple', + # 'timm', + # 'soft', # 'vfnet', # error # 'carafe', # error # 'sparse', # error diff --git a/tests/test_core/test_graph/test_prune_tracer_model.py b/tests/test_core/test_graph/test_prune_tracer_model.py new file mode 100644 index 000000000..cd484538e --- /dev/null +++ b/tests/test_core/test_graph/test_prune_tracer_model.py @@ -0,0 +1,161 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import multiprocessing as mp +import os +import signal +import sys +import time +from contextlib import contextmanager +from functools import partial +from typing import List +from unittest import TestCase + +import torch +import torch.nn as nn + +from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin +from mmrazor.models.mutables.mutable_channel.units import \ + SequentialMutableChannelUnit +from mmrazor.models.task_modules.tracer.backward_tracer import BackwardTracer +from mmrazor.models.task_modules.tracer.fx_tracer import CustomFxTracer +from mmrazor.models.task_modules.tracer.prune_tracer import PruneTracer +from mmrazor.models.task_modules.tracer.razor_tracer import (FxBaseNode, + RazorFxTracer) +from mmrazor.structures.graph import BaseGraph, ModuleGraph +from mmrazor.structures.graph.channel_graph import ( + ChannelGraph, default_channel_node_converter) +from mmrazor.structures.graph.module_graph import (FxTracerToGraphConverter, + PathToGraphConverter) +from ...data.model_library import ModelGenerator +from ...data.tracer_passed_models import (PassedModelManager, + backward_passed_library, + fx_passed_library) +from ...utils import SetTorchThread + +sys.setrecursionlimit(int(pow(2, 20))) +# test config +from mmrazor.models.task_modules.tracer import PruneTracer + +DEVICE = torch.device('cpu') +FULL_TEST = os.getenv('FULL_TEST') == 'true' +MP = os.getenv('MP') == 'true' + +DEBUG = os.getenv('DEBUG') == 'true' + +if MP: + POOL_SIZE = mp.cpu_count() + TORCH_THREAD_SIZE = 1 + # torch.set_num_interop_threads(1) +else: + POOL_SIZE = 1 + TORCH_THREAD_SIZE = -1 + +print(f'DEBUG: {DEBUG}') +print(f'FULL_TEST: {FULL_TEST}') +print(f'POOL_SIZE: {POOL_SIZE}') +print(f'TORCH_THREAD_SIZE: {TORCH_THREAD_SIZE}') + +# tools for tesing + +# test functions for mp + + +def _test_a_model(Model, tracer_type='fx'): + start = time.time() + + try: + model = Model.init_model() + model.eval() + if tracer_type == 'fx': + tracer_type = 'FxTracer' + elif tracer_type == 'backward': + tracer_type = 'BackwardTracer' + else: + raise NotImplementedError() + + unit_configs = PruneTracer( + tracer_type=tracer_type, + demo_input={ + 'type': 'DefaultDemoInput', + 'scope': Model.scope + }).trace(model) + out = len(unit_configs) + print(f'test {Model} successful.') + return Model.name, True, '', time.time() - start, out + except Exception as e: + if DEBUG: + raise e + else: + print(f'test {Model} failed.') + return Model.name, False, f'{e}', time.time() - start, -1 + + +# TestCase + + +class TestTraceModel(TestCase): + + def test_init_from_fx_tracer(self) -> None: + TestData = fx_passed_library.include_models(FULL_TEST) + with SetTorchThread(TORCH_THREAD_SIZE): + if POOL_SIZE != 1: + with mp.Pool(POOL_SIZE) as p: + result = p.map( + partial(_test_a_model, tracer_type='fx'), TestData) + else: + result = map( + partial(_test_a_model, tracer_type='fx'), TestData) + self.report(result, fx_passed_library, 'fx') + + def test_init_from_backward_tracer(self) -> None: + TestData = backward_passed_library.include_models(FULL_TEST) + with SetTorchThread(TORCH_THREAD_SIZE): + if POOL_SIZE != 1: + with mp.Pool(POOL_SIZE) as p: + result = p.map( + partial(_test_a_model, tracer_type='backward'), + TestData) + else: + result = map( + partial(_test_a_model, tracer_type='fx'), TestData) + self.report(result, backward_passed_library, 'backward') + + def report(self, result, model_manager: PassedModelManager, fx_type='fx'): + print() + print(f'Trace model summary using {fx_type} tracer.') + + passd_test = [res for res in result if res[1] is True] + unpassd_test = [res for res in result if res[1] is False] + + # long summary + + print(f'{len(passd_test)},{len(unpassd_test)},' + f'{len(model_manager.uninclude_models(full_test=FULL_TEST))}') + + print('Passed:') + print('\tmodel\ttime\tlen(mutable)') + for model, passed, msg, used_time, out in passd_test: + with self.subTest(model=model): + print(f'\t{model}\t{int(used_time)}s\t{out}') + self.assertTrue(passed, msg) + + print('UnPassed:') + for model, passed, msg, used_time, out in unpassd_test: + with self.subTest(model=model): + print(f'\t{model}\t{int(used_time)}s\t{out}') + print(f'\t\t{msg}') + self.assertTrue(passed, msg) + + print('UnTest:') + for model in model_manager.uninclude_models(full_test=FULL_TEST): + print(f'\t{model}') + + # short summary + print('Short Summary:') + short_passed = set( + [ModelGenerator.get_short_name(res[0]) for res in passd_test]) + + print('Passed\n', short_passed) + + short_unpassed = set( + [ModelGenerator.get_short_name(res[0]) for res in unpassd_test]) + print('Unpassed\n', short_unpassed) From 3fa1e252efa73e8099a1a1acb4482b9d99c5c6a2 Mon Sep 17 00:00:00 2001 From: jacky Date: Fri, 18 Nov 2022 17:43:22 +0800 Subject: [PATCH 8/9] add instruction for shikeying --- ...42\350\220\245\346\214\207\345\215\227.md" | 35 +++++++++++++++++++ .../task_modules/tracer/prune_tracer.py | 10 +++--- 2 files changed, 40 insertions(+), 5 deletions(-) create mode 100644 "docs/zh_cn/\350\247\206\345\256\242\350\220\245\346\214\207\345\215\227.md" diff --git "a/docs/zh_cn/\350\247\206\345\256\242\350\220\245\346\214\207\345\215\227.md" "b/docs/zh_cn/\350\247\206\345\256\242\350\220\245\346\214\207\345\215\227.md" new file mode 100644 index 000000000..1b670d135 --- /dev/null +++ "b/docs/zh_cn/\350\247\206\345\256\242\350\220\245\346\214\207\345\215\227.md" @@ -0,0 +1,35 @@ +# 如何适配trace新模型 + +首先我们提供一个单元测试:tests/test_core/test_graph/test_prune_tracer_model.py。 +在这个文件中,我们自动的对模型进行测试,判断是否能够顺利通过PruneTracer。 + +## 添加新模型 + +所有定义的新模型在tests/data/tracer_passed_models.py中定义,当我们需要添加新的repo时,我们需要实现相应的ModelLibrary。再将ModelLibrary放在tracer_passed_models中。 + +模型定义完成后,我们的单元测试便可对这些模型进行测试。 + +## 简单的模型适配 + +当我们发现一些模型无法通过tracer时,我们有一些简单的适配方法。 + +where to config prune tracer +""" + +- How to config PruneTracer using hard code + - fxtracer + - demo_inputs + ./mmrazor\\models\\task_modules\\demo_inputs\\default_demo_inputs.py + - leaf module + - PruneTracer.default_leaf_modules + - method + - .\\mmrazor\\models\\task_modules\\tracer\\fx_tracer.py + - ChannelNode + - .\\mmrazor\\structures\\graph\\channel_nodes.py + - DynamicOp + .\\mmrazor\\models\\architectures\\dynamic_ops\\bricks\\dynamic_conv.py + """ + +## 文档 + +./docs/en/user_guides/pruning_user_guide.md diff --git a/mmrazor/models/task_modules/tracer/prune_tracer.py b/mmrazor/models/task_modules/tracer/prune_tracer.py index 58cd0ff80..44dfcaece 100644 --- a/mmrazor/models/task_modules/tracer/prune_tracer.py +++ b/mmrazor/models/task_modules/tracer/prune_tracer.py @@ -27,16 +27,16 @@ """ - How to config PruneTracer using hard code - fxtracer - - concrete args - - demo_inputs + - demo_inputs + ./mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py - leaf module - PruneTracer.default_leaf_modules - method - - None + - ./mmrazor/models/task_modules/tracer/fx_tracer.py - ChannelNode - - channel_nodes.py + - ./mmrazor/structures/graph/channel_nodes.py - DynamicOp - ChannelUnits + ./mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py """ # concrete args From 185116bb40f6d9efd2661ad2e636a93c17de8402 Mon Sep 17 00:00:00 2001 From: liukai Date: Sun, 20 Nov 2022 21:07:02 +0800 Subject: [PATCH 9/9] update md --- ...\256\242\350\220\245\346\214\207\345\215\227.md" | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git "a/docs/zh_cn/\350\247\206\345\256\242\350\220\245\346\214\207\345\215\227.md" "b/docs/zh_cn/\350\247\206\345\256\242\350\220\245\346\214\207\345\215\227.md" index 1b670d135..a475a3db1 100644 --- "a/docs/zh_cn/\350\247\206\345\256\242\350\220\245\346\214\207\345\215\227.md" +++ "b/docs/zh_cn/\350\247\206\345\256\242\350\220\245\346\214\207\345\215\227.md" @@ -14,21 +14,16 @@ 当我们发现一些模型无法通过tracer时,我们有一些简单的适配方法。 where to config prune tracer -""" - How to config PruneTracer using hard code - fxtracer - - demo_inputs - ./mmrazor\\models\\task_modules\\demo_inputs\\default_demo_inputs.py + - [default_demo_inputs.py](/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py) - leaf module - - PruneTracer.default_leaf_modules - - method - - .\\mmrazor\\models\\task_modules\\tracer\\fx_tracer.py + - [prune_tracer.py](/mmrazor/models/task_modules/tracer/prune_tracer.py)::default_leaf_modules - ChannelNode - - .\\mmrazor\\structures\\graph\\channel_nodes.py + - [channel_nodes.py](./mmrazor/structures/graph/channel_nodes.py) - DynamicOp - .\\mmrazor\\models\\architectures\\dynamic_ops\\bricks\\dynamic_conv.py - """ + [dynamic_conv.py](/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py) ## 文档