Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Features]Support MethodInputsRecorder and FunctionInputsRecorder #320

Merged
merged 6 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmrazor/models/task_modules/delivery/distill_delivery.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from queue import Queue
from collections import deque
from typing import Callable


Expand Down Expand Up @@ -33,7 +33,7 @@ class DistillDelivery(metaclass=ABCMeta):
def __init__(self, max_keep_data: int = 1) -> None:

self._override_data = False
self.data_queue: Queue = Queue(maxsize=max_keep_data)
self.data_queue: deque = deque([], maxlen=max_keep_data)
self.max_keep_data = max_keep_data

@property
Expand Down
50 changes: 29 additions & 21 deletions mmrazor/models/task_modules/delivery/function_outputs_delivery.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,7 @@ def __init__(self, func_path: str, max_keep_data: int):
super().__init__(max_keep_data)

self._check_valid_path(func_path)
module_path = self._get_module_path(func_path)
try:
module = import_modules_from_strings(module_path)
except ImportError:
raise ImportError(f'{module_path} is not imported correctly.')
self.module = module

func_name = self._get_func_name(func_path)
assert hasattr(module, func_name), \
f'{func_name} is not in {module_path}.'
self.func_name = func_name

origin_func = getattr(module, func_name)
if not isinstance(origin_func, FunctionType):
raise TypeError(f'{func_name} should be a FunctionType '
f'instance, but got {type(origin_func)}')
self.origin_func = origin_func
self.func_path = func_path

@staticmethod
def _check_valid_path(func_path: str) -> None:
Expand All @@ -121,6 +105,24 @@ def __enter__(self) -> None:

Wrap the origin function.
"""
module_path = self._get_module_path(self.func_path)
try:
module = import_modules_from_strings(module_path)
except ImportError:
raise ImportError(f'{module_path} is not imported correctly.')
self.module = module

func_name = self._get_func_name(self.func_path)
assert hasattr(module, func_name), \
f'{func_name} is not in {module_path}.'
self.func_name = func_name

origin_func = getattr(module, func_name)
if not isinstance(origin_func, FunctionType):
raise TypeError(f'{func_name} should be a FunctionType '
f'instance, but got {type(origin_func)}')
self.origin_func = origin_func

wrapped_func = self.deliver_wrapper(self.origin_func)
setattr(self.module, self.func_name, wrapped_func)

Expand All @@ -131,6 +133,11 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
setattr(self.module, self.func_name, self.origin_func)

# self.module and self.origin_func can not be pickled.
# Delete these two attributes to avoid errors when ema model is used.
del self.module
del self.origin_func

def deliver_wrapper(self, origin_func: Callable) -> Callable:
"""Wrap the specific function to make the intermediate results of the
model can be delivered."""
Expand All @@ -139,12 +146,13 @@ def deliver_wrapper(self, origin_func: Callable) -> Callable:
def wrap_func(*args, **kwargs):

if self.override_data:
assert not self.data_queue.empty(), 'pop from an empty queue'
outputs = self.data_queue.get()
assert len(self.data_queue) > 0, 'pop from an empty queue'
outputs = self.data_queue.popleft()
else:
assert not self.data_queue.full(), 'push into an full queue'
assert len(self.data_queue) < self.data_queue.maxlen,\
'push into an full queue'
outputs = origin_func(*args, **kwargs)
self.data_queue.put(outputs)
self.data_queue.append(outputs)
return outputs

return wrap_func
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,13 @@ def deliver_wrapper(self, origin_method: Callable) -> Callable:
def wrap_method(*args, **kwargs):

if self.override_data:
assert not self.data_queue.empty(), 'pop from an empty queue'
outputs = self.data_queue.get()
assert len(self.data_queue) > 0, 'pop from an empty queue'
outputs = self.data_queue.popleft()
else:
assert not self.data_queue.full(), 'push into an full queue'
assert len(self.data_queue) < self.data_queue.maxlen,\
'push into an full queue'
outputs = origin_method(*args, **kwargs)
self.data_queue.put(outputs)
self.data_queue.append(outputs)
return outputs

return wrap_method
4 changes: 3 additions & 1 deletion mmrazor/models/task_modules/recorder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .function_inputs_recorder import FunctionInputsRecorder
from .function_outputs_recorder import FunctionOutputsRecorder
from .method_inputs_recorder import MethodInputsRecorder
from .method_outputs_recorder import MethodOutputsRecorder
from .module_inputs_recorder import ModuleInputsRecorder
from .module_outputs_recorder import ModuleOutputsRecorder
Expand All @@ -9,5 +11,5 @@
__all__ = [
'FunctionOutputsRecorder', 'MethodOutputsRecorder',
'ModuleOutputsRecorder', 'ParameterRecorder', 'RecorderManager',
'ModuleInputsRecorder'
'ModuleInputsRecorder', 'MethodInputsRecorder', 'FunctionInputsRecorder'
]
71 changes: 71 additions & 0 deletions mmrazor/models/task_modules/recorder/function_inputs_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from inspect import signature
from typing import Callable, List

from mmrazor.registry import TASK_UTILS
from .function_outputs_recorder import FunctionOutputsRecorder


@TASK_UTILS.register_module()
class FunctionInputsRecorder(FunctionOutputsRecorder):
"""Recorder for intermediate results which are ``FunctionType``'s inputs.

Notes:
The form of `source` needs special attention. For example,
`anchor_inside_flags` is a function in mmdetection to check whether the
anchors are inside the border. This function is in
`mmdet/core/anchor/utils.py` and used in
`mmdet/models/dense_heads/anchor_head.py`. Then the source should be
`mmdet.models.dense_heads.anchor_head.anchor_inside_flags` but not
`mmdet.core.anchor.utils.anchor_inside_flags`.


Examples:
>>> # Below code in toy_module.py
>>> import random
>>> def toy_func(a, b):
... return a, b
>>> def execute_toy_func(a, b):
... toy_func(a, b)

>>> # Below code in main.py
>>> # Now, we want to get teacher's inputs by recorder.

>>> from toy_module import execute_toy_func
>>> r1 = FunctionInputsRecorder('toy_module.toy_func')
>>> r1.initialize()
>>> with r1:
... execute_toy_func(1, 2)
... execute_toy_func(1, b=2)
... execute_toy_func(b=2, a=1)

>>> r1.data_buffer
[[1, 2], [1, 2], [1, 2]]
"""

def func_record_wrapper(self, origin_func: Callable,
data_buffer: List) -> Callable:
"""Save the function's inputs.

Args:
origin_func (FunctionType): The method whose inputs need to be
recorded.
data_buffer (list): A list of data.
"""

func_input_params = signature(origin_func).parameters.keys()

@functools.wraps(origin_func)
def wrap_func(*args, **kwargs):
outputs = origin_func(*args, **kwargs)
inputs = list(args)
for keyword in func_input_params:
if keyword in kwargs:
inputs.append(kwargs[keyword])
# assume a func execute N times, there will be N inputs need to
# save.
data_buffer.append(inputs)
return outputs

return wrap_func
49 changes: 25 additions & 24 deletions mmrazor/models/task_modules/recorder/function_outputs_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,28 +65,8 @@ class FunctionOutputsRecorder(BaseRecorder):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self._check_valid_source(self.source)

# import the function corrosponding module
try:
mod = import_modules_from_strings(self.module_string)
except ImportError:
raise ImportError(
f'{self.module_string} is not imported correctly.')

self.imported_module: ModuleType = mod

assert hasattr(mod, self.func_name), \
f'{self.func_name} is not in {self.module_string}.'

origin_func = getattr(mod, self.func_name)
if not isinstance(origin_func, FunctionType):
raise TypeError(f'{self.func_name} should be a FunctionType '
f'instance, but got {type(origin_func)}')

self.origin_func: Callable = origin_func

@staticmethod
def _check_valid_source(source):
"""Check if the source's format is valid."""
Expand Down Expand Up @@ -118,8 +98,7 @@ def func_record_wrapper(self, origin_func: Callable,
Args:
origin_func (FunctionType): The method whose outputs need to be
recorded.
buffer_key (str): The key of the function's outputs saved in
``data_buffer``.
data_buffer (list): A list of data.
"""

@functools.wraps(origin_func)
Expand All @@ -136,8 +115,25 @@ def __enter__(self):
"""Enter the context manager."""
super().__enter__()

mod = self.imported_module
origin_func = self.origin_func
# import the function corrosponding module
try:
mod = import_modules_from_strings(self.module_string)
except ImportError:
raise ImportError(
f'{self.module_string} is not imported correctly.')

self.imported_module: ModuleType = mod

assert hasattr(mod, self.func_name), \
f'{self.func_name} is not in {self.module_string}.'

origin_func = getattr(mod, self.func_name)
if not isinstance(origin_func, FunctionType):
raise TypeError(f'{self.func_name} should be a FunctionType '
f'instance, but got {type(origin_func)}')

self.origin_func: Callable = origin_func

# add record wrapper to origin function.
record_func = self.func_record_wrapper(origin_func, self.data_buffer)

Expand All @@ -159,3 +155,8 @@ def __exit__(self, exc_type, exc_value, traceback):

# restore the origin function
setattr(mod, self.func_name, origin_func)

# self.imported_module and self.origin_func can not be pickled.
# Delete these two attributes to avoid errors when ema model is used.
del self.imported_module
del self.origin_func
83 changes: 83 additions & 0 deletions mmrazor/models/task_modules/recorder/method_inputs_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from inspect import signature
from typing import Callable, List

from mmrazor.registry import TASK_UTILS
from .method_outputs_recorder import MethodOutputsRecorder


@TASK_UTILS.register_module()
class MethodInputsRecorder(MethodOutputsRecorder):
"""Recorder for intermediate results which are ``MethodType``'s inputs.

Note:
Different from ``FunctionType``, ``MethodType`` is the type of methods
of class instances.

Examples:
>>> # Below code in toy_module.py
>>> import random
>>> class Toy():
... def toy_func(self, x, y=0):
... return x + y

>>> # Below code in main.py
>>> # Now, we want to get teacher's inputs by recorder.

>>> from toy_module import Toy
>>> toy = Toy()
>>> r1 = MethodInputsRecorder('toy_module.Toy.toy_func')
>>> r1.initialize()
>>> with r1:
... _ = toy.toy_func(1, 2)

>>> r1.data_buffer
[[1, 2]]
>>> r1.get_record_data(record_idx=0, data_idx=0)
1
>>> r1.get_record_data(record_idx=0, data_idx=1)
2

>>> from toy_module import Toy
>>> toy = Toy()
>>> r1 = MethodInputsRecorder('toy_module.Toy.toy_func')
>>> r1.initialize()
>>> with r1:
... _ = toy.toy_func(1, 2)
... _ = toy.toy_func(y=2, x=1)

>>> r1.data_buffer
[[1, 2], [1, 2]]
>>> r1.get_record_data(record_idx=1, data_idx=0)
1
>>> r1.get_record_data(record_idx=1, data_idx=1)
2
"""

def method_record_wrapper(self, orgin_method: Callable,
data_buffer: List) -> Callable:
"""Save the method's inputs.

Args:
origin_method (MethodType): The method whose inputs need to be
recorded.
data_buffer (list): A list of data.
"""

method_input_params = signature(orgin_method).parameters.keys()

@functools.wraps(orgin_method)
def wrap_method(*args, **kwargs):
outputs = orgin_method(*args, **kwargs)
# the first element of a class method is the class itself
inputs = list(args[1:])
for keyword in method_input_params:
if keyword in kwargs:
inputs.append(kwargs[keyword])
# Assume a func execute N times, there will be N inputs need to
# save.
data_buffer.append(inputs)
return outputs

return wrap_method
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ def method_record_wrapper(self, orgin_method: Callable,
Args:
origin_method (MethodType): The method whose outputs need to be
recorded.
buffer_key (str): The key of the method's outputs saved in
``data_buffer``.
data_buffer (list): A list of data.
"""

@functools.wraps(orgin_method)
Expand Down
Loading