From aee427cd08c0ea3a019af9b41856cb15b908a214 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Thu, 3 Nov 2022 18:07:01 +0000 Subject: [PATCH 1/5] Ensuring has_custimzed_qrange==False for torch <=1.9.1 and >=1.12.0 --- .../sparsification/quantization/helpers.py | 77 +++++++++---------- .../quantization/modifier_quantization.py | 4 - .../quantization/test_helpers.py | 16 ++++ 3 files changed, 51 insertions(+), 46 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index ae35353e6f5..1d523d673e1 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -18,6 +18,7 @@ from copy import deepcopy from dataclasses import dataclass, field +from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -31,7 +32,7 @@ _PARSED_TORCH_VERSION = version.parse(torch.__version__) - +_TORCH_PRE_112 = _PARSED_TORCH_VERSION < version.parse("1.12.0") __all__ = [ "QATWrapper", @@ -41,7 +42,6 @@ "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", - "fix_observer_quant_range", "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", @@ -572,14 +572,45 @@ def get_observer( ): qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine quant_min, quant_max = compute_range(dtype, bits) + observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=quant_min, - quant_max=quant_max, dtype=dtype, qscheme=qscheme, reduce_range=reduce_range, ) + + # in torch 1.9.1, quant_min and quant_max are not passed to observer: + # https://github.com/pytorch/pytorch/blob/v1.9.1/torch/quantization/fake_quantize.py#L109 + # however in 1.12.0, this is fixed so both are passed to observer: + # https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/fake_quantize.py#L132 + # + # Passing quant_min/quant_max to observer means the observer will have + # `self.has_customized_qrange == True` in both 1.9.1 and 1.12.0. + # + # For whatever reason, both versions calculate zero point for + # quint8 differently **if there is a customized_qrange** + # 1. customized qrange has zero point of 127 + # 2. non-customized has zero point of 128. + # source: + # https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/observer.py#L293 + # + # **we want to ensure that the zero point is 128** + # see https://github.com/neuralmagic/sparseml/pull/604 + # + # NOTE: This assumes we *never* want a customized qrange, and that the + # compute_range above doesn't return customized qranges + if _TORCH_PRE_112: + # pre 1.12, the observer doesn't get passed the quant_min/quant_max values, + # so we are safe to pass these to FakeQuantize + observer_kwargs["quant_min"] = quant_min + observer_kwargs["quant_max"] = quant_max + else: + # post 1.12 we cannot pass them to the observer since that will set + # has_customized_qrange. instead we rely on the default values + # being equal to the `quant_min` and `quant_max` here. + pass + observer_kwargs.update(qconfig_kwargs or {}) observer = torch_quantization.FakeQuantize.with_args( **observer_kwargs, @@ -588,44 +619,6 @@ def get_observer( return observer -def fix_observer_quant_range(module: Module): - """ - As of torch 1.10.2 there is a bug in FakeQuantize initialization where - quant_min and quant_max of FakeQuantize are not propagated to its - activation_post_process observer. This function propagates FakeQuantize quant - ranges to their Observer objects - - :param module: Module object to propagate FakeQuantize quant ranges of. Propagates - in place - """ - for submodule in module.modules(): - if isinstance(submodule, torch_quantization.FakeQuantize): - fake_quantize = submodule - elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize - ): - fake_quantize = submodule.activation_post_process - else: - continue - - # continue if fake_quantize quant range not set, or observer quant range is set - observer = fake_quantize.activation_post_process - if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) - or ( # do not propagate default uint8 symmetric range - observer.qscheme == torch.per_tensor_symmetric - and fake_quantize.quant_min == 0 - and fake_quantize.quant_max == 255 - ) - ): - continue - observer.quant_min = fake_quantize.quant_min - observer.quant_max = fake_quantize.quant_max - observer.has_customized_qrange = True - - def freeze_bn_stats(module: Module): if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 8be35f73229..e1bbedba851 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -57,7 +57,6 @@ configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, - fix_observer_quant_range, freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, @@ -713,9 +712,6 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat(module, qproperties) - # propagate custom quant min/max range from FakeQuantize to Observer objects - fix_observer_quant_range(module) - self._qat_enabled = True self._calibrate_if_possible(module) diff --git a/tests/sparseml/pytorch/sparsification/quantization/test_helpers.py b/tests/sparseml/pytorch/sparsification/quantization/test_helpers.py index 14c8556fb2f..0fed29617a3 100644 --- a/tests/sparseml/pytorch/sparsification/quantization/test_helpers.py +++ b/tests/sparseml/pytorch/sparsification/quantization/test_helpers.py @@ -28,6 +28,9 @@ get_qat_qconfig, prepare_embeddings_qat, ) +from sparseml.pytorch.sparsification.quantization.modifier_quantization import ( + QuantizationModifier, +) try: @@ -283,3 +286,16 @@ def test_prepare_embeddings_qat(): module(torch.arange(10)) observed_range_min = observer.activation_post_process.min_val.item() assert orig_range_min != observed_range_min + + +def test_zero_point_is_128(): + # see https://github.com/neuralmagic/sparseml/pull/604 + + # give QATMatMul a layer to be wrapped + dummy_sequential = torch.nn.Sequential(_QATMatMul()) + QuantizationModifier().apply(dummy_sequential) + qat_matmul = dummy_sequential[0] + _ = qat_matmul(torch.randn(10, 10), torch.randn(10, 10)) + + fq = qat_matmul.input_quant_stubs[1].activation_post_process + assert fq.zero_point[0] == 128 From 30de66f416de0e386a0c2828c2f73a7dea5fb7d0 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Thu, 3 Nov 2022 18:19:35 +0000 Subject: [PATCH 2/5] Handling custom qrange in get_observer --- .../sparsification/quantization/helpers.py | 40 ++++++++++++------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 1d523d673e1..11d088a5459 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -450,9 +450,10 @@ def compute_range(dtype: torch.dtype, bits: int): :param dtype: data type. :param bits: number of bits. - :return: minimum limit, maximum limit + :return: minimum limit, maximum limit, whether the range is customized """ bits = bits if bits else 8 + is_custom = bits != 8 if dtype == torch.qint8: quant_min = -(2 ** (bits - 1)) quant_max = (2 ** (bits - 1)) - 1 @@ -460,7 +461,7 @@ def compute_range(dtype: torch.dtype, bits: int): quant_min = 0 quant_max = (2 ** bits) - 1 - return quant_min, quant_max + return quant_min, quant_max, is_custom def configure_module_default_qconfigs(module: Module): @@ -571,10 +572,10 @@ def get_observer( qconfig_kwargs: Dict[str, Any], ): qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine - quant_min, quant_max = compute_range(dtype, bits) + quant_min, quant_max, is_custom_qrange = compute_range(dtype, bits) + observer_cls = torch_quantization.MovingAverageMinMaxObserver observer_kwargs = dict( - observer=torch_quantization.MovingAverageMinMaxObserver, dtype=dtype, qscheme=qscheme, reduce_range=reduce_range, @@ -597,20 +598,31 @@ def get_observer( # # **we want to ensure that the zero point is 128** # see https://github.com/neuralmagic/sparseml/pull/604 - # - # NOTE: This assumes we *never* want a customized qrange, and that the - # compute_range above doesn't return customized qranges - if _TORCH_PRE_112: - # pre 1.12, the observer doesn't get passed the quant_min/quant_max values, - # so we are safe to pass these to FakeQuantize + if is_custom_qrange: + # for both versions we need to include the custom min/max values in kwargs observer_kwargs["quant_min"] = quant_min observer_kwargs["quant_max"] = quant_max + if _TORCH_PRE_112: + # pre 1.12, the observer doesn't get passed the quant_min/quant_max values, + # so we patch them in to the constructor of the observer + observer_cls = partial( + observer_cls, quant_min=quant_min, quant_max=quant_max + ) else: - # post 1.12 we cannot pass them to the observer since that will set - # has_customized_qrange. instead we rely on the default values - # being equal to the `quant_min` and `quant_max` here. - pass + # if using a non custom qrange, we can rely on default values used by + # the observers + if _TORCH_PRE_112: + # pre 1.12, the observer doesn't get passed the quant_min/quant_max values, + # so we are safe to pass these to FakeQuantize + observer_kwargs["quant_min"] = quant_min + observer_kwargs["quant_max"] = quant_max + else: + # post 1.12 we cannot pass them to the observer since that will set + # has_customized_qrange. instead we rely on the default values + # being equal to the `quant_min` and `quant_max` here. + pass + observer_kwargs["observer"] = observer_cls observer_kwargs.update(qconfig_kwargs or {}) observer = torch_quantization.FakeQuantize.with_args( **observer_kwargs, From d10132ae86e955e1107946c47dc3e4874e234b66 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Thu, 3 Nov 2022 18:41:33 +0000 Subject: [PATCH 3/5] Adding more tests for zero points & fixing failing modifier tests --- .../quantization/test_helpers.py | 46 +++++++++++++++++++ .../test_modifier_quantization.py | 5 +- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/tests/sparseml/pytorch/sparsification/quantization/test_helpers.py b/tests/sparseml/pytorch/sparsification/quantization/test_helpers.py index 0fed29617a3..64cd194cde9 100644 --- a/tests/sparseml/pytorch/sparsification/quantization/test_helpers.py +++ b/tests/sparseml/pytorch/sparsification/quantization/test_helpers.py @@ -28,6 +28,7 @@ get_qat_qconfig, prepare_embeddings_qat, ) +from sparseml.pytorch.sparsification.quantization.helpers import get_observer from sparseml.pytorch.sparsification.quantization.modifier_quantization import ( QuantizationModifier, ) @@ -299,3 +300,48 @@ def test_zero_point_is_128(): fq = qat_matmul.input_quant_stubs[1].activation_post_process assert fq.zero_point[0] == 128 + + +def test_standard_qrange_zero_points(): + bits = 8 + + fake_quantize = get_observer(True, torch.qint8, bits, False, {})() + fake_quantize(torch.randn(10, 10)) + assert fake_quantize.quant_min == -128 + assert fake_quantize.quant_max == 127 + assert fake_quantize.activation_post_process.quant_min == -128 + assert fake_quantize.activation_post_process.quant_max == 127 + _, zero_point = fake_quantize.calculate_qparams() + assert zero_point[0] == 0 + + fake_quantize = get_observer(True, torch.quint8, bits, False, {})() + fake_quantize(torch.randn(10, 10)) + assert fake_quantize.quant_min == 0 + assert fake_quantize.quant_max == 255 + assert fake_quantize.activation_post_process.quant_min == 0 + assert fake_quantize.activation_post_process.quant_max == 255 + _, zero_point = fake_quantize.calculate_qparams() + assert zero_point[0] == 128 + + +def test_custom_qrange_zero_points(): + # non 8 bits is what makes it a custom qrange + bits = 4 + + fake_quantize = get_observer(True, torch.qint8, bits, False, {})() + fake_quantize(torch.randn(10, 10)) + assert fake_quantize.quant_min == -8 + assert fake_quantize.quant_max == 7 + assert fake_quantize.activation_post_process.quant_min == -8 + assert fake_quantize.activation_post_process.quant_max == 7 + _, zero_point = fake_quantize.calculate_qparams() + assert zero_point[0] == 0 + + fake_quantize = get_observer(True, torch.quint8, bits, False, {})() + fake_quantize(torch.randn(10, 10)) + assert fake_quantize.quant_min == 0 + assert fake_quantize.quant_max == 15 + assert fake_quantize.activation_post_process.quant_min == 0 + assert fake_quantize.activation_post_process.quant_max == 15 + _, zero_point = fake_quantize.calculate_qparams() + assert zero_point[0] == 7 diff --git a/tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py b/tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py index dbaf88f287b..7ad69bcb17a 100644 --- a/tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py +++ b/tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py @@ -98,10 +98,9 @@ def _test_quantizable_module(module, qat_expected, modifier): if modifier.activation_bits is not None: expected_quant_min = 0 expected_quant_max = (1 << modifier.activation_bits) - 1 - activation_quant_properties = module.qconfig.activation.p.keywords - assert activation_quant_properties["quant_min"] == expected_quant_min - assert activation_quant_properties["quant_max"] == expected_quant_max + module.activation_post_process.quant_min == expected_quant_min + module.activation_post_process.quant_max == expected_quant_max if isinstance(module, Linear): assert isinstance(module.activation_post_process, Identity) == ( From c3831d5825b7a69f972ef1c05853bdb1fbae13f5 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Thu, 3 Nov 2022 18:54:44 +0000 Subject: [PATCH 4/5] Fixing tests --- .../pytorch/sparsification/quantization/test_helpers.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/sparseml/pytorch/sparsification/quantization/test_helpers.py b/tests/sparseml/pytorch/sparsification/quantization/test_helpers.py index 64cd194cde9..a56041e5445 100644 --- a/tests/sparseml/pytorch/sparsification/quantization/test_helpers.py +++ b/tests/sparseml/pytorch/sparsification/quantization/test_helpers.py @@ -309,8 +309,6 @@ def test_standard_qrange_zero_points(): fake_quantize(torch.randn(10, 10)) assert fake_quantize.quant_min == -128 assert fake_quantize.quant_max == 127 - assert fake_quantize.activation_post_process.quant_min == -128 - assert fake_quantize.activation_post_process.quant_max == 127 _, zero_point = fake_quantize.calculate_qparams() assert zero_point[0] == 0 @@ -318,8 +316,6 @@ def test_standard_qrange_zero_points(): fake_quantize(torch.randn(10, 10)) assert fake_quantize.quant_min == 0 assert fake_quantize.quant_max == 255 - assert fake_quantize.activation_post_process.quant_min == 0 - assert fake_quantize.activation_post_process.quant_max == 255 _, zero_point = fake_quantize.calculate_qparams() assert zero_point[0] == 128 From b2e0f171fc61316c7bfacbe62b6b1d41e09b187b Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Tue, 8 Nov 2022 16:14:05 +0000 Subject: [PATCH 5/5] Triple quote comment --- .../sparsification/quantization/helpers.py | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 11d088a5459..b0c61527ee5 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -581,23 +581,25 @@ def get_observer( reduce_range=reduce_range, ) - # in torch 1.9.1, quant_min and quant_max are not passed to observer: - # https://github.com/pytorch/pytorch/blob/v1.9.1/torch/quantization/fake_quantize.py#L109 - # however in 1.12.0, this is fixed so both are passed to observer: - # https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/fake_quantize.py#L132 - # - # Passing quant_min/quant_max to observer means the observer will have - # `self.has_customized_qrange == True` in both 1.9.1 and 1.12.0. - # - # For whatever reason, both versions calculate zero point for - # quint8 differently **if there is a customized_qrange** - # 1. customized qrange has zero point of 127 - # 2. non-customized has zero point of 128. - # source: - # https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/observer.py#L293 - # - # **we want to ensure that the zero point is 128** - # see https://github.com/neuralmagic/sparseml/pull/604 + """ + in torch 1.9.1, quant_min and quant_max are not passed to observer: + https://github.com/pytorch/pytorch/blob/v1.9.1/torch/quantization/fake_quantize.py#L109 + however in 1.12.0, this is fixed so both are passed to observer: + https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/fake_quantize.py#L132 + + Passing quant_min/quant_max to observer means the observer will have + `self.has_customized_qrange == True` in both 1.9.1 and 1.12.0. + + For whatever reason, both versions calculate zero point for + quint8 differently **if there is a customized_qrange** + 1. customized qrange has zero point of 127 + 2. non-customized has zero point of 128. + source: + https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/observer.py#L293 + + **we want to ensure that the zero point is 128** + see https://github.com/neuralmagic/sparseml/pull/604 + """ if is_custom_qrange: # for both versions we need to include the custom min/max values in kwargs observer_kwargs["quant_min"] = quant_min