Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update exp-pruning #352

Merged
merged 9 commits into from
Nov 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/zh_cn/视客营指南.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# 如何适配trace新模型

首先我们提供一个单元测试:tests/test_core/test_graph/test_prune_tracer_model.py。
在这个文件中,我们自动的对模型进行测试,判断是否能够顺利通过PruneTracer。

## 添加新模型

所有定义的新模型在tests/data/tracer_passed_models.py中定义,当我们需要添加新的repo时,我们需要实现相应的ModelLibrary。再将ModelLibrary放在tracer_passed_models中。

模型定义完成后,我们的单元测试便可对这些模型进行测试。

## 简单的模型适配

当我们发现一些模型无法通过tracer时,我们有一些简单的适配方法。

where to config prune tracer

- How to config PruneTracer using hard code
- fxtracer
- [default_demo_inputs.py](/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py)
- leaf module
- [prune_tracer.py](/mmrazor/models/task_modules/tracer/prune_tracer.py)::default_leaf_modules
- ChannelNode
- [channel_nodes.py](./mmrazor/structures/graph/channel_nodes.py)
- DynamicOp
[dynamic_conv.py](/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py)

## 文档

./docs/en/user_guides/pruning_user_guide.md
48 changes: 39 additions & 9 deletions mmrazor/engine/runner/prune_evolution_search_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ def __init__(self,
self.min_flops = self._min_flops()
assert self.min_flops < self.flops_range[0], 'Cannot reach flop targe.'

if self.runner.distributed:
self.search_wrapper = runner.model.module
else:
self.search_wrapper = runner.model

def _min_flops(self):
subnet = self.model.sample_subnet()
for key in subnet:
Expand Down Expand Up @@ -194,20 +199,25 @@ def _save_best_fix_subnet(self):
@torch.no_grad()
def _val_candidate(self) -> Dict:
# bn rescale
len_img = 0
self.runner.model.train()
max_iter = 100
iter = 0
self.search_wrapper.train()
for _, data_batch in enumerate(self.bn_dataloader):
data = self.runner.model.data_preprocessor(data_batch, True)
self.runner.model._run_forward(data, mode='tensor') # type: ignore
len_img += len(data_batch['data_samples'])
if len_img > 1000:
data = self.search_wrapper.data_preprocessor(data_batch, True)
self.search_wrapper._run_forward(
data, mode='tensor') # type: ignore
iter += 1
if iter > max_iter:
break
return super()._val_candidate()
if self.score_key == 'loss':
return self._val_by_loss()
else:
return super()._val_candidate()

def _scale_and_check_subnet_constraints(
self,
random_subnet: SupportRandomSubnet,
auto_scale_times=5) -> Tuple[bool, SupportRandomSubnet]:
auto_scale_times=20) -> Tuple[bool, SupportRandomSubnet]:
"""Check whether is beyond constraints.

Returns:
Expand All @@ -225,7 +235,14 @@ def _scale_and_check_subnet_constraints(
random_subnet,
(self.flops_range[1] + self.flops_range[0]) / 2, flops)
continue

if is_pass:
from mmengine import MMLogger
MMLogger.get_current_instance().info(
f'sample a net,{flops},{self.flops_range}')
else:
from mmengine import MMLogger
MMLogger.get_current_instance().info(
f'sample a net failed,{flops},{self.flops_range}')
return is_pass, random_subnet

def _update_flop_range(self):
Expand All @@ -237,3 +254,16 @@ def _update_flop_range(self):
def check_subnet_flops(self, flops):
return self.flops_range[0] <= flops <= self.flops_range[
1] # type: ignore

def _val_by_loss(self):
from mmengine.dist import all_reduce
self.runner.model.eval()
sum_loss = 0
for data_batch in self.dataloader:
data = self.search_wrapper.data_preprocessor(data_batch)
losses = self.search_wrapper.forward(**data, mode='loss')
parsed_losses, _ = self.search_wrapper.parse_losses(
losses) # type: ignore
sum_loss = sum_loss + parsed_losses
all_reduce(sum_loss)
return {'loss': sum_loss.item() * -1}
26 changes: 17 additions & 9 deletions mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from mmrazor.utils import get_placeholder
from .demo_inputs import (BaseDemoInput, DefaultMMClsDemoInput,
DefaultMMDemoInput, DefaultMMDetDemoInput,
DefaultMMRotateDemoInput, DefaultMMSegDemoInput)
DefaultMMRotateDemoInput, DefaultMMSegDemoInput,
DefaultMMYoloDemoInput)

try:
from mmdet.models import BaseDetector
Expand Down Expand Up @@ -35,32 +36,39 @@
'mmdet': DefaultMMDetDemoInput,
'mmseg': DefaultMMSegDemoInput,
'mmrotate': DefaultMMRotateDemoInput,
'mmyolo': DefaultMMYoloDemoInput,
'torchvision': BaseDemoInput,
}


def defaul_demo_inputs(model, input_shape, training=False, scope=None):
def get_default_demo_input_class(model, scope):

if scope is not None:
for scope_name, demo_input in default_demo_input_class_for_scope.items(
):
if scope == scope_name:
return demo_input(input_shape, training).get_data(model)
return demo_input

for module_type, demo_input in default_demo_input_class.items( # noqa
): # noqa
if isinstance(model, module_type):
return demo_input(input_shape, training).get_data(model)
return demo_input
# default
return BaseDemoInput(input_shape, training).get_data(model)
return BaseDemoInput


def defaul_demo_inputs(model, input_shape, training=False, scope=None):
demo_input = get_default_demo_input_class(model, scope)
return demo_input().get_data(model, input_shape, training)


@TASK_UTILS.register_module()
class DefaultDemoInput(BaseDemoInput):

def __init__(self,
input_shape=[1, 3, 224, 224],
training=False,
scope=None) -> None:
def __init__(self, input_shape=None, training=False, scope=None) -> None:
default_demo_input_class = get_default_demo_input_class(None, scope)
if input_shape is None:
input_shape = default_demo_input_class.default_shape
super().__init__(input_shape, training)
self.scope = scope

Expand Down
15 changes: 9 additions & 6 deletions mmrazor/models/task_modules/demo_inputs/demo_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

@TASK_UTILS.register_module()
class BaseDemoInput():
default_shape = (1, 3, 224, 224)

def __init__(self, input_shape=[1, 3, 224, 224], training=False) -> None:
def __init__(self, input_shape=default_shape, training=None) -> None:
self.input_shape = input_shape
self.training = training

Expand All @@ -33,10 +34,6 @@ def __call__(self,
class DefaultMMDemoInput(BaseDemoInput):

def _get_data(self, model, input_shape=None, training=None):
if input_shape is None:
input_shape = self.input_shape
if training is None:
training = self.training

data = self._get_mm_data(model, input_shape, training)
data['mode'] = 'tensor'
Expand All @@ -58,6 +55,7 @@ def _get_mm_data(self, model, input_shape, training=False):
'data_samples':
[ClsDataSample().set_gt_label(1) for _ in range(input_shape[0])],
}
mm_inputs = model.data_preprocessor(mm_inputs, training)
return mm_inputs


Expand All @@ -69,7 +67,7 @@ def _get_mm_data(self, model, input_shape, training=False):
from mmdet.testing._utils import demo_mm_inputs
assert isinstance(model, BaseDetector)

data = demo_mm_inputs(1, [input_shape[1:]])
data = demo_mm_inputs(1, [input_shape[1:]], with_mask=True)
data = model.data_preprocessor(data, training)
return data

Expand All @@ -94,3 +92,8 @@ def _get_mm_data(self, model, input_shape, training=False):
data = demo_mm_inputs(1, [input_shape[1:]], use_box_type=True)
data = model.data_preprocessor(data, training)
return data


@TASK_UTILS.register_module()
class DefaultMMYoloDemoInput(DefaultMMDetDemoInput):
default_shape = (1, 3, 125, 320)
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ def get_model_flops_params(model,
if isinstance(input_constructor, dict):
input_constructor = TASK_UTILS.build(input_constructor)
input = input_constructor(model, input_shape)
_ = flops_params_model(**input)
if isinstance(input, dict):
_ = flops_params_model(**input)
else:
flops_params_model(input)
else:
try:
batch = torch.ones(()).new_empty(
Expand Down
1 change: 1 addition & 0 deletions mmrazor/models/task_modules/tracer/fx_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self,
}
self.warp_fn = {
torch: torch.arange,
torch: torch.linspace,
}

def trace(self,
Expand Down
14 changes: 8 additions & 6 deletions mmrazor/models/task_modules/tracer/prune_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@
"""
- How to config PruneTracer using hard code
- fxtracer
- concrete args
- demo_inputs
- demo_inputs
./mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py
- leaf module
- PruneTracer.default_leaf_modules
- method
- None
- ./mmrazor/models/task_modules/tracer/fx_tracer.py
- ChannelNode
- channel_nodes.py
- ./mmrazor/structures/graph/channel_nodes.py
- DynamicOp
ChannelUnits
./mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py
"""

# concrete args
Expand Down Expand Up @@ -76,7 +76,8 @@ def __init__(self,
self.tracer_type = tracer_type
if tracer_type == 'BackwardTracer':
self.tracer = BackwardTracer(
loss_calculator=SumPseudoLoss(input_shape=demo_input))
loss_calculator=SumPseudoLoss(
input_shape=self.demo_input.input_shape))
elif tracer_type == 'FxTracer':
self.tracer = CustomFxTracer(leaf_module=self.default_leaf_modules)
else:
Expand Down Expand Up @@ -119,6 +120,7 @@ def _fx_trace(self, model):
args = self.demo_input.get_data(model)
if isinstance(args, dict):
args.pop('inputs')
args['mode'] = 'tensor'
return self.tracer.trace(model, concrete_args=args)
else:
return self.tracer.trace(model)
Expand Down
16 changes: 11 additions & 5 deletions tests/data/model_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __call__(self, *args, **kwargs):

def init_model(self):
self._model = self.model_src()
return self._model

def forward(self, x):
assert self._model is not None
Expand Down Expand Up @@ -92,6 +93,10 @@ def get_short_name(cls, name: str):
def short_name(self):
return self.__class__.get_short_name(self.name)

@property
def scope(self):
return self.name.split('.')[0]


class MMModelGenerator(ModelGenerator):

Expand Down Expand Up @@ -378,14 +383,15 @@ class MMClsModelLibrary(MMModelLibrary):
'seresnet',
'repvgg',
'seresnext',
'deit'
'deit',
]
base_config_path = '_base_/models/'
repo = 'mmcls'

def __init__(self,
include=default_includes,
exclude=['cutmix', 'cifar', 'gem']) -> None:
def __init__(
self,
include=default_includes,
exclude=['cutmix', 'cifar', 'gem', 'efficientformer']) -> None:
super().__init__(include=include, exclude=exclude)


Expand Down Expand Up @@ -506,7 +512,7 @@ def _config_process(cls, config: Dict):

@classmethod
def generator_type(cls):
return MMDetModelGenerator
return MMModelGenerator


class MMSegModelLibrary(MMModelLibrary):
Expand Down
8 changes: 4 additions & 4 deletions tests/data/tracer_passed_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,9 @@ def mmdet_library(cls):
'fcos',
'yolo',
'gfl',
'simple',
'lvis',
'selfsup',
'solo',
'soft',
'instaboost',
'point',
'pafpn',
Expand All @@ -177,14 +175,16 @@ def mmdet_library(cls):
'foveabox',
'resnet',
'cityscapes',
'timm',
'atss',
'dynamic',
'panoptic',
'solov2',
'fsaf',
'double',
'cornernet',
# 'panoptic',
# 'simple',
# 'timm',
# 'soft',
# 'vfnet', # error
# 'carafe', # error
# 'sparse', # error
Expand Down
Loading