Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] (suggested fix) mmrazor.models.algorithms.quantization.mm_architecture.MMArchitectureQuant.sync_qparams() fails if there are modules present in other modes but not in forward mode='tensor' #634

Open
elisa-aleman opened this issue Apr 5, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@elisa-aleman
Copy link

elisa-aleman commented Apr 5, 2024

Describe the bug

In models where theres modules that exist only in mode 'predict' or in 'loss' but not in 'tensor', the following code fails with a KeyError looking through the state dict of the tensor mode model. For example, if one model has duplicates but the other doesn't.

mmrazor.models.algorithms.quantization.mm_architecture.MMArchitectureQuant.sync_params()#L124--L148

        def traverse(module, prefix):
            for name, child in module._modules.items():
                if module is None:
                    continue
                child_name = f'{prefix}{name}'
                if isinstance(child, FakeQuantizeBase):
                    for name, param in child.named_parameters():
                        param_name = f'{child_name}.{name}'
                        src_param = src_state_dict[param_name]  ## Here
                        if src_param.shape == param.shape:
                            param.data.copy_(src_param)
                        else:
                            requirs_grad = param.requires_grad
                            param.requires_grad = False
                            param.resize_(src_param.shape)
                            param.requires_grad = requirs_grad
                            param.data.copy_(src_param)
                    for name, buffer in child.named_buffers():
                        buffer_name = f'{child_name}.{name}'
                        src_buffer = src_state_dict[buffer_name] # here
                        if src_buffer.shape == buffer.shape:
                            buffer.data.copy_(src_buffer)
                        else:
                            buffer.resize_(src_buffer.shape)
                            buffer.data.copy_(src_buffer)

Additional Context

I have been trying to quantize the mmpose.TopdownPoseEstimator, applying fixes for torch 2.0.0 incompatibility suggested in mmrazor #632, a fix for nn.Parameters inside TopdownPoseEstimator not being traced in mmrazor #633, and a fix on mmpose.TopdownPoseEstimator untraceable methods in mmpose #3012.

Because of a flip input inversion test being added to the predict forward graph, not only are there duplicate modules but also duplicate loose (leaf) activation_post_process_xyz numbered modules that make the syncing fail.

Reproduces the error - code sample

I cannot currently provide the configuration, but the executing code is this:

from mmrazor.models.algorithms.quantization.mm_architechture import MMArchitectureQuant
from mmengine import Config

cfg = Config.fromfile('qat_rtmpose-t_8xb256-420e_coco-256x192.py')

qtopdown = MMArchitectureQuant(
    data_preprocessor=cfg.data_preprocessor,
    architecture=cfg.architecture,
    quantizer=cfg.model.quantizer,
    input_shapes=cfg.model.input_shapes
)

Reproduces the problem - error message

Traceback (most recent call last):
  File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 91, in __init__
    self.sync_qparams('tensor')
   File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 156, in sync_qparams
    .....redacted
  File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 143, in traverse
    src_buffer = src_state_dict[buffer_name]
                 ~~~~~~~~~~~~~~~~~~~~~~~~~~~
KeyError: 'backbone.stem.0.conv_dup1.weight_fake_quant.fake_quant_enabled'

And while patching that:

Traceback (most recent call last):
  File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 91, in __init__
    self.sync_qparams('tensor')
   File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 156, in sync_qparams
    .....redacted
  File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 157, in traverse
    raise KeyError(f"{buffer_name} in mode '{mode}' but not found in source mode '{tensor}', sync_qparams() failed.")
KeyError: "activation_post_process_123.fake_quant_enabled in mode 'predict' but not found in source mode 'tensor', sync_qparams() failed."

Post related information - suggested fix


*EDIT: while this fix allows for syncing of nodes that aren't in other modes, it causes failure in model deployment later down the line

For duplicate modules i figure one can copy the state_dict element with a non-suffixed name, but I don't have a suggestion for non existent modules yet.

For activation post processing leaf nodes, I can ignore most of the copying since a lot of it is reset in MMArchitectureQuant.__init__().

mmrazor/models/algorithms/quantization/mm_architecture.py

@@ -121,7 +121,7 @@ class MMArchitectureQuant(BaseAlgorithm):
                 in some subtle ways, so we need to sync them here.
         """
 
-        def traverse(module, prefix):
+        def traverse(module, prefix, mode, src_mode):
             for name, child in module._modules.items():
                 if module is None:
                     continue
@@ -129,7 +129,14 @@ class MMArchitectureQuant(BaseAlgorithm):
                 if isinstance(child, FakeQuantizeBase):
                     for name, param in child.named_parameters():
                         param_name = f'{child_name}.{name}'
-                        src_param = src_state_dict[param_name]
+                        src_param = src_state_dict.get(param_name)
+                        if '_dup' in param_name and src_param is None:
+                            param_name = '.'.join([section.split('_dup')[0] for section in param_name.split('.')])
+                            src_param = src_state_dict.get(param_name)
+                        if src_param is None:
+                            print(src_state_dict)
+                            print(child)
+                            raise KeyError(f"{param_name} in mode: '{mode}' but not found in source mode: '{src_mode}', sync_qparams() failed.")
                         if src_param.shape == param.shape:
                             param.data.copy_(src_param)
                         else:
@@ -138,22 +145,42 @@ class MMArchitectureQuant(BaseAlgorithm):
                             param.resize_(src_param.shape)
                             param.requires_grad = requirs_grad
                             param.data.copy_(src_param)
+                    # These are either reset after sync_qparams() is called, or are left as default (eps)
+                    # so there's no need to sync them if there's not a match
+                    skip_buffer_sync = [
+                        "fake_quant_enabled",
+                        "observer_enabled",
+                        "scale",
+                        "zero_point",
+                        "min_val",
+                        "max_val",
+                        "eps",
+                    ]
                     for name, buffer in child.named_buffers():
                         buffer_name = f'{child_name}.{name}'
-                        src_buffer = src_state_dict[buffer_name]
+                        src_buffer = src_state_dict.get(buffer_name)
+                        if '_dup' in buffer_name and src_buffer is None:
+                            buffer_name = '.'.join([section.split('_dup')[0] for section in buffer_name.split('.')])
+                            src_buffer = src_state_dict.get(buffer_name)
+                        if any([s in buffer_name for s in skip_buffer_sync]) and src_buffer is None:
+                            continue
+                            src_buffer = torch.tensor([1], dtype=torch.uint8)
+                        if src_buffer is None:
+                            print(src_state_dict)
+                            print(child)
+                            raise KeyError(f"{buffer_name} in mode: '{mode}' but not found in source mode: '{src_mode}', sync_qparams() failed.")
                         if src_buffer.shape == buffer.shape:
                             buffer.data.copy_(src_buffer)
                         else:
                             buffer.resize_(src_buffer.shape)
                             buffer.data.copy_(src_buffer)
                 else:
-                    traverse(child, f'{child_name}.')
+                    traverse(child, f'{child_name}.', mode, src_mode)
         src_state_dict = self.qmodels[src_mode].state_dict()
         for mode in self.forward_modes:
             if mode == src_mode:
                 continue
-            traverse(self.qmodels[mode], '')
+            traverse(self.qmodels[mode], '', mode, src_mode)

     def _get_rewriter_context_in_mmdeploy(self, deploy_cfg):
         """Get rewriter context in mmdeploy according to the deploy related
@elisa-aleman elisa-aleman added the bug Something isn't working label Apr 5, 2024
@elisa-aleman elisa-aleman changed the title [Bug] mmrazor.models.algorithms.quantization.mm_architecture.MMArchitectureQuant.sync_params() fails if there are modules present in other modes but not in forward mode='tensor' [Bug] mmrazor.models.algorithms.quantization.mm_architecture.MMArchitectureQuant.sync_qparams() fails if there are modules present in other modes but not in forward mode='tensor' Apr 10, 2024
@elisa-aleman elisa-aleman changed the title [Bug] mmrazor.models.algorithms.quantization.mm_architecture.MMArchitectureQuant.sync_qparams() fails if there are modules present in other modes but not in forward mode='tensor' [Bug] mmrazor.models.algorithms.quantization.mm_architecture.MMArchitectureQuant.sync_qparams() fails if there are modules present in other modes but not in forward mode='tensor' Apr 10, 2024
@elisa-aleman
Copy link
Author

Added more context and suggested a fix

@elisa-aleman elisa-aleman changed the title [Bug] mmrazor.models.algorithms.quantization.mm_architecture.MMArchitectureQuant.sync_qparams() fails if there are modules present in other modes but not in forward mode='tensor' [Bug] (suggested fix) mmrazor.models.algorithms.quantization.mm_architecture.MMArchitectureQuant.sync_qparams() fails if there are modules present in other modes but not in forward mode='tensor' Apr 10, 2024
@elisa-aleman
Copy link
Author

elisa-aleman commented Apr 18, 2024

After trying to deploy the quantized model, I realized the suggested fix might be unnecessary and cause further issues since the mmdeploy/tools/deploy.py will force model.architecture.test_cfg.flip_test=False for pose estimators, which means that there would be extra weights in the quantized state_dict and cause the model deploy to fail.

I then tried:

python /tools/train.py \
    ${qat_topdown_cgf} \
    --cgf-options \
         model.architecture.test_cfg.flip_test=False \
    --work-dir /path/here/

But the model still fails to sync without my patch.

@elisa-aleman
Copy link
Author

I realized that the sync_qparams() is also called from the loss mode as a source mode during the training loop, so my previous fix actually removes any progress during training. I suggest this new fix that doesn't reset fake weight values if not found, although I've yet to finish deploying this model and so it's subject to changes.

@@ -121,7 +121,7 @@ class MMArchitectureQuant(BaseAlgorithm):
                 in some subtle ways, so we need to sync them here.
         """
 
-        def traverse(module, prefix):
+        def traverse(module, prefix, mode, src_mode):
             for name, child in module._modules.items():
                 if module is None:
                     continue
@@ -129,7 +129,13 @@ class MMArchitectureQuant(BaseAlgorithm):
                 if isinstance(child, FakeQuantizeBase):
                     for name, param in child.named_parameters():
                         param_name = f'{child_name}.{name}'
-                        src_param = src_state_dict[param_name]
+                        src_param = src_state_dict.get(param_name)
+                        if '_dup' in param_name and src_param is None:
+                            param_name = '.'.join([section.split('_dup')[0] for section in param_name.split('.')])
+                            src_param = src_state_dict.get(param_name)
+                        if src_param is None:
+                            print(f"{param_name} in mode: '{mode}' but not found in source mode: '{src_mode}', skipping sync.")
+                            continue
                         if src_param.shape == param.shape:
                             param.data.copy_(src_param)
                         else:
@@ -140,20 +146,26 @@ class MMArchitectureQuant(BaseAlgorithm):
                             param.data.copy_(src_param)
                     for name, buffer in child.named_buffers():
                         buffer_name = f'{child_name}.{name}'
-                        src_buffer = src_state_dict[buffer_name]
+                        src_buffer = src_state_dict.get(buffer_name)
+                        if '_dup' in buffer_name and src_buffer is None:
+                            buffer_name = '.'.join([section.split('_dup')[0] for section in buffer_name.split('.')])
+                            src_buffer = src_state_dict.get(buffer_name)
+                        if src_buffer is None:
+                            print(f"{buffer_name} in mode: '{mode}' but not found in source mode: '{src_mode}', skipping sync.")
+                            continue
                         if src_buffer.shape == buffer.shape:
                             buffer.data.copy_(src_buffer)
                         else:
                             buffer.resize_(src_buffer.shape)
                             buffer.data.copy_(src_buffer)
                 else:
-                    traverse(child, f'{child_name}.')
+                    traverse(child, f'{child_name}.', mode, src_mode)
         src_state_dict = self.qmodels[src_mode].state_dict()
         for mode in self.forward_modes:
             if mode == src_mode:
                 continue
-            traverse(self.qmodels[mode], '')
+            traverse(self.qmodels[mode], '', mode, src_mode)

     def _get_rewriter_context_in_mmdeploy(self, deploy_cfg):
         """Get rewriter context in mmdeploy according to the deploy related

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant