Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 53 additions & 46 deletions src/sparseml/pytorch/sparsification/quantization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from copy import deepcopy
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -31,7 +32,7 @@


_PARSED_TORCH_VERSION = version.parse(torch.__version__)

_TORCH_PRE_112 = _PARSED_TORCH_VERSION < version.parse("1.12.0")

__all__ = [
"QATWrapper",
Expand All @@ -41,7 +42,6 @@
"add_quant_dequant",
"remove_activation_qat_by_layer_name",
"get_qat_qconfig",
"fix_observer_quant_range",
"freeze_bn_stats",
"fuse_module_conv_bn_relus",
"prepare_embeddings_qat",
Expand Down Expand Up @@ -450,17 +450,18 @@ def compute_range(dtype: torch.dtype, bits: int):

:param dtype: data type.
:param bits: number of bits.
:return: minimum limit, maximum limit
:return: minimum limit, maximum limit, whether the range is customized
"""
bits = bits if bits else 8
is_custom = bits != 8
if dtype == torch.qint8:
quant_min = -(2 ** (bits - 1))
quant_max = (2 ** (bits - 1)) - 1
elif dtype == torch.quint8:
quant_min = 0
quant_max = (2 ** bits) - 1

return quant_min, quant_max
return quant_min, quant_max, is_custom


def configure_module_default_qconfigs(module: Module):
Expand Down Expand Up @@ -571,15 +572,59 @@ def get_observer(
qconfig_kwargs: Dict[str, Any],
):
qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine
quant_min, quant_max = compute_range(dtype, bits)
quant_min, quant_max, is_custom_qrange = compute_range(dtype, bits)

observer_cls = torch_quantization.MovingAverageMinMaxObserver
observer_kwargs = dict(
observer=torch_quantization.MovingAverageMinMaxObserver,
quant_min=quant_min,
quant_max=quant_max,
dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range,
)

"""
in torch 1.9.1, quant_min and quant_max are not passed to observer:
https://github.com/pytorch/pytorch/blob/v1.9.1/torch/quantization/fake_quantize.py#L109
however in 1.12.0, this is fixed so both are passed to observer:
https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/fake_quantize.py#L132

Passing quant_min/quant_max to observer means the observer will have
`self.has_customized_qrange == True` in both 1.9.1 and 1.12.0.

For whatever reason, both versions calculate zero point for
quint8 differently **if there is a customized_qrange**
1. customized qrange has zero point of 127
2. non-customized has zero point of 128.
source:
https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/observer.py#L293

**we want to ensure that the zero point is 128**
see https://github.com/neuralmagic/sparseml/pull/604
"""
if is_custom_qrange:
# for both versions we need to include the custom min/max values in kwargs
observer_kwargs["quant_min"] = quant_min
observer_kwargs["quant_max"] = quant_max
if _TORCH_PRE_112:
# pre 1.12, the observer doesn't get passed the quant_min/quant_max values,
# so we patch them in to the constructor of the observer
observer_cls = partial(
observer_cls, quant_min=quant_min, quant_max=quant_max
)
else:
# if using a non custom qrange, we can rely on default values used by
# the observers
if _TORCH_PRE_112:
# pre 1.12, the observer doesn't get passed the quant_min/quant_max values,
# so we are safe to pass these to FakeQuantize
observer_kwargs["quant_min"] = quant_min
observer_kwargs["quant_max"] = quant_max
else:
# post 1.12 we cannot pass them to the observer since that will set
# has_customized_qrange. instead we rely on the default values
# being equal to the `quant_min` and `quant_max` here.
pass

observer_kwargs["observer"] = observer_cls
observer_kwargs.update(qconfig_kwargs or {})
observer = torch_quantization.FakeQuantize.with_args(
**observer_kwargs,
Expand All @@ -588,44 +633,6 @@ def get_observer(
return observer


def fix_observer_quant_range(module: Module):
"""
As of torch 1.10.2 there is a bug in FakeQuantize initialization where
quant_min and quant_max of FakeQuantize are not propagated to its
activation_post_process observer. This function propagates FakeQuantize quant
ranges to their Observer objects

:param module: Module object to propagate FakeQuantize quant ranges of. Propagates
in place
"""
for submodule in module.modules():
if isinstance(submodule, torch_quantization.FakeQuantize):
fake_quantize = submodule
elif hasattr(submodule, "activation_post_process") and isinstance(
submodule.activation_post_process, torch_quantization.FakeQuantize
):
fake_quantize = submodule.activation_post_process
else:
continue

# continue if fake_quantize quant range not set, or observer quant range is set
observer = fake_quantize.activation_post_process
if (
fake_quantize.quant_min is None
or fake_quantize.quant_max is None
or (observer.quant_min is not None or observer.quant_max is not None)
or ( # do not propagate default uint8 symmetric range
observer.qscheme == torch.per_tensor_symmetric
and fake_quantize.quant_min == 0
and fake_quantize.quant_max == 255
)
):
continue
observer.quant_min = fake_quantize.quant_min
observer.quant_max = fake_quantize.quant_max
observer.has_customized_qrange = True


def freeze_bn_stats(module: Module):
if hasattr(module, "freeze_bn_stats"):
module.freeze_bn_stats()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
configure_module_bn_wrappers,
configure_module_default_qconfigs,
configure_module_qat_wrappers,
fix_observer_quant_range,
freeze_bn_stats,
fuse_module_conv_bn_relus,
get_qat_qconfig,
Expand Down Expand Up @@ -713,9 +712,6 @@ def _enable_module_qat(self, module: Module):
if self._quantize_embeddings:
prepare_embeddings_qat(module, qproperties)

# propagate custom quant min/max range from FakeQuantize to Observer objects
fix_observer_quant_range(module)

self._qat_enabled = True
self._calibrate_if_possible(module)

Expand Down
58 changes: 58 additions & 0 deletions tests/sparseml/pytorch/sparsification/quantization/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
get_qat_qconfig,
prepare_embeddings_qat,
)
from sparseml.pytorch.sparsification.quantization.helpers import get_observer
from sparseml.pytorch.sparsification.quantization.modifier_quantization import (
QuantizationModifier,
)


try:
Expand Down Expand Up @@ -283,3 +287,57 @@ def test_prepare_embeddings_qat():
module(torch.arange(10))
observed_range_min = observer.activation_post_process.min_val.item()
assert orig_range_min != observed_range_min


def test_zero_point_is_128():
# see https://github.com/neuralmagic/sparseml/pull/604

# give QATMatMul a layer to be wrapped
dummy_sequential = torch.nn.Sequential(_QATMatMul())
QuantizationModifier().apply(dummy_sequential)
qat_matmul = dummy_sequential[0]
_ = qat_matmul(torch.randn(10, 10), torch.randn(10, 10))

fq = qat_matmul.input_quant_stubs[1].activation_post_process
assert fq.zero_point[0] == 128


def test_standard_qrange_zero_points():
bits = 8

fake_quantize = get_observer(True, torch.qint8, bits, False, {})()
fake_quantize(torch.randn(10, 10))
assert fake_quantize.quant_min == -128
assert fake_quantize.quant_max == 127
_, zero_point = fake_quantize.calculate_qparams()
assert zero_point[0] == 0

fake_quantize = get_observer(True, torch.quint8, bits, False, {})()
fake_quantize(torch.randn(10, 10))
assert fake_quantize.quant_min == 0
assert fake_quantize.quant_max == 255
_, zero_point = fake_quantize.calculate_qparams()
assert zero_point[0] == 128


def test_custom_qrange_zero_points():
# non 8 bits is what makes it a custom qrange
bits = 4

fake_quantize = get_observer(True, torch.qint8, bits, False, {})()
fake_quantize(torch.randn(10, 10))
assert fake_quantize.quant_min == -8
assert fake_quantize.quant_max == 7
assert fake_quantize.activation_post_process.quant_min == -8
assert fake_quantize.activation_post_process.quant_max == 7
_, zero_point = fake_quantize.calculate_qparams()
assert zero_point[0] == 0

fake_quantize = get_observer(True, torch.quint8, bits, False, {})()
fake_quantize(torch.randn(10, 10))
assert fake_quantize.quant_min == 0
assert fake_quantize.quant_max == 15
assert fake_quantize.activation_post_process.quant_min == 0
assert fake_quantize.activation_post_process.quant_max == 15
_, zero_point = fake_quantize.calculate_qparams()
assert zero_point[0] == 7
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,9 @@ def _test_quantizable_module(module, qat_expected, modifier):
if modifier.activation_bits is not None:
expected_quant_min = 0
expected_quant_max = (1 << modifier.activation_bits) - 1
activation_quant_properties = module.qconfig.activation.p.keywords

assert activation_quant_properties["quant_min"] == expected_quant_min
assert activation_quant_properties["quant_max"] == expected_quant_max
module.activation_post_process.quant_min == expected_quant_min
module.activation_post_process.quant_max == expected_quant_max

if isinstance(module, Linear):
assert isinstance(module.activation_post_process, Identity) == (
Expand Down