diff --git a/mmrazor/models/task_modules/delivery/distill_delivery.py b/mmrazor/models/task_modules/delivery/distill_delivery.py index dcd56f388..d8c335f00 100644 --- a/mmrazor/models/task_modules/delivery/distill_delivery.py +++ b/mmrazor/models/task_modules/delivery/distill_delivery.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABCMeta, abstractmethod -from queue import Queue +from collections import deque from typing import Callable @@ -33,7 +33,7 @@ class DistillDelivery(metaclass=ABCMeta): def __init__(self, max_keep_data: int = 1) -> None: self._override_data = False - self.data_queue: Queue = Queue(maxsize=max_keep_data) + self.data_queue: deque = deque([], maxlen=max_keep_data) self.max_keep_data = max_keep_data @property diff --git a/mmrazor/models/task_modules/delivery/function_outputs_delivery.py b/mmrazor/models/task_modules/delivery/function_outputs_delivery.py index 19aadc2e9..15c361e38 100644 --- a/mmrazor/models/task_modules/delivery/function_outputs_delivery.py +++ b/mmrazor/models/task_modules/delivery/function_outputs_delivery.py @@ -78,23 +78,7 @@ def __init__(self, func_path: str, max_keep_data: int): super().__init__(max_keep_data) self._check_valid_path(func_path) - module_path = self._get_module_path(func_path) - try: - module = import_modules_from_strings(module_path) - except ImportError: - raise ImportError(f'{module_path} is not imported correctly.') - self.module = module - - func_name = self._get_func_name(func_path) - assert hasattr(module, func_name), \ - f'{func_name} is not in {module_path}.' - self.func_name = func_name - - origin_func = getattr(module, func_name) - if not isinstance(origin_func, FunctionType): - raise TypeError(f'{func_name} should be a FunctionType ' - f'instance, but got {type(origin_func)}') - self.origin_func = origin_func + self.func_path = func_path @staticmethod def _check_valid_path(func_path: str) -> None: @@ -121,6 +105,24 @@ def __enter__(self) -> None: Wrap the origin function. """ + module_path = self._get_module_path(self.func_path) + try: + module = import_modules_from_strings(module_path) + except ImportError: + raise ImportError(f'{module_path} is not imported correctly.') + self.module = module + + func_name = self._get_func_name(self.func_path) + assert hasattr(module, func_name), \ + f'{func_name} is not in {module_path}.' + self.func_name = func_name + + origin_func = getattr(module, func_name) + if not isinstance(origin_func, FunctionType): + raise TypeError(f'{func_name} should be a FunctionType ' + f'instance, but got {type(origin_func)}') + self.origin_func = origin_func + wrapped_func = self.deliver_wrapper(self.origin_func) setattr(self.module, self.func_name, wrapped_func) @@ -131,6 +133,11 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: """ setattr(self.module, self.func_name, self.origin_func) + # self.module and self.origin_func can not be pickled. + # Delete these two attributes to avoid errors when ema model is used. + del self.module + del self.origin_func + def deliver_wrapper(self, origin_func: Callable) -> Callable: """Wrap the specific function to make the intermediate results of the model can be delivered.""" @@ -139,12 +146,13 @@ def deliver_wrapper(self, origin_func: Callable) -> Callable: def wrap_func(*args, **kwargs): if self.override_data: - assert not self.data_queue.empty(), 'pop from an empty queue' - outputs = self.data_queue.get() + assert len(self.data_queue) > 0, 'pop from an empty queue' + outputs = self.data_queue.popleft() else: - assert not self.data_queue.full(), 'push into an full queue' + assert len(self.data_queue) < self.data_queue.maxlen,\ + 'push into an full queue' outputs = origin_func(*args, **kwargs) - self.data_queue.put(outputs) + self.data_queue.append(outputs) return outputs return wrap_func diff --git a/mmrazor/models/task_modules/delivery/method_outputs_delivery.py b/mmrazor/models/task_modules/delivery/method_outputs_delivery.py index dcaae2fd8..fa9f6c4a4 100644 --- a/mmrazor/models/task_modules/delivery/method_outputs_delivery.py +++ b/mmrazor/models/task_modules/delivery/method_outputs_delivery.py @@ -143,12 +143,13 @@ def deliver_wrapper(self, origin_method: Callable) -> Callable: def wrap_method(*args, **kwargs): if self.override_data: - assert not self.data_queue.empty(), 'pop from an empty queue' - outputs = self.data_queue.get() + assert len(self.data_queue) > 0, 'pop from an empty queue' + outputs = self.data_queue.popleft() else: - assert not self.data_queue.full(), 'push into an full queue' + assert len(self.data_queue) < self.data_queue.maxlen,\ + 'push into an full queue' outputs = origin_method(*args, **kwargs) - self.data_queue.put(outputs) + self.data_queue.append(outputs) return outputs return wrap_method diff --git a/mmrazor/models/task_modules/recorder/__init__.py b/mmrazor/models/task_modules/recorder/__init__.py index 6d1858f0b..8af399126 100644 --- a/mmrazor/models/task_modules/recorder/__init__.py +++ b/mmrazor/models/task_modules/recorder/__init__.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .function_inputs_recorder import FunctionInputsRecorder from .function_outputs_recorder import FunctionOutputsRecorder +from .method_inputs_recorder import MethodInputsRecorder from .method_outputs_recorder import MethodOutputsRecorder from .module_inputs_recorder import ModuleInputsRecorder from .module_outputs_recorder import ModuleOutputsRecorder @@ -9,5 +11,5 @@ __all__ = [ 'FunctionOutputsRecorder', 'MethodOutputsRecorder', 'ModuleOutputsRecorder', 'ParameterRecorder', 'RecorderManager', - 'ModuleInputsRecorder' + 'ModuleInputsRecorder', 'MethodInputsRecorder', 'FunctionInputsRecorder' ] diff --git a/mmrazor/models/task_modules/recorder/function_inputs_recorder.py b/mmrazor/models/task_modules/recorder/function_inputs_recorder.py new file mode 100644 index 000000000..e7bbdd896 --- /dev/null +++ b/mmrazor/models/task_modules/recorder/function_inputs_recorder.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +from inspect import signature +from typing import Callable, List + +from mmrazor.registry import TASK_UTILS +from .function_outputs_recorder import FunctionOutputsRecorder + + +@TASK_UTILS.register_module() +class FunctionInputsRecorder(FunctionOutputsRecorder): + """Recorder for intermediate results which are ``FunctionType``'s inputs. + + Notes: + The form of `source` needs special attention. For example, + `anchor_inside_flags` is a function in mmdetection to check whether the + anchors are inside the border. This function is in + `mmdet/core/anchor/utils.py` and used in + `mmdet/models/dense_heads/anchor_head.py`. Then the source should be + `mmdet.models.dense_heads.anchor_head.anchor_inside_flags` but not + `mmdet.core.anchor.utils.anchor_inside_flags`. + + + Examples: + >>> # Below code in toy_module.py + >>> import random + >>> def toy_func(a, b): + ... return a, b + >>> def execute_toy_func(a, b): + ... toy_func(a, b) + + >>> # Below code in main.py + >>> # Now, we want to get teacher's inputs by recorder. + + >>> from toy_module import execute_toy_func + >>> r1 = FunctionInputsRecorder('toy_module.toy_func') + >>> r1.initialize() + >>> with r1: + ... execute_toy_func(1, 2) + ... execute_toy_func(1, b=2) + ... execute_toy_func(b=2, a=1) + + >>> r1.data_buffer + [[1, 2], [1, 2], [1, 2]] + """ + + def func_record_wrapper(self, origin_func: Callable, + data_buffer: List) -> Callable: + """Save the function's inputs. + + Args: + origin_func (FunctionType): The method whose inputs need to be + recorded. + data_buffer (list): A list of data. + """ + + func_input_params = signature(origin_func).parameters.keys() + + @functools.wraps(origin_func) + def wrap_func(*args, **kwargs): + outputs = origin_func(*args, **kwargs) + inputs = list(args) + for keyword in func_input_params: + if keyword in kwargs: + inputs.append(kwargs[keyword]) + # assume a func execute N times, there will be N inputs need to + # save. + data_buffer.append(inputs) + return outputs + + return wrap_func diff --git a/mmrazor/models/task_modules/recorder/function_outputs_recorder.py b/mmrazor/models/task_modules/recorder/function_outputs_recorder.py index c6ab5228f..706c1a8f7 100644 --- a/mmrazor/models/task_modules/recorder/function_outputs_recorder.py +++ b/mmrazor/models/task_modules/recorder/function_outputs_recorder.py @@ -65,28 +65,8 @@ class FunctionOutputsRecorder(BaseRecorder): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self._check_valid_source(self.source) - # import the function corrosponding module - try: - mod = import_modules_from_strings(self.module_string) - except ImportError: - raise ImportError( - f'{self.module_string} is not imported correctly.') - - self.imported_module: ModuleType = mod - - assert hasattr(mod, self.func_name), \ - f'{self.func_name} is not in {self.module_string}.' - - origin_func = getattr(mod, self.func_name) - if not isinstance(origin_func, FunctionType): - raise TypeError(f'{self.func_name} should be a FunctionType ' - f'instance, but got {type(origin_func)}') - - self.origin_func: Callable = origin_func - @staticmethod def _check_valid_source(source): """Check if the source's format is valid.""" @@ -118,8 +98,7 @@ def func_record_wrapper(self, origin_func: Callable, Args: origin_func (FunctionType): The method whose outputs need to be recorded. - buffer_key (str): The key of the function's outputs saved in - ``data_buffer``. + data_buffer (list): A list of data. """ @functools.wraps(origin_func) @@ -136,8 +115,25 @@ def __enter__(self): """Enter the context manager.""" super().__enter__() - mod = self.imported_module - origin_func = self.origin_func + # import the function corrosponding module + try: + mod = import_modules_from_strings(self.module_string) + except ImportError: + raise ImportError( + f'{self.module_string} is not imported correctly.') + + self.imported_module: ModuleType = mod + + assert hasattr(mod, self.func_name), \ + f'{self.func_name} is not in {self.module_string}.' + + origin_func = getattr(mod, self.func_name) + if not isinstance(origin_func, FunctionType): + raise TypeError(f'{self.func_name} should be a FunctionType ' + f'instance, but got {type(origin_func)}') + + self.origin_func: Callable = origin_func + # add record wrapper to origin function. record_func = self.func_record_wrapper(origin_func, self.data_buffer) @@ -159,3 +155,8 @@ def __exit__(self, exc_type, exc_value, traceback): # restore the origin function setattr(mod, self.func_name, origin_func) + + # self.imported_module and self.origin_func can not be pickled. + # Delete these two attributes to avoid errors when ema model is used. + del self.imported_module + del self.origin_func diff --git a/mmrazor/models/task_modules/recorder/method_inputs_recorder.py b/mmrazor/models/task_modules/recorder/method_inputs_recorder.py new file mode 100644 index 000000000..44cb41843 --- /dev/null +++ b/mmrazor/models/task_modules/recorder/method_inputs_recorder.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +from inspect import signature +from typing import Callable, List + +from mmrazor.registry import TASK_UTILS +from .method_outputs_recorder import MethodOutputsRecorder + + +@TASK_UTILS.register_module() +class MethodInputsRecorder(MethodOutputsRecorder): + """Recorder for intermediate results which are ``MethodType``'s inputs. + + Note: + Different from ``FunctionType``, ``MethodType`` is the type of methods + of class instances. + + Examples: + >>> # Below code in toy_module.py + >>> import random + >>> class Toy(): + ... def toy_func(self, x, y=0): + ... return x + y + + >>> # Below code in main.py + >>> # Now, we want to get teacher's inputs by recorder. + + >>> from toy_module import Toy + >>> toy = Toy() + >>> r1 = MethodInputsRecorder('toy_module.Toy.toy_func') + >>> r1.initialize() + >>> with r1: + ... _ = toy.toy_func(1, 2) + + >>> r1.data_buffer + [[1, 2]] + >>> r1.get_record_data(record_idx=0, data_idx=0) + 1 + >>> r1.get_record_data(record_idx=0, data_idx=1) + 2 + + >>> from toy_module import Toy + >>> toy = Toy() + >>> r1 = MethodInputsRecorder('toy_module.Toy.toy_func') + >>> r1.initialize() + >>> with r1: + ... _ = toy.toy_func(1, 2) + ... _ = toy.toy_func(y=2, x=1) + + >>> r1.data_buffer + [[1, 2], [1, 2]] + >>> r1.get_record_data(record_idx=1, data_idx=0) + 1 + >>> r1.get_record_data(record_idx=1, data_idx=1) + 2 + """ + + def method_record_wrapper(self, orgin_method: Callable, + data_buffer: List) -> Callable: + """Save the method's inputs. + + Args: + origin_method (MethodType): The method whose inputs need to be + recorded. + data_buffer (list): A list of data. + """ + + method_input_params = signature(orgin_method).parameters.keys() + + @functools.wraps(orgin_method) + def wrap_method(*args, **kwargs): + outputs = orgin_method(*args, **kwargs) + # the first element of a class method is the class itself + inputs = list(args[1:]) + for keyword in method_input_params: + if keyword in kwargs: + inputs.append(kwargs[keyword]) + # Assume a func execute N times, there will be N inputs need to + # save. + data_buffer.append(inputs) + return outputs + + return wrap_method diff --git a/mmrazor/models/task_modules/recorder/method_outputs_recorder.py b/mmrazor/models/task_modules/recorder/method_outputs_recorder.py index 266750726..6d3fb6593 100644 --- a/mmrazor/models/task_modules/recorder/method_outputs_recorder.py +++ b/mmrazor/models/task_modules/recorder/method_outputs_recorder.py @@ -130,8 +130,7 @@ def method_record_wrapper(self, orgin_method: Callable, Args: origin_method (MethodType): The method whose outputs need to be recorded. - buffer_key (str): The key of the method's outputs saved in - ``data_buffer``. + data_buffer (list): A list of data. """ @functools.wraps(orgin_method) diff --git a/tests/test_core/test_delivers/test_function_outputs_deliver.py b/tests/test_core/test_delivers/test_function_outputs_deliver.py index 8115af411..531e59795 100644 --- a/tests/test_core/test_delivers/test_function_outputs_deliver.py +++ b/tests/test_core/test_delivers/test_function_outputs_deliver.py @@ -1,11 +1,75 @@ # Copyright (c) OpenMMLab. All rights reserved. +import logging +import os.path as osp +import tempfile from unittest import TestCase +from unittest.mock import Mock + +import torch +import torch.nn as nn +from mmengine.evaluator import Evaluator +from mmengine.hooks import EMAHook +from mmengine.logging import MMLogger +from mmengine.model import BaseModel, ExponentialMovingAverage +from mmengine.optim import OptimWrapper +from mmengine.runner import Runner +from torch.utils.data import Dataset from mmrazor.models.task_modules import FunctionOutputsDelivery +class ToyModel(BaseModel): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 1) + # test FunctionOutputsDelivery when ema_hook is used + self.deliver = FunctionOutputsDelivery( + max_keep_data=2, func_path='toy_module.toy_func') + + def forward(self, inputs, data_sample, mode='tensor'): + labels = torch.stack(data_sample) + inputs = torch.stack(inputs) + with self.deliver: + outputs = self.linear(inputs) + if mode == 'tensor': + return outputs + elif mode == 'loss': + loss = (labels - outputs).sum() + outputs = dict(loss=loss) + return outputs + else: + return outputs + + +class DummyDataset(Dataset): + METAINFO = dict() # type: ignore + data = torch.randn(12, 2) + label = torch.ones(12) + + @property + def metainfo(self): + return self.METAINFO + + def __len__(self): + return self.data.size(0) + + def __getitem__(self, index): + return dict(inputs=self.data[index], data_sample=self.label[index]) + + class TestFuncOutputsDeliver(TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + # `FileHandler` should be closed in Windows, otherwise we cannot + # delete the temporary directory + logging.shutdown() + MMLogger._instance_dict.clear() + self.temp_dir.cleanup() + def test_init(self): with self.assertRaisesRegex(TypeError, 'func_path should be'): @@ -14,19 +78,25 @@ def test_init(self): with self.assertRaisesRegex(AssertionError, 'func_path must have at '): _ = FunctionOutputsDelivery(max_keep_data=1, func_path='toy_func') + def test_context_manager(self): + import toy_module + + delivery = FunctionOutputsDelivery(max_keep_data=2, func_path='aaa.bb') with self.assertRaisesRegex(ImportError, 'aaa is not imported'): - _ = FunctionOutputsDelivery(max_keep_data=1, func_path='aaa.bb') + with delivery: + _ = toy_module.toy_func() + delivery = FunctionOutputsDelivery( + max_keep_data=1, func_path='toy_module.bb') with self.assertRaisesRegex(AssertionError, 'bb is not in toy_mod'): - _ = FunctionOutputsDelivery( - max_keep_data=1, func_path='toy_module.bb') + with delivery: + _ = toy_module.toy_func() + delivery = FunctionOutputsDelivery( + max_keep_data=1, func_path='toy_module.TOY_VAR') with self.assertRaisesRegex(TypeError, 'TOY_VAR should be'): - _ = FunctionOutputsDelivery( - max_keep_data=1, func_path='toy_module.TOY_VAR') - - def test_context_manager(self): - import toy_module + with delivery: + _ = toy_module.toy_func() delivery = FunctionOutputsDelivery( max_keep_data=2, func_path='toy_module.toy_func') @@ -52,3 +122,42 @@ def test_context_manager(self): with self.assertRaisesRegex(AssertionError, 'pop from an empty queue'): with delivery: _ = toy_module.toy_func() + + def test_ema_hook(self): + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + model = ToyModel().to(device) + evaluator = Evaluator([]) + evaluator.evaluate = Mock(return_value=dict(acc=0.5)) + runner = Runner( + model=model, + train_dataloader=dict( + dataset=DummyDataset(), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict( + dataset=DummyDataset(), + sampler=dict(type='DefaultSampler', shuffle=False), + batch_size=3, + num_workers=0), + val_evaluator=evaluator, + work_dir=self.temp_dir.name, + default_scope='mmrazor', + optim_wrapper=OptimWrapper( + torch.optim.Adam(ToyModel().parameters())), + train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), + val_cfg=dict(), + default_hooks=dict(logger=None), + custom_hooks=[dict(type='EMAHook', )], + experiment_name='test_func_outputs_deliver') + runner.train() + for hook in runner.hooks: + if isinstance(hook, EMAHook): + self.assertTrue( + isinstance(hook.ema_model, ExponentialMovingAverage)) + + self.assertTrue( + osp.exists(osp.join(self.temp_dir.name, 'epoch_2.pth'))) + checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) + self.assertTrue('ema_state_dict' in checkpoint) + self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8) diff --git a/tests/test_core/test_recorders/test_func_inputs_recorder.py b/tests/test_core/test_recorders/test_func_inputs_recorder.py new file mode 100644 index 000000000..6fa9655a1 --- /dev/null +++ b/tests/test_core/test_recorders/test_func_inputs_recorder.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +import os.path as osp +import tempfile +from unittest import TestCase +from unittest.mock import Mock + +import torch +import torch.nn as nn +from mmengine.evaluator import Evaluator +from mmengine.hooks import EMAHook +from mmengine.logging import MMLogger +from mmengine.model import BaseModel, ExponentialMovingAverage +from mmengine.optim import OptimWrapper +from mmengine.runner import Runner +from torch.utils.data import Dataset + +from mmrazor.models.task_modules import FunctionInputsRecorder, RecorderManager + + +class ToyModel(BaseModel): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 1) + # test FunctionInputsRecorder when ema_hook is used + recorders_cfg = dict( + out=dict(type='FunctionInputs', source='toy_mod.toy_func')) + self.recorders = RecorderManager(recorders_cfg) + self.recorders.initialize(self) + + def forward(self, inputs, data_sample, mode='tensor'): + labels = torch.stack(data_sample) + inputs = torch.stack(inputs) + with self.recorders: + outputs = self.linear(inputs) + if mode == 'tensor': + return outputs + elif mode == 'loss': + loss = (labels - outputs).sum() + outputs = dict(loss=loss) + return outputs + else: + return outputs + + +class DummyDataset(Dataset): + METAINFO = dict() # type: ignore + data = torch.randn(12, 2) + label = torch.ones(12) + + @property + def metainfo(self): + return self.METAINFO + + def __len__(self): + return self.data.size(0) + + def __getitem__(self, index): + return dict(inputs=self.data[index], data_sample=self.label[index]) + + +class TestFuncInputsRecorder(TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + # `FileHandler` should be closed in Windows, otherwise we cannot + # delete the temporary directory + logging.shutdown() + MMLogger._instance_dict.clear() + self.temp_dir.cleanup() + + def test_context_manager(self): + from toy_mod import execute_toy_func2 as execute_toy_func + + recorder = FunctionInputsRecorder('toy_mod.toy_func2') + recorder.initialize() + + with recorder: + execute_toy_func(1, 2) + execute_toy_func(1, b=2) + execute_toy_func(b=2, a=1) + + self.assertTrue( + recorder.get_record_data(record_idx=0, data_idx=0) == 1) + self.assertTrue( + recorder.get_record_data(record_idx=0, data_idx=1) == 2) + + self.assertTrue( + recorder.get_record_data(record_idx=1, data_idx=0) == 1) + self.assertTrue( + recorder.get_record_data(record_idx=1, data_idx=1) == 2) + + self.assertTrue( + recorder.get_record_data(record_idx=2, data_idx=0) == 1) + self.assertTrue( + recorder.get_record_data(record_idx=2, data_idx=1) == 2) + + def test_ema_hook(self): + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + model = ToyModel().to(device) + evaluator = Evaluator([]) + evaluator.evaluate = Mock(return_value=dict(acc=0.5)) + runner = Runner( + model=model, + train_dataloader=dict( + dataset=DummyDataset(), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict( + dataset=DummyDataset(), + sampler=dict(type='DefaultSampler', shuffle=False), + batch_size=3, + num_workers=0), + val_evaluator=evaluator, + work_dir=self.temp_dir.name, + default_scope='mmrazor', + optim_wrapper=OptimWrapper( + torch.optim.Adam(ToyModel().parameters())), + train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), + val_cfg=dict(), + default_hooks=dict(logger=None), + custom_hooks=[dict(type='EMAHook', )], + experiment_name='test_func_inputs_recorder') + runner.train() + for hook in runner.hooks: + if isinstance(hook, EMAHook): + self.assertTrue( + isinstance(hook.ema_model, ExponentialMovingAverage)) + + self.assertTrue( + osp.exists(osp.join(self.temp_dir.name, 'epoch_2.pth'))) + checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) + self.assertTrue('ema_state_dict' in checkpoint) + self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8) diff --git a/tests/test_core/test_recorders/test_func_outputs_recorder.py b/tests/test_core/test_recorders/test_func_outputs_recorder.py index a6c3be2ba..1d6561495 100644 --- a/tests/test_core/test_recorders/test_func_outputs_recorder.py +++ b/tests/test_core/test_recorders/test_func_outputs_recorder.py @@ -16,17 +16,26 @@ def test_init(self): with self.assertRaisesRegex(AssertionError, 'source must have at '): _ = FunctionOutputsRecorder('aaaaa') + def test_context_manager(self): + from toy_mod import execute_toy_func + + recorder = FunctionOutputsRecorder('aaa.bbb') + recorder.initialize() with self.assertRaisesRegex(ImportError, 'aaa is not imported'): - _ = FunctionOutputsRecorder('aaa.bbb') + with recorder: + execute_toy_func(1) + recorder = FunctionOutputsRecorder('toy_mod.aaa') + recorder.initialize() with self.assertRaisesRegex(AssertionError, 'aaa is not in toy_mod'): - _ = FunctionOutputsRecorder('toy_mod.aaa') + with recorder: + execute_toy_func(1) + recorder = FunctionOutputsRecorder('toy_mod.TOY_VAR') + recorder.initialize() with self.assertRaisesRegex(TypeError, 'TOY_VAR should be'): - _ = FunctionOutputsRecorder('toy_mod.TOY_VAR') - - def test_context_manager(self): - from toy_mod import execute_toy_func + with recorder: + execute_toy_func(1) recorder = FunctionOutputsRecorder('toy_mod.toy_func') recorder.initialize() diff --git a/tests/test_core/test_recorders/test_method_inputs_recorder.py b/tests/test_core/test_recorders/test_method_inputs_recorder.py new file mode 100644 index 000000000..7450a231c --- /dev/null +++ b/tests/test_core/test_recorders/test_method_inputs_recorder.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +from mmrazor.models.task_modules import MethodInputsRecorder + + +class TestFuncOutputsRecorder(TestCase): + + def test_context_manager(self): + from toy_mod import ToyClass + + toy = ToyClass() + + recorder = MethodInputsRecorder('toy_mod.ToyClass.func') + recorder.initialize() + + with recorder: + _ = toy.func(x=1, y=2) + _ = toy.func(1, y=2) + _ = toy.func(y=2, x=1) + + self.assertTrue( + recorder.get_record_data(record_idx=0, data_idx=0) == 1) + self.assertTrue( + recorder.get_record_data(record_idx=0, data_idx=1) == 2) + + self.assertTrue( + recorder.get_record_data(record_idx=1, data_idx=0) == 1) + self.assertTrue( + recorder.get_record_data(record_idx=1, data_idx=1) == 2) + + self.assertTrue( + recorder.get_record_data(record_idx=2, data_idx=0) == 1) + self.assertTrue( + recorder.get_record_data(record_idx=2, data_idx=1) == 2) diff --git a/tests/test_core/test_recorders/toy_mod.py b/tests/test_core/test_recorders/toy_mod.py index 3cc331476..0df3e2d70 100644 --- a/tests/test_core/test_recorders/toy_mod.py +++ b/tests/test_core/test_recorders/toy_mod.py @@ -8,6 +8,10 @@ def toy_func(a): return a +def toy_func2(a, b): + return a, b + + def toy_list_func(a): return [a, a, a] @@ -16,6 +20,10 @@ def execute_toy_func(a): toy_func(a) +def execute_toy_func2(a, b): + toy_func2(a, b) + + def execute_toy_list_func(a): toy_list_func(a) @@ -31,6 +39,9 @@ def toy(self): self._count += 1 return self._count + def func(self, x, y=0): + return x + y + def __call__(self): self._count += 1 return self._count