Skip to content

Commit

Permalink
[Fix]: Fix rewriter conflict when processing derived class (open-mmla…
Browse files Browse the repository at this point in the history
…b#289)

* Fix rewriter

* lint

* rename function and update docstring

* use is class

* Update docstring
  • Loading branch information
SingleZombie committed Dec 14, 2021
1 parent 0dea300 commit 78b37bb
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 19 deletions.
70 changes: 55 additions & 15 deletions mmdeploy/core/rewriters/function_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,45 @@
from typing import Callable, Dict

from mmdeploy.utils.constants import Backend
from .rewriter_utils import ContextCaller, RewriterRegistry, eval_with_import
from .rewriter_utils import ContextCaller, RewriterRegistry, import_function


def _set_func(origin_func_name: str, rewrite_func: Callable):
"""Rewrite a function by executing a python statement."""
def _set_func(origin_func_path: str, rewrite_func: Callable):
"""Rewrite a function by executing a python statement.
Args:
origin_func_path (str): The path to origin function.
rewrite_func (Callable): The new function instance.
"""

# Import necessary module
split_path = origin_func_name.split('.')
split_path = origin_func_path.split('.')
for i in range(len(split_path), 0, -1):
try:
exec('import {}'.format('.'.join(split_path[:i])))
break
except Exception:
continue
# Assign function
exec(f'{origin_func_name} = rewrite_func')
exec(f'{origin_func_path} = rewrite_func')


def _del_func(path: str):
"""Delete a function that is denoted by a path.
Args:
path (str): The path to evaluate.
"""

split_path = path.split('.')
for i in range(len(split_path), 0, -1):
try:
exec('import {}'.format('.'.join(split_path[:i])))
break
except Exception:
continue

exec(f'del {path}')


class FunctionRewriter:
Expand Down Expand Up @@ -72,23 +95,38 @@ def enter(self,
functions_records = self._registry.get_records(backend)

self._origin_functions = list()
self._additional_functions = list()
new_functions = list()
for function_name, record_dict in functions_records:
for function_path, record_dict in functions_records:

# Check if the origin function exists
try:
origin_func = eval_with_import(function_name)
origin_func, origin_class = import_function(function_path)
except Exception:
origin_func = None
logging.warning(
f'Can not find {function_name}, function rewrite will '
f'Can not find {function_path}, function rewrite will '
'not be applied')

# Only rewrite functions that exist
if origin_func is not None:

# Save origin function
self._origin_functions.append((function_name, origin_func))
is_addition_function = False
if origin_class is not None:
function_name = function_path.split('.')[-1]
try:
origin_class.__getattribute__(origin_class,
function_name)
except Exception:
# The function is a method and it is derived from base
# class.
is_addition_function = True

if is_addition_function:
self._additional_functions.append(function_path)
else:
# Save origin function
self._origin_functions.append((function_path, origin_func))

# Create context_caller
rewrite_function = record_dict['_object']
Expand All @@ -99,13 +137,15 @@ def enter(self,
**extra_kwargs).get_wrapped_caller()

# Cache new the function to avoid homonymic bug
new_functions.append((function_name, context_caller))
new_functions.append((function_path, context_caller))

for function_name, new_function in new_functions:
for function_path, new_function in new_functions:
# Rewrite functions
_set_func(function_name, new_function)
_set_func(function_path, new_function)

def exit(self):
"""Recover the function rewrite."""
for func_name, func in self._origin_functions:
_set_func(func_name, func)
for func_path, func in self._origin_functions:
_set_func(func_path, func)
for func_path in self._additional_functions:
_del_func(func_path)
37 changes: 35 additions & 2 deletions mmdeploy/core/rewriters/rewriter_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Callable, Dict, List
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple

from mmdeploy.utils.constants import Backend

Expand All @@ -11,7 +12,7 @@ def eval_with_import(path: str) -> Any:
path (str): The path to evaluate.
Returns:
Any: The result of evaluate.
Any: The result of evaluation.
"""
split_path = path.split('.')
for i in range(len(split_path), 0, -1):
Expand All @@ -23,6 +24,38 @@ def eval_with_import(path: str) -> Any:
return eval(path)


def import_function(path: str) -> Tuple[Callable, Optional[type]]:
"""Import and evaluate a function. If the function is defined in a class,
evaluate the class additionally.
Args:
path (str): The path to evaluate.
Returns:
Callable: The function of evaluation.
type: The class of evaluation if the function is defined in a class, or
None.
"""
split_path = path.split('.')
for i in range(len(split_path), 0, -1):
try:
exec('import {}'.format('.'.join(split_path[:i])))
break
except Exception:
continue

obj = eval(path)

# The path that might be a class
previous_obj = eval('.'.join(split_path[:-1]))

# Check if the path leads to a class
if inspect.isclass(previous_obj):
return obj, previous_obj
else:
return obj, None


class RewriterRegistry:
"""A registry that recoreds rewrite objects.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_core/package/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .module import C, func
from .module import C2, C, func

__all__ = ['func', 'C']
__all__ = ['func', 'C', 'C2']
4 changes: 4 additions & 0 deletions tests/test_core/package/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ class C:

def method(self):
return 1


class C2(C):
pass
46 changes: 46 additions & 0 deletions tests/test_core/test_function_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,49 @@ def func_5(ctx, self):
function_rewriter2.exit()

assert c.method() == 1


def test_rewrite_derived_methods():
import package
path1 = 'package.C.method'
path2 = 'package.C2.method'

base_obj = package.C()
derived_obj = package.C2()

assert base_obj.method() == 1
assert derived_obj.method() == 1

function_rewriter = FunctionRewriter()
function_rewriter.add_backend(Backend.NCNN.value)

@function_rewriter.register_rewriter(func_name=path1)
def func_2(ctx, self):
return 2

@function_rewriter.register_rewriter(
func_name=path2, backend=Backend.NCNN.value)
def func_3(ctx, self):
return 3

function_rewriter.enter()
assert base_obj.method() == 2
assert derived_obj.method() == 2
function_rewriter.exit()

function_rewriter.enter(backend=Backend.NCNN.value)
assert base_obj.method() == 2
assert derived_obj.method() == 3
function_rewriter.exit()

assert base_obj.method() == 1
assert derived_obj.method() == 1

# Check if the recovery is correct
function_rewriter.enter()
assert base_obj.method() == 2
assert derived_obj.method() == 2
function_rewriter.exit()

assert base_obj.method() == 1
assert derived_obj.method() == 1

0 comments on commit 78b37bb

Please sign in to comment.