Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The problem about using RECORDER in dev1.x (RecorderManager) #433

Closed
cape-zck opened this issue Jan 12, 2023 · 2 comments
Closed

The problem about using RECORDER in dev1.x (RecorderManager) #433

cape-zck opened this issue Jan 12, 2023 · 2 comments
Assignees

Comments

@cape-zck
Copy link
Contributor

cape-zck commented Jan 12, 2023

Thanks to the MMrazor team contribution, Now I am learning the use of MMrazor dev1.x about the RECORDER moudle, but I ran into a couple of problems.

1.I find no mmrazor.core in dev1.x. however,I find the module such as MethodOutputsRecorder, RecorderManager and so on in mmrazor.models.task_modules, this problem seems to have been solved.

2.But when I test the demo about RecorderManager, I found that if I tried to use ModuleOutputs and MethodOutputs in RecorderManager at the same time, the forward process would seem to run twice, and output MethodOutputs would report an out-of-range error. But if the Method(toy_func) isn’t in the same python file, this error doesn't seem to happen.

For example:

import torch
import random
from mmengine import ConfigDict
from torch import nn
from mmrazor.models.task_modules import RecorderManager

class Toy():
    def toy_func(self):
        return random.randint(0, 1000000)

class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 1, 1)
        self.conv2 = nn.Conv2d(1, 1, 1)
        self.toy = Toy()

    def forward(self, x):
        return self.conv2(self.conv1(x)) + self.toy.toy_func()

# configure multi-recorders
conv1_rec = ConfigDict(type='ModuleOutputs', source='conv1')
conv2_rec = ConfigDict(type='ModuleOutputs', source='conv2')
func_rec = ConfigDict(type='MethodOutputs', source='demo.Toy.toy_func')
# instantiate RecorderManager with a dict that contains recorders' configs,
# you can customize their keys.
manager = RecorderManager(
    {'conv1_rec': conv1_rec,
     'conv2_rec': conv2_rec,
     'func_rec': func_rec})

model = ToyModel()
# initialize is to make specified module can be recorded by
# registering customized forward hook.
manager.initialize(model)

x = torch.rand(1, 1, 1, 1)
with manager:
    out = model(x)

conv2_out = manager.get_recorder('conv2_rec').get_record_data()
'''
 - mmengine - WARNING - The "task util" registry in mmrazor did not set import location. Fallback to call `mmrazor.utils.register_all_modules` instead.
tensor([[[[-0.2194]]]], grad_fn=<ThnnConv2DBackward>)
tensor([[[[-1.8282]]]], grad_fn=<ThnnConv2DBackward>)
'''

func_rec = manager.get_recorder('func_rec').get_record_data()
'''
569682
Traceback (most recent call last):
  File "demo.py", line 50, in <module>
    func_rec = manager.get_recorder('func_rec').get_record_data()
  File "mmrazor\mmrazor\models\task_modules\recorder\base_recorder.py", line 78, in get_record_data
    assert record_idx < len(self._data_buffer), \
AssertionError: record_idx is illegal. The length of data_buffer is 0, but record_idx is 0.
'''

But when I remove the MethodOutputs, everything seemed normal again.


import torch
import random
from mmengine import ConfigDict
from torch import nn
from mmrazor.models.task_modules import RecorderManager

class Toy():
    def toy_func(self):
        return random.randint(0, 1000000)

class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 1, 1)
        self.conv2 = nn.Conv2d(1, 1, 1)
        self.toy = Toy()

    def forward(self, x):
        return self.conv2(self.conv1(x)) + self.toy.toy_func()

# configure multi-recorders
conv1_rec = ConfigDict(type='ModuleOutputs', source='conv1')
conv2_rec = ConfigDict(type='ModuleOutputs', source='conv2')
#func_rec = ConfigDict(type='MethodOutputs', source='demo.Toy.toy_func')
# instantiate RecorderManager with a dict that contains recorders' configs,
# you can customize their keys.
manager = RecorderManager(
    {'conv1_rec': conv1_rec,
     'conv2_rec': conv2_rec})

model = ToyModel()
# initialize is to make specified module can be recorded by
# registering customized forward hook.
manager.initialize(model)

x = torch.rand(1, 1, 1, 1)
with manager:
    out = model(x)

conv2_out = manager.get_recorder('conv2_rec').get_record_data()

'''
- mmengine - WARNING - The "task util" registry in mmrazor did not set import location. Fallback to call `mmrazor.utils.register_all_modules` instead.
tensor([[[[-0.7945]]]], grad_fn=<ThnnConv2DBackward>)

'''
@HIT-cwh
Copy link
Collaborator

HIT-cwh commented Feb 3, 2023

Hi @cape-zck ! Thank you for your issue.
If we set the source module name to demo.Toy.toy_func, the method output recorder will first import this target module demo.Toy by import_modules_from_strings('demo.Toy') and then wrap its toy_func method by the method_record_wrapper method in MethodOutputsRecorder. In this way, we can record the output data of a specific class method.

While the Toy.toy_func method you used during the forward pass is not the same as demo.Toy.toy_func. So the method demo.Toy.toy_func is not executed actually and the output data of demo.Toy.toy_func is not recorded by the MethodOutputsRecorder. Hence it would report an out-of-range error.

humu789 pushed a commit to humu789/mmrazor that referenced this issue Feb 13, 2023
* [refactor][API2.0]  Add onnx export and jit trace (open-mmlab#419)

* first commit

* add async call

* add new api onnx export and jit trace

* add decorator

* fix ci

* fix torchscript ci

* fix loader

* better pipemanager

* remove comment, better import

* add kwargs

* remove comment

* better pipeline manager

* remove print

* [Refactor][API2.0] Api partition calibration (open-mmlab#433)

* first commit

* add async call

* add new api onnx export and jit trace

* add decorator

* fix ci

* fix torchscript ci

* fix loader

* better pipemanager

* remove comment, better import

* add partition

* move calibration

* Better create_calib_table

* better deploy

* add kwargs

* remove comment

* better pipeline manager

* rename api, remove reduant variable, and misc

* [Refactor][API2.0] Api ncnn openvino (open-mmlab#435)

* first commit

* add async call

* add new api onnx export and jit trace

* add decorator

* fix ci

* fix torchscript ci

* fix loader

* better pipemanager

* remove comment, better import

* add ncnn api

* finish ncnn api

* add openvino support

* add kwargs

* remove comment

* better pipeline manager

* merge fix

* merge util and onnx2ncnn

* fix docstring

* [Refactor][API2.0] API for TensorRT (open-mmlab#519)

* first commit

* add async call

* add new api onnx export and jit trace

* add decorator

* fix ci

* fix torchscript ci

* fix loader

* better pipemanager

* remove comment, better import

* add partition

* move calibration

* Better create_calib_table

* better deploy

* add kwargs

* remove comment

* Add tensorrt API

* better pipeline manager

* add tensorrt new api

* remove print

* rename api, remove reduant variable, and misc

* add docstring

* [Refactor][API2.0] Api ppl other (open-mmlab#528)

* first commit

* add async call

* add new api onnx export and jit trace

* add decorator

* fix ci

* fix torchscript ci

* fix loader

* better pipemanager

* remove comment, better import

* add kwargs

* Add new APIS for pplnn sdk and misc

* remove comment

* better pipeline manager

* merge fix

* update tools/onnx2pplnn.py

* rename function
@pppppM pppppM closed this as completed Mar 7, 2023
@pkyzh2006
Copy link

Hi @cape-zck ! Thank you for your issue. If we set the source module name to demo.Toy.toy_func, the method output recorder will first import this target module demo.Toy by import_modules_from_strings('demo.Toy') and then wrap its toy_func method by the method_record_wrapper method in MethodOutputsRecorder. In this way, we can record the output data of a specific class method.

While the Toy.toy_func method you used during the forward pass is not the same as demo.Toy.toy_func. So the method demo.Toy.toy_func is not executed actually and the output data of demo.Toy.toy_func is not recorded by the MethodOutputsRecorder. Hence it would report an out-of-range error.

Could you help me check #636

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants