diff --git a/test/quantization/test_quantize.py b/test/quantization/test_quantize.py index 067c35bd3c64..c47982f0c0cc 100644 --- a/test/quantization/test_quantize.py +++ b/test/quantization/test_quantize.py @@ -726,6 +726,20 @@ def forward(self, x): ref_res = ref_m(data) self.assertEqual(res, ref_res) + @skipIfNoFBGEMM + def test_convtranspose_per_channel_fails_early(self): + r""" + Verifies that attempting to quantize a ConvTranspose module with per-Channel + weight observers fails in the prepare step, as opposed to the convert step. + """ + m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1)) + m.qconfig = torch.quantization.get_default_qconfig('fbgemm') + with self.assertRaises(AssertionError) as context: + mp = torch.quantization.prepare(m) + self.assertTrue( + str(context.exception) == + 'Per channel weight observer is not supported yet for ConvTranspose{n}d.') + @skipIfNoFBGEMM class TestPostTrainingDynamic(QuantizationTestCase): diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 545e70a2c5e6..d014bd31f02e 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -1278,6 +1278,21 @@ def test_fp32_input_fp32_output(self): self._test_quantized_inputs_outputs( prepare_custom_config_dict, prepare_count_check, convert_count_check) + @skipIfNoFBGEMM + def test_convtranspose_per_channel_fails_early(self): + r""" + Verifies that attempting to quantize a ConvTranspose module with per-Channel + weight observers fails in the prepare step, as opposed to the convert step. + """ + m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1)) + m.eval() + qconfig_dict = {'': torch.quantization.get_default_qconfig('fbgemm')} + with self.assertRaises(AssertionError) as context: + mp = prepare_fx(m, qconfig_dict) + self.assertTrue( + str(context.exception) == + 'Per channel weight observer is not supported yet for ConvTranspose{n}d.') + @skipIfNoFBGEMM class TestQuantizeFxOps(QuantizationTestCase): """Unit tests for individual ops diff --git a/torch/quantization/qconfig.py b/torch/quantization/qconfig.py index 8da4ad6bb182..2d91d8ab6b3e 100644 --- a/torch/quantization/qconfig.py +++ b/torch/quantization/qconfig.py @@ -3,6 +3,8 @@ from .fake_quantize import * import torch.nn as nn +from typing import Union + class QConfig(namedtuple('QConfig', ['activation', 'weight'])): """ Describes how to quantize a layer or a part of the network by providing @@ -109,3 +111,18 @@ def get_default_qat_qconfig(backend='fbgemm'): else: qconfig = default_qat_qconfig return qconfig + +def assert_valid_qconfig(qconfig: Union[QConfig, QConfigDynamic], + mod: torch.nn.Module) -> None: + is_conv_transpose_mod = ( + isinstance(mod, torch.nn.ConvTranspose1d) or + isinstance(mod, torch.nn.ConvTranspose2d) or + isinstance(mod, torch.nn.ConvTranspose3d)) + if is_conv_transpose_mod: + example_observer = qconfig.weight() + is_per_channel = ( + isinstance(example_observer, torch.quantization.PerChannelMinMaxObserver) or + isinstance(example_observer, torch.quantization.MovingAveragePerChannelMinMaxObserver) + ) + assert not is_per_channel, \ + 'Per channel weight observer is not supported yet for ConvTranspose{n}d.' diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index 1be867e0a299..a57a4ea6bcb8 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -50,6 +50,8 @@ def _propagate_qconfig_helper(module, qconfig_dict, allow_list=None, module_qconfig = qconfig_dict.get(prefix, module_qconfig) module_qconfig = getattr(module, 'qconfig', module_qconfig) + torch.quantization.qconfig.assert_valid_qconfig(module_qconfig, module) + module.qconfig = module_qconfig for name, child in module.named_children(): module_prefix = prefix + '.' + name if prefix else name