Skip to content

Commit

Permalink
add ut about custom tracer
Browse files Browse the repository at this point in the history
  • Loading branch information
humu789 committed Jan 16, 2023
1 parent a38af46 commit c5f54c1
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 24 deletions.
48 changes: 24 additions & 24 deletions mmrazor/models/task_modules/tracer/fx/custom_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,30 +74,6 @@ def method(*args, **kwargs):
return wrapped_method


def custom_symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule:
"""Modified `symbolic_trace` function in pytorch. Given an ``nn.Module`` or
function instance ``root``, this function will return a ``GraphModule``
constructed by recording operations seen while tracing through ``root``.
Args:
root (torch.nn.Module): Module or function to be
traced and converted into a Graph representation.
concrete_args (Optional[Dict[str, any]]): Inputs to be partially
specialized.
Returns:
GraphModule: a Module created from the recorded operations from
``root``.
"""
tracer = CustomTracer()
graph = tracer.trace(root, concrete_args)
name = root.__class__.__name__ if isinstance(root,
nn.Module) else root.__name__
return GraphModule(tracer.root, graph, name)


def _prepare_module_dict(model: torch.nn.Module, fx_graph: torch.fx.Graph):
"""If there is a class method that can not be traced by the symbolic
tracer, a ``call_method`` ``Node`` will be inserted into the ``Graph`` in
Expand Down Expand Up @@ -444,3 +420,27 @@ def is_leaf_module(self, m: torch.nn.Module,
"""
leaf = super().is_leaf_module(m, module_qualified_name)
return leaf


def custom_symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule:
"""Modified `symbolic_trace` function in pytorch. Given an ``nn.Module`` or
function instance ``root``, this function will return a ``GraphModule``
constructed by recording operations seen while tracing through ``root``.
Args:
root (torch.nn.Module): Module or function to be
traced and converted into a Graph representation.
concrete_args (Optional[Dict[str, any]]): Inputs to be partially
specialized.
Returns:
GraphModule: a Module created from the recorded operations from
``root``.
"""
tracer = CustomTracer()
graph = tracer.trace(root, concrete_args)
name = root.__class__.__name__ if isinstance(
root, torch.nn.Module) else root.__name__
return GraphModule(tracer.root, graph, name)
2 changes: 2 additions & 0 deletions tests/data/test_models/test_task_modules/mmcls_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) OpenMMLab. All rights reserved.
_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py']
158 changes: 158 additions & 0 deletions tests/test_models/test_task_modules/test_custom_tracer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import torch.nn as nn
from mmcls.models.backbones.resnet import ResLayer
from mmengine.config import Config
from mmengine.registry import MODELS
from torch.fx import GraphModule
from torch.fx._symbolic_trace import Graph

from mmrazor.models.task_modules import (CustomTracer, UntracedMethodRegistry,
build_graphmodule,
custom_symbolic_trace)
from mmrazor.models.task_modules.tracer.fx.custom_tracer import \
_prepare_module_dict


class ToyModel(nn.Module):

def __init__(self):
super.__init__()

def get_loss(self, x):
return x * 0.1

def extrac_feature(self, x):
return x * 2

def forward(self, x):
x = self.extrac_feature(x)
x = self.get_loss(x)
return x


class testUntracedMethodRgistry(TestCase):

def test_init(self):
method = ToyModel.get_loss
method_registry = UntracedMethodRegistry(method)
assert hasattr(method_registry, 'method')
assert hasattr(method_registry, 'method_dict')
assert len(method_registry.method_dict) == 0

def test_registry_method(self):
model = ToyModel
method = ToyModel.get_loss
method_registry = UntracedMethodRegistry(method)
method_registry.__set_name__(model, 'get_loss')
assert 'get_loss' in method_registry.method_dict.keys()
assert method_registry.method_dict['get_loss']['mod'] == model


class testCustomTracer(TestCase):

def setUp(self):
self.cfg = Config.fromfile(
'tests/data/test_models/test_task_modules/mmcls_cfg.py')
self.skipped_methods = [
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
]
self.skipped_module_names = ['backbone.layer4.0']
self.skipped_module_classes = [ResLayer]

def test_init(self):
# init without skipped_methods
tracer = CustomTracer()
assert hasattr(tracer, 'skipped_methods')
assert len(tracer.skipped_methods) == 0
# init with skipped_methods(list)
UntracedMethodRegistry.method_dict = dict()
tracer = CustomTracer(skipped_methods=self.skipped_methods)
assert '_get_loss' in UntracedMethodRegistry.method_dict.keys()
assert '_get_predictions' in UntracedMethodRegistry.method_dict.keys()
# init with skipped_methods(str)
UntracedMethodRegistry.method_dict = dict()
tracer = CustomTracer(skipped_methods=self.skipped_methods[0])
assert '_get_loss' in UntracedMethodRegistry.method_dict.keys()
# init with skipped_methods(int, error)
with self.assertRaises(TypeError):
CustomTracer(skipped_methods=123)
# init with skipped_methods(str, error)
with self.assertRaises(AssertionError):
CustomTracer(skipped_methods='_get_loss')

def test_trace(self):
# test trace with skipped_methods
model = MODELS.build(self.cfg.model)
UntracedMethodRegistry.method_dict = dict()
tracer = CustomTracer(skipped_methods=self.skipped_methods)
graph_tensor = tracer.trace(model, concrete_args={'mode': 'tensor'})
graph_loss = tracer.trace(model, concrete_args={'mode': 'loss'})
graph_predict = tracer.trace(model, concrete_args={'mode': 'predict'})
assert isinstance(graph_tensor, Graph)
assert isinstance(graph_loss, Graph)
skip_flag_loss = False
for node in graph_loss.nodes:
if node.op == 'call_method' and node.target == '_get_loss':
skip_flag_loss = True
assert isinstance(graph_predict, Graph)
skip_flag_predict = False
for node in graph_predict.nodes:
if node.op == 'call_method' and node.target == '_get_predictions':
skip_flag_predict = True
assert skip_flag_loss and skip_flag_predict

# test trace with skipped_module_names
model = MODELS.build(self.cfg.model)
UntracedMethodRegistry.method_dict = dict()
tracer = CustomTracer(skipped_module_names=self.skipped_module_names)
graph_tensor = tracer.trace(model, concrete_args={'mode': 'tensor'})
skip_flag = False
for node in graph_tensor.nodes:
skipped_module_name = self.skipped_module_names[0]
if node.op == 'call_module' and node.target == skipped_module_name:
skip_flag = True
assert skip_flag

# test trace with skipped_module_classes
model = MODELS.build(self.cfg.model)
UntracedMethodRegistry.method_dict = dict()
tracer = CustomTracer(
skipped_module_classes=self.skipped_module_classes)
graph_tensor = tracer.trace(model, concrete_args={'mode': 'tensor'})
skip_flag = False
for node in graph_tensor.nodes:
if node.op == 'call_module' and node.target == 'backbone.layer1':
skip_flag = True
assert skip_flag


def test_custom_symbolic_trace():
cfg = Config.fromfile(
'tests/data/test_models/test_task_modules/mmcls_cfg.py')
model = MODELS.build(cfg.model)
UntracedMethodRegistry.method_dict = dict()
graph_module = custom_symbolic_trace(
model, concrete_args={'mode': 'tensor'})
assert isinstance(graph_module, GraphModule)


def test_build_graphmodule():
skipped_methods = ['mmcls.models.heads.ClsHead._get_predictions']
cfg = Config.fromfile(
'tests/data/test_models/test_task_modules/mmcls_cfg.py')
model = MODELS.build(cfg.model)
UntracedMethodRegistry.method_dict = dict()
tracer = CustomTracer(skipped_methods=skipped_methods)
graph_predict = tracer.trace(model, concrete_args={'mode': 'predict'})
graph_module = build_graphmodule(model, graph_predict)
assert isinstance(graph_module, GraphModule)

# test _prepare_module_dict
modules = dict(model.named_modules())
module_dict = _prepare_module_dict(model, graph_predict)
for k, v in module_dict.items():
assert isinstance(v, nn.Module)
assert not isinstance(v, modules[k].__class__)

0 comments on commit c5f54c1

Please sign in to comment.