Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.fx.proxy.TraceError: class MMArchitectureQuant #621

Open
Levi-zhan opened this issue Jan 5, 2024 · 11 comments
Open

torch.fx.proxy.TraceError: class MMArchitectureQuant #621

Levi-zhan opened this issue Jan 5, 2024 · 11 comments
Labels
bug Something isn't working

Comments

@Levi-zhan
Copy link

Describe the bug

torch.fx.proxy.TraceError: class MMArchitectureQuant in mmrazor/models/algorithms/quantization/mm_architecture.py: Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or **kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors

I am currently trying to quantify the segmentation model, and the configuration file is as follows Then I reported the bug above Can you help me check how to solve it? Thank you.

The base configuration file is a segmentation model I modified based on DDRNet, with only 3 categories, and all other configurations are consistent

base = [
'mmseg::ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k-1024x1024_label3.py',
'../../deploy_cfgs/mmseg/set_tensorrt-int8-explicit-1024x1024_label3.py'
]

base.val_dataloader.batch_size = 32

test_cfg = dict(
type='mmrazor.PTQLoop',
calibrate_dataloader=base.val_dataloader,
calibrate_steps=32,
)

float_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth' # noqa: E501

global_qconfig = dict(
w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'),
a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'),
w_fake_quant=dict(type='mmrazor.FakeQuantize'),
a_fake_quant=dict(type='mmrazor.FakeQuantize'),
w_qscheme=dict(
qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True),
a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True),
)
crop_size = (1024, 1024)
model = dict(
delete=True,
type='mmrazor.MMArchitectureQuant',
data_preprocessor = dict(
type='mmseg.SegDataPreProcessor',
size=crop_size,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255),
architecture=base.model,
deploy_cfg=base.deploy_cfg,
float_checkpoint=float_checkpoint,
quantizer=dict(
type='mmrazor.TensorRTQuantizer',
global_qconfig=global_qconfig,
tracer=dict(
type='mmrazor.CustomTracer',
skipped_methods=[
'mmseg.models.decode_heads.ddr_head.DDRHead.loss_by_feat',
])))

model_wrapper_cfg = dict(
type='mmrazor.MMArchitectureQuantDDP',
broadcast_buffers=False,
find_unused_parameters=True)

custom_hooks = []

May I ask where my configuration file is written incorrectly? thanke you!

@Levi-zhan Levi-zhan added the bug Something isn't working label Jan 5, 2024
@elisa-aleman
Copy link

This might be tangentially related to what I encountered in the mmpose TopdownEstimator in issue #3012

You might need to refactor the model so that there is no self-referencing methods within it, and instead point to wrapped outer methods.

I haven't checked if thats the case for mmseg but it might point you in the right direction.

@Veccoy
Copy link

Veccoy commented May 23, 2024

Hi, I have the same problem with the class EncoderDecoder from the segmentors of MMSegmentation (line 208). Did you manage to refactor your model and how?

@elisa-aleman
Copy link

Hi, I have the same problem with the class EncoderDecoder from the segmentors of MMSegmentation (line 208). Did you manage to refactor your model and how?

Yes, I haven't posted an issue yet, but you should mimic the structure in mmpretrain.models.heads.cls_head.ClsHead where there is an additional _get_loss and _get_predict that handle all the untraceable methods, and only trace the code where forward is being called on the input.

@Veccoy
Copy link

Veccoy commented May 28, 2024

Thank you. I have changed the following argument of the MMRazor CustomTracer to fit with the EncoderDecoder class:

skipped_methods=[
                'mmseg.models.decode_heads.decode_head.BaseDecodeHead.predict_by_feat',
                'mmseg.models.decode_heads.decode_head.BaseDecodeHead.loss_by_feat']

Both auxiliary head (FCNHead) and decode head (PSPHead) use the the same predict and loss functions.

Moreover, I have take the whole code of the EncoderDecoder predict method out of the class (except from the self.inference() call), by creating functions with a @torch.fx.wrap decorator.

def predict(self,
                inputs: Tensor,
                data_samples: OptSampleList = None) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing.

        Args:
            inputs (Tensor): Inputs with shape (N, C, H, W).
            data_samples (List[:obj:`SegDataSample`], optional): The seg data
                samples. It usually includes information such as `metainfo`
                and `gt_sem_seg`.

        Returns:
            list[:obj:`SegDataSample`]: Segmentation results of the
            input images. Each SegDataSample usually contain:

            - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
            - ``seg_logits``(PixelData): Predicted logits of semantic
                segmentation before normalization.
        """
        batch_img_metas = _prepare_batch(inputs, data_samples)

        seg_logits = self.inference(inputs, batch_img_metas)

        return postprocess_result(self.decode_head, seg_logits, data_samples)

The problem now is when calling the EncoderDecoder loss function, it calls the EncoderDecoder _decode_head_forward_train and _auxiliary_head_forward_train functions which try to update a dictionnary of losses. I can't make the same changes you have made in mmpose TopdownEstimator for the loss function, as the latter two functions update the dictionnary.

Do I have to pass the EncoderDecoder loss function entirely to skipped_methods, or is this a bigger issue?

Here is the full log of the issue:

/opt/conda/lib/python3.10/site-packages/mmseg/models/backbones/resnet.py:431: UserWarning: DeprecationWarning: pretrained is a deprecated, please use "init_cfg" instead
  warnings.warn('DeprecationWarning: pretrained is a deprecated, '
/opt/conda/lib/python3.10/site-packages/mmseg/models/builder.py:36: UserWarning: ``build_loss`` would be deprecated soon, please use ``mmseg.registry.MODELS.build()`` 
  warnings.warn('``build_loss`` would be deprecated soon, please use '
/opt/conda/lib/python3.10/site-packages/mmseg/models/losses/cross_entropy_loss.py:235: UserWarning: Default ``avg_non_ignore`` is False, if you would like to ignore the certain label and average loss over non-ignore labels, which is the same with PyTorch official cross_entropy, set ``avg_non_ignore=True``.
  warnings.warn(
Loads checkpoint by local backend from path: /workspace/mmlab/MMR/qat/seg/pspnet_r18-d8_512x1024_80k_cityscapes_20201225_021458-09ffa746.pth
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/workspace/mmlab/mmrazor/tools/train.py", line 121, in <module>
    main()
  File "/workspace/mmlab/mmrazor/tools/train.py", line 114, in main
    runner = Runner.from_cfg(cfg)
  File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/runner.py", line 462, in from_cfg
    runner = cls(
  File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/runner.py", line 429, in __init__
    self.model = self.build_model(model)
  File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/runner.py", line 836, in build_model
    model = MODELS.build(model)
  File "/opt/conda/lib/python3.10/site-packages/mmengine/registry/registry.py", line 570, in build
    return self.build_func(cfg, *args, **kwargs, registry=self)
  File "/opt/conda/lib/python3.10/site-packages/mmengine/registry/build_functions.py", line 232, in build_model_from_cfg
    return build_from_cfg(cfg, registry, default_args)
  File "/opt/conda/lib/python3.10/site-packages/mmengine/registry/build_functions.py", line 121, in build_from_cfg
    obj = obj_cls(**args)  # type: ignore
  File "/workspace/mmlab/mmrazor/mmrazor/models/algorithms/quantization/mm_architecture.py", line 90, in __init__
    self.qmodels = self._build_qmodels(self.architecture)
  File "/workspace/mmlab/mmrazor/mmrazor/models/algorithms/quantization/mm_architecture.py", line 300, in _build_qmodels
    observed_module = self.quantizer.prepare(model, concrete_args)
  File "/workspace/mmlab/mmrazor/mmrazor/models/quantizers/native_quantizer.py", line 231, in prepare
    traced_graph = self.tracer.trace(model, concrete_args=concrete_args)
  File "/workspace/mmlab/mmrazor/mmrazor/models/task_modules/tracer/fx/custom_tracer.py", line 422, in trace
    'output', (self.create_arg(fn(*args)), ), {},
  File "/opt/conda/lib/python3.10/site-packages/mmseg/models/segmentors/base.py", line 94, in forward
    return self.loss(inputs, data_samples)
  File "/opt/conda/lib/python3.10/site-packages/mmseg/models/segmentors/encoder_decoder.py", line 179, in loss
    loss_decode = self._decode_head_forward_train(x, data_samples)
  File "/opt/conda/lib/python3.10/site-packages/mmseg/models/segmentors/encoder_decoder.py", line 143, in _decode_head_forward_train
    losses.update(add_prefix(loss_decode, 'decode'))
  File "/opt/conda/lib/python3.10/site-packages/mmseg/utils/misc.py", line 24, in add_prefix
    for name, value in inputs.items():
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 274, in __iter__
    return self.tracer.iter(self)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 183, in iter
    raise TraceError('Proxy object cannot be iterated. This can be '
torch.fx.proxy.TraceError: Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or **kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors

@elisa-aleman
Copy link

elisa-aleman commented May 29, 2024

@Veccoy

Passing the entire loss function to skipped_methods will disallow the fake quantize observers to be calibrated, but anything that is inside the loss function that is not calling the head forward call can be refactored in another method which you can then skip. Basically you want the tracer to trace all nodes that are common between forward, predict, and loss, but not anything else necessarily.

In this case something like this should work:

    def _get_loss(self, x: Tensor, data_samples: SampleList) -> dict:
        """Calculate losses from a batch of inputs and data samples.

        Args:
            x (Tensor): forward call result.
            data_samples (list[:obj:`SegDataSample`]): The seg data samples.
                It usually includes information such as `metainfo` and
                `gt_sem_seg`.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """

        losses = dict()

        loss_decode = self._decode_head_forward_train(x, data_samples)
        losses.update(loss_decode)

        if self.with_auxiliary_head:
            loss_aux = self._auxiliary_head_forward_train(x, data_samples)
            losses.update(loss_aux)

        return losses

    

    def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
        """Calculate losses from a batch of inputs and data samples.

        Args:
            inputs (Tensor): Input images.
            data_samples (list[:obj:`SegDataSample`]): The seg data samples.
                It usually includes information such as `metainfo` and
                `gt_sem_seg`.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """

        x = self.extract_feat(inputs)

        losses = self._get_loss(x, data_samples)

        return losses

with a config that skips _get_loss

@Veccoy
Copy link

Veccoy commented May 29, 2024

Thank you for your answer. Unfortunately, this doesn't work (see traceback below). It seems to be a malfunction in the trace function when dealing with the 'loss' mode.

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/workspace/mmlab/mmrazor/tools/train.py", line 121, in <module>
    main()
  File "/workspace/mmlab/mmrazor/tools/train.py", line 114, in main
    runner = Runner.from_cfg(cfg)
  File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/runner.py", line 462, in from_cfg
    runner = cls(
  File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/runner.py", line 429, in __init__
    self.model = self.build_model(model)
  File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/runner.py", line 836, in build_model
    model = MODELS.build(model)
  File "/opt/conda/lib/python3.10/site-packages/mmengine/registry/registry.py", line 570, in build
    return self.build_func(cfg, *args, **kwargs, registry=self)
  File "/opt/conda/lib/python3.10/site-packages/mmengine/registry/build_functions.py", line 232, in build_model_from_cfg
    return build_from_cfg(cfg, registry, default_args)
  File "/opt/conda/lib/python3.10/site-packages/mmengine/registry/build_functions.py", line 121, in build_from_cfg
    obj = obj_cls(**args)  # type: ignore
  File "/workspace/mmlab/mmrazor/mmrazor/models/algorithms/quantization/mm_architecture.py", line 90, in __init__
    self.qmodels = self._build_qmodels(self.architecture)
  File "/workspace/mmlab/mmrazor/mmrazor/models/algorithms/quantization/mm_architecture.py", line 300, in _build_qmodels
    observed_module = self.quantizer.prepare(model, concrete_args)
  File "/workspace/mmlab/mmrazor/mmrazor/models/quantizers/native_quantizer.py", line 231, in prepare
    traced_graph = self.tracer.trace(model, concrete_args=concrete_args)
  File "/workspace/mmlab/mmrazor/mmrazor/models/task_modules/tracer/fx/custom_tracer.py", line 422, in trace
    'output', (self.create_arg(fn(*args)), ), {},
  File "/opt/conda/lib/python3.10/site-packages/mmseg/models/segmentors/base.py", line 94, in forward
    return self.loss(inputs, data_samples)
  File "/opt/conda/lib/python3.10/site-packages/mmseg/models/segmentors/encoder_decoder.py", line 205, in loss
    losses = self._get_loss(x, data_samples)
  File "/workspace/mmlab/mmrazor/mmrazor/models/task_modules/tracer/fx/custom_tracer.py", line 72, in wrapped_method
    return self.tracer.call_method(mod, self.name, method, args,
  File "/workspace/mmlab/mmrazor/mmrazor/models/task_modules/tracer/fx/custom_tracer.py", line 317, in call_method
    return self.create_proxy('call_method', name, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 66, in create_proxy
    args_ = self.create_arg(args)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 344, in create_arg
    return super().create_arg(a)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 140, in create_arg
    return type(a)(self.create_arg(elem) for elem in a)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 140, in <genexpr>
    return type(a)(self.create_arg(elem) for elem in a)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 298, in create_arg
    return self.create_node("get_attr", n_, (), {})
  File "/opt/conda/lib/python3.10/site-packages/torch/ao/quantization/fx/tracer.py", line 114, in create_node
    node = super().create_node(kind, target, args, kwargs, name, type_expr)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 46, in create_node
    return self.graph.create_node(kind, target, args, kwargs, name, type_expr)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/graph.py", line 777, in create_node
    name = self._graph_namespace.create_name(candidate, None)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/graph.py", line 137, in create_name
    if candidate[0].isdigit():
IndexError: string index out of range

When the trace function of CustomTracer is called, it calls the create_arg method of torch fx for the forward method of EncoderDecoder and several of its modules. However, one of these modules is the EncoderDecoder itself (not submodules), which should not. It enters in create_arg and crashes in this condition because the EncoderDecoder module has no name n_ (empty string).

I think the problem comes from the fact that the _get_loss function is still in the EncoderDecoder class: this makes the EncoderDecoder model appear in the arguments of the create_arg method. I had the same issue and traceback with the tracing of the 'predict' mode and I made some changes (see in this comment). I take the _prepare_batch and postprocess_result functions out of the class and put the @torch.fx.wrap decorator on top, which enables the tracing for the 'predict' mode.

@elisa-aleman
Copy link

elisa-aleman commented May 30, 2024

@Veccoy

The above Traceback makes me think that you didn't add EncoderDecoder._get_loss to skipped_methods. Can you tell me if that is the case?

EDIT: I see, so EncoderDecider is not a submodule, sorry, if so, you'll need to refactor the loss function into not using .update for dicts, since that is what makes it untraceable

EDIT 2: Or, alternatively, factor the dict handling out of the class and decorate it with @torch.fx.wrap

EDIT 3: You might also need to refactor and skip the refactored code from the decoder head and auxiliary head losses when they also handle dictionaries.

@Veccoy
Copy link

Veccoy commented May 31, 2024

Thank you! Indeed, it works by refactoring the dict handling the batch preparation in respectively the lossand predictmethod of the EncoderDecoder class and the postprocess_result method of the class BaseSegmentor and decorating it with @torch.fx.wrap. I also put the BaseDecodeHead.predict_by_feat, the PSPHead.loss_by_feat and the FCNHead.loss methods in the skipped_method argument.

What is the difference between the use of the @torch.fx.wrap decorator and the skipped_method argument if both try to handle untraceable code? When using one instead of the other?

@elisa-aleman
Copy link

Thank you! Indeed, it works by refactoring the dict handling the batch preparation in respectively the lossand predictmethod of the EncoderDecoder class and the postprocess_result method of the class BaseSegmentor and decorating it with @torch.fx.wrap. I also put the BaseDecodeHead.predict_by_feat, the PSPHead.loss_by_feat and the FCNHead.loss methods in the skipped_method argument.

Do make sure that the FCNHead.loss doesn't have any nodes in common i
with EncoderDecoder.forward, or the fake quants won't calibrate correctly.

What is the difference between the use of the @torch.fx.wrap decorator and the skipped_method argument if both try to handle untraceable code? When using one instead of the other?

@torch.fx.wrap is mainly for functions, and I use it for things that either repeat across classes or that are on the root class I'm trying to trace. In contrast, skipped_methods works only on submodule methods, but theoretically if you can skip it without refactoring it is more convenient.

@Veccoy
Copy link

Veccoy commented May 31, 2024

How do you check if methods have nodes in common? FCNHead and PSPHead both inherit from the same loss method in the BaseDecodeHead class, that only do the forward of the head and the computation of the loss. But these heads are submodules inside the EncoderDecoder model.

@elisa-aleman
Copy link

Anything that has a forward calculation would need to not be skipped. one way to check is adding a printout of the JIT graph within mmrazor's CustomTracer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants