Skip to content

Commit

Permalink
quant: nice error message on convtranspose with per-channel weight (p…
Browse files Browse the repository at this point in the history
…ytorch#49899)

Summary:
Pull Request resolved: pytorch#49899

Per channel weights observer in conv transpose is not supported yet.  Adding an
error message which fails instantly instead of making the user wait until after
calibration/training finishes.

Test Plan:
```
python test/test_quantization.py TestPostTrainingStatic.test_convtranspose_per_channel_fails_early
python test/test_quantization.py TestQuantizeFx.test_convtranspose_per_channel_fails_early
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D25717151

fbshipit-source-id: 093e5979030ec185e3e0d56c45d7ce7338bf94b6
  • Loading branch information
vkuzo authored and hwangdeyu committed Jan 14, 2021
1 parent d1568a1 commit 499f740
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 0 deletions.
14 changes: 14 additions & 0 deletions test/quantization/test_quantize.py
Expand Up @@ -726,6 +726,20 @@ def forward(self, x):
ref_res = ref_m(data)
self.assertEqual(res, ref_res)

@skipIfNoFBGEMM
def test_convtranspose_per_channel_fails_early(self):
r"""
Verifies that attempting to quantize a ConvTranspose module with per-Channel
weight observers fails in the prepare step, as opposed to the convert step.
"""
m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1))
m.qconfig = torch.quantization.get_default_qconfig('fbgemm')
with self.assertRaises(AssertionError) as context:
mp = torch.quantization.prepare(m)
self.assertTrue(
str(context.exception) ==
'Per channel weight observer is not supported yet for ConvTranspose{n}d.')


@skipIfNoFBGEMM
class TestPostTrainingDynamic(QuantizationTestCase):
Expand Down
15 changes: 15 additions & 0 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -1278,6 +1278,21 @@ def test_fp32_input_fp32_output(self):
self._test_quantized_inputs_outputs(
prepare_custom_config_dict, prepare_count_check, convert_count_check)

@skipIfNoFBGEMM
def test_convtranspose_per_channel_fails_early(self):
r"""
Verifies that attempting to quantize a ConvTranspose module with per-Channel
weight observers fails in the prepare step, as opposed to the convert step.
"""
m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1))
m.eval()
qconfig_dict = {'': torch.quantization.get_default_qconfig('fbgemm')}
with self.assertRaises(AssertionError) as context:
mp = prepare_fx(m, qconfig_dict)
self.assertTrue(
str(context.exception) ==
'Per channel weight observer is not supported yet for ConvTranspose{n}d.')

@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
Expand Down
17 changes: 17 additions & 0 deletions torch/quantization/qconfig.py
Expand Up @@ -3,6 +3,8 @@
from .fake_quantize import *
import torch.nn as nn

from typing import Union

class QConfig(namedtuple('QConfig', ['activation', 'weight'])):
"""
Describes how to quantize a layer or a part of the network by providing
Expand Down Expand Up @@ -109,3 +111,18 @@ def get_default_qat_qconfig(backend='fbgemm'):
else:
qconfig = default_qat_qconfig
return qconfig

def assert_valid_qconfig(qconfig: Union[QConfig, QConfigDynamic],
mod: torch.nn.Module) -> None:
is_conv_transpose_mod = (
isinstance(mod, torch.nn.ConvTranspose1d) or
isinstance(mod, torch.nn.ConvTranspose2d) or
isinstance(mod, torch.nn.ConvTranspose3d))
if is_conv_transpose_mod:
example_observer = qconfig.weight()
is_per_channel = (
isinstance(example_observer, torch.quantization.PerChannelMinMaxObserver) or
isinstance(example_observer, torch.quantization.MovingAveragePerChannelMinMaxObserver)
)
assert not is_per_channel, \
'Per channel weight observer is not supported yet for ConvTranspose{n}d.'
2 changes: 2 additions & 0 deletions torch/quantization/quantize.py
Expand Up @@ -50,6 +50,8 @@ def _propagate_qconfig_helper(module, qconfig_dict, allow_list=None,
module_qconfig = qconfig_dict.get(prefix, module_qconfig)
module_qconfig = getattr(module, 'qconfig', module_qconfig)

torch.quantization.qconfig.assert_valid_qconfig(module_qconfig, module)

module.qconfig = module_qconfig
for name, child in module.named_children():
module_prefix = prefix + '.' + name if prefix else name
Expand Down

0 comments on commit 499f740

Please sign in to comment.