diff --git a/src/sparseml/pytorch/optim/modifier_quantization.py b/src/sparseml/pytorch/optim/modifier_quantization.py index b0e545d65e0..5e5f2321587 100644 --- a/src/sparseml/pytorch/optim/modifier_quantization.py +++ b/src/sparseml/pytorch/optim/modifier_quantization.py @@ -51,6 +51,7 @@ add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, + fix_observer_quant_range, fuse_module_conv_bn_relus, get_qat_qconfig, prepare_embeddings_qat, @@ -85,6 +86,7 @@ class QuantizationModifier(ScheduledModifier): | disable_quantization_observer_epoch: 2.0 | freeze_bn_stats_epoch: 3.0 | reduce_range: False + | activation_bits: False :param start_epoch: The epoch to start the modifier at :param submodules: List of submodule names to perform QAT on. Leave None to quantize @@ -114,10 +116,16 @@ class QuantizationModifier(ScheduledModifier): transformer based models such as BERT where the quantized MatMul outputs are kept at 32 bits of precision and fake quantizing the outputs harm training recovery. Default is True + :param activation_bits: Number of bits to use for setting quant min/max values for + activations. Default is None, which will quantize activations to 8 bits. :param num_calibration_steps: Number of steps to run post training calibration for. When None, the entire calibration_dataloader is used :param exclude_module_types: optional list of module class names to not propagate quantization configs to. Default is None + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. """ def __init__( @@ -132,8 +140,11 @@ def __init__( quantize_embeddings: bool = True, reduce_range: bool = False, quantize_linear_activations: bool = True, + activation_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_module_types: Union[List[str], None] = None, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -158,12 +169,16 @@ def __init__( self._quantize_embeddings = quantize_embeddings self._reduce_range = reduce_range self._quantize_linear_activations = quantize_linear_activations + self._activation_bits = activation_bits self._exclude_module_types = exclude_module_types self._modules_to_quantize = None self._qat_enabled = False self._quantization_observer_disabled = False self._bn_stats_frozen = False + self._activation_qconfig_kwargs = activation_qconfig_kwargs + self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._calibration_dataloader = None self._calibration_function = None self._num_calibration_steps = num_calibration_steps @@ -309,6 +324,32 @@ def exclude_module_types(self) -> Union[List[str], None]: """ return self._exclude_module_types + @ModifierProp() + def activation_bits(self) -> Optional[int]: + """ + :return: Number of bits to be use for setting quant min/max values for + activations. Default is None, which will quantize activations to 8 bits. + """ + return self._activation_bits + + @ModifierProp() + def activation_qconfig_kwargs(self) -> Dict[str, Any]: + """ + :return: Dictionary with correct quant_min, quant_max, and dtype values + for activations + + """ + return self._activation_qconfig_kwargs + + @ModifierProp() + def weight_qconfig_kwargs(self) -> Dict[str, Any]: + """ + :return: Dictionary with correct quant_min, quant_max, and dtype values + for weights + + """ + return self._weight_qconfig_kwargs + @ModifierProp() def num_calibration_steps(self) -> Optional[int]: """ @@ -457,11 +498,22 @@ def _enable_module_qat(self, module: Module): self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() + # prepare each module / submodule for quantization - qconfig = get_qat_qconfig(reduce_range=self._reduce_range) + qconfig = get_qat_qconfig( + reduce_range=self._reduce_range, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) - configure_module_qat_wrappers(quant_module, reduce_range=self._reduce_range) + configure_module_qat_wrappers( + quant_module, + reduce_range=self._reduce_range, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig # wrap all conv / linear blocks in with quantization observers @@ -480,7 +532,15 @@ def _enable_module_qat(self, module: Module): # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) if self._quantize_embeddings: - prepare_embeddings_qat(module, reduce_range=self._reduce_range) + prepare_embeddings_qat( + module, + reduce_range=self._reduce_range, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + ) + + # propagate custom quant min/max range from FakeQuantize to Observer objects + fix_observer_quant_range(module) self._qat_enabled = True self._calibrate_if_possible(module) @@ -533,6 +593,35 @@ def _calibrate(self, module): if module_training: module.train() + def _get_updated_activation_qconfig_kwargs(self): + activation_qconfig_kwargs = ( + self.activation_qconfig_kwargs.copy() + if self.activation_qconfig_kwargs + else {} + ) + + # update qconfig_kwargs for activation_bits + if self.activation_bits and ( + activation_qconfig_kwargs.get("quant_min") + or activation_qconfig_kwargs.get("quant_max") + ): + raise ValueError( + "Cannot override quant_max and quant_min with activation_bits enabled" + ) + + if self.activation_bits: + quant_min = 0 + quant_max = 2 ** self.activation_bits - 1 + dtype = torch.quint8 + activation_qconfig_kwargs.update( + dict( + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + ) + ) + return activation_qconfig_kwargs + def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None diff --git a/src/sparseml/pytorch/utils/helpers.py b/src/sparseml/pytorch/utils/helpers.py index b2936a7047f..475d0845076 100644 --- a/src/sparseml/pytorch/utils/helpers.py +++ b/src/sparseml/pytorch/utils/helpers.py @@ -44,15 +44,15 @@ QATLinear = None QATConv2d = None +from sparseml.utils import create_dirs, save_numpy + + try: from torch.nn.qat import Conv3d as QATConv3d except Exception as _err: quant_conv3d_err = _err QATConv3d = None -from sparseml.utils import create_dirs, save_numpy - - __all__ = [ "default_device", "device_of", diff --git a/src/sparseml/pytorch/utils/quantization/helpers.py b/src/sparseml/pytorch/utils/quantization/helpers.py index 3876661ce8a..2f3bac7c85b 100644 --- a/src/sparseml/pytorch/utils/quantization/helpers.py +++ b/src/sparseml/pytorch/utils/quantization/helpers.py @@ -17,7 +17,7 @@ """ from copy import deepcopy -from typing import Any, Callable, List, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU @@ -40,11 +40,11 @@ "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", + "fix_observer_quant_range", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] - _QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers @@ -98,16 +98,29 @@ class QATWrapper(Module): :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware Default is False + :param activation_qconfig_kwargs: Additional kwargs for quantization of activations. + Default is {} + :param weight_qconfig_kwargs: Additional kwargs for quantization of weights. + Default is {} """ @staticmethod - def from_module(module: Module, reduce_range: bool = None) -> "QATWrapper": + def from_module( + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware. Default is None, will only override qat_wrapper_kwargs if set to a bool value + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. Default is {} + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. Default is {} :return: QATWrapper object created using the given Module as the forward function. Will attempt to find any other named parameter of the QATWrapper constructor from the attributes of the given Module @@ -124,6 +137,19 @@ def from_module(module: Module, reduce_range: bool = None) -> "QATWrapper": else reduce_range or qat_wrapper_kwargs["reduce_range"] ) + qat_wrapper_kwargs["activation_qconfig_kwargs"] = ( + activation_qconfig_kwargs + if "activation_qconfig_kwargs" not in qat_wrapper_kwargs + else activation_qconfig_kwargs + or qat_wrapper_kwargs["activation_qconfig_kwargs"] + ) + + qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( + weight_qconfig_kwargs + if "weight_qconfig_kwargs" not in qat_wrapper_kwargs + else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] + ) + return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -139,6 +165,8 @@ def __init__( "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -159,11 +187,24 @@ def __init__( self.forward_fn = forward_fn self._reduce_range = reduce_range + self._activation_qconfig_kwargs = activation_qconfig_kwargs + self._weight_qconfig_kwargs = weight_qconfig_kwargs + self.input_qconfigs = self._load_qconfigs( - "input_qconfigs", num_input_quant_stubs, input_qconfigs, self._reduce_range + name="input_qconfigs", + expected_len=num_input_quant_stubs, + qconfigs=input_qconfigs, + reduce_range=self._reduce_range, + activation_qconfig_kwargs=self._activation_qconfig_kwargs, + weight_qconfig_kwargs=self._weight_qconfig_kwargs, ) self.output_qconfigs = self._load_qconfigs( - "output_qconfigs", num_outputs, output_qconfigs, self._reduce_range + name="output_qconfigs", + expected_len=num_outputs, + qconfigs=output_qconfigs, + reduce_range=self._reduce_range, + activation_qconfig_kwargs=self._activation_qconfig_kwargs, + weight_qconfig_kwargs=self._weight_qconfig_kwargs, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -247,7 +288,9 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - redcuce_range: bool = False, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -279,13 +322,20 @@ def _load_qconfigs( qconfigs[idx] = get_qat_qconfig( symmetric_activations=(qconfig == "symmetric"), - reduce_range=redcuce_range, + reduce_range=reduce_range, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) return qconfigs -def configure_module_qat_wrappers(module: Module, reduce_range: bool = False): +def configure_module_qat_wrappers( + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, +): """ if any submodule of the given module has the attribute wrap_qat == True, then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. @@ -296,6 +346,10 @@ def configure_module_qat_wrappers(module: Module, reduce_range: bool = False): :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware Default is False + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. Default is {} + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): @@ -303,10 +357,20 @@ def configure_module_qat_wrappers(module: Module, reduce_range: bool = False): setattr( module, child_name, - QATWrapper.from_module(child_module, reduce_range=reduce_range), + QATWrapper.from_module( + module=child_module, + reduce_range=reduce_range, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, + ), ) # recurse on child module - configure_module_qat_wrappers(child_module, reduce_range=reduce_range) + configure_module_qat_wrappers( + module=child_module, + reduce_range=reduce_range, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, + ) def configure_module_default_qconfigs(module: Module): @@ -359,8 +423,8 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ e.x. ["Linear"] """ for submodule in module.modules(): - if submodule.__class__.__name__ in layer_class_names and ( - hasattr(submodule, "qconfig") + if submodule.__class__.__name__ in layer_class_names and hasattr( + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -372,6 +436,8 @@ def get_qat_qconfig( symmetric_activations: bool = False, symmetric_weights: bool = True, reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -383,6 +449,10 @@ def get_qat_qconfig( :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware Default is False + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. :return: A QAT fake quantization config for symmetric weight quantization and asymmetric activation quantization. The difference between this and torch.quantization.default_qat_qconfig is that the activation observer @@ -391,7 +461,7 @@ def get_qat_qconfig( activation_qscheme = ( torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine ) - activation_observer = torch_quantization.FakeQuantize.with_args( + activation_observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, quant_min=0, quant_max=255, @@ -399,10 +469,14 @@ def get_qat_qconfig( qscheme=activation_qscheme, reduce_range=reduce_range, ) + activation_observer_kwargs.update(activation_qconfig_kwargs or {}) + activation_observer = torch_quantization.FakeQuantize.with_args( + **activation_observer_kwargs, + ) weight_qscheme = ( torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine ) - weight_observer = torch_quantization.FakeQuantize.with_args( + weight_observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, quant_min=-128, quant_max=127, @@ -410,12 +484,50 @@ def get_qat_qconfig( qscheme=weight_qscheme, reduce_range=reduce_range, ) + + weight_observer_kwargs.update(weight_qconfig_kwargs or {}) + weight_observer = torch_quantization.FakeQuantize.with_args( + **weight_observer_kwargs, + ) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) +def fix_observer_quant_range(module: Module): + """ + As of torch 1.10.2 there is a bug in FakeQuantize initialization where + quant_min and quant_max of FakeQuantize are not propagated to its + activation_post_process observer. This function propagates FakeQuantize quant + ranges to their Observer objects + + :param module: Module object to propagate FakeQuantize quant ranges of. Propagates + in place + """ + for submodule in module.modules(): + if isinstance(submodule, torch_quantization.FakeQuantize): + fake_quantize = submodule + elif hasattr(submodule, "activation_post_process") and isinstance( + submodule.activation_post_process, torch_quantization.FakeQuantize + ): + fake_quantize = submodule.activation_post_process + else: + continue + + # continue if fake_quantize quant range not set, or observer quant range is set + observer = fake_quantize.activation_post_process + if ( + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) + ): + continue + observer.quant_min = fake_quantize.quant_min + observer.quant_max = fake_quantize.quant_max + observer.has_customized_qrange = True + + def fuse_module_conv_bn_relus( module: Module, inplace: bool = True, @@ -501,6 +613,8 @@ def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -509,12 +623,21 @@ def prepare_embeddings_qat( :param module: module to run QAT for the embeddings of :param qconfig: qconfig to generate the fake quantize ops from. Default uses INT8 asymmetric range + :param activation_qconfig_kwargs: additional kwargs for quantizing activations. + Default is {}. + :param weight_qconfig_kwargs: additional kwargs for quantizing the weights. + Default is {}. :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware. Default is False """ if qconfig is None: - qconfig = get_qat_qconfig(symmetric_weights=False, reduce_range=reduce_range) + qconfig = get_qat_qconfig( + symmetric_weights=False, + reduce_range=reduce_range, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, + ) for submodule in module.modules(): if type(submodule) is Embedding: _prepare_qat_embedding(submodule, qconfig) diff --git a/tests/sparseml/pytorch/optim/test_modifier_quantization.py b/tests/sparseml/pytorch/optim/test_modifier_quantization.py index d269141b49a..80702762687 100644 --- a/tests/sparseml/pytorch/optim/test_modifier_quantization.py +++ b/tests/sparseml/pytorch/optim/test_modifier_quantization.py @@ -53,6 +53,10 @@ start_epoch=0.0, quantize_linear_activations=False, ), + lambda: QuantizationModifier( + start_epoch=0.0, + activation_bits=4, + ), lambda: QuantizationModifier( start_epoch=0.0, exclude_module_types=["Linear"], @@ -86,6 +90,17 @@ def _test_quantizable_module(module, qat_expected, modifier): ) if module.qconfig.activation is not Identity: assert module.qconfig.activation.p.keywords["reduce_range"] == reduce_range + if modifier.activation_bits is not None: + expected_quant_min = 0 + expected_quant_max = (1 << modifier.activation_bits) - 1 + + assert ( + module.qconfig.activation.p.keywords["quant_min"] == expected_quant_min + ) + assert ( + module.qconfig.activation.p.keywords["quant_max"] == expected_quant_max + ) + if isinstance(module, Linear): assert isinstance(module.activation_post_process, Identity) == ( not quantize_linear_activations @@ -103,11 +118,8 @@ def _test_qat_applied(modifier, model): submodules = [""] for module in model.modules(): if _is_quantizable_module(module): - _test_quantizable_module( - module, - True, - modifier, - ) + _test_quantizable_module(module, True, modifier) + else: assert not hasattr(model, "qconfig") or model.qconfig is None submodules = modifier.submodules @@ -211,6 +223,11 @@ def test_quantization_modifier_yaml(): quantize_linear_activations = False num_calibration_steps = 2 exclude_module_types = ["LayerNorm", "Tanh"] + activation_bits = 4 + averaging_constant = 0.05 + activation_qconfig_kwargs = dict( + averaging_constant=averaging_constant, + ) yaml_str = f""" !QuantizationModifier start_epoch: {start_epoch} @@ -223,6 +240,8 @@ def test_quantization_modifier_yaml(): quantize_linear_activations: {quantize_linear_activations} num_calibration_steps: {num_calibration_steps} exclude_module_types: {exclude_module_types} + activation_bits: {activation_bits} + activation_qconfig_kwargs: {activation_qconfig_kwargs} """ yaml_modifier = QuantizationModifier.load_obj( yaml_str @@ -239,8 +258,10 @@ def test_quantization_modifier_yaml(): quantize_embeddings=quantize_embeddings, reduce_range=reduce_range, quantize_linear_activations=quantize_linear_activations, + activation_bits=activation_bits, num_calibration_steps=num_calibration_steps, exclude_module_types=exclude_module_types, + activation_qconfig_kwargs=activation_qconfig_kwargs, ) assert isinstance(yaml_modifier, QuantizationModifier) @@ -284,6 +305,11 @@ def test_quantization_modifier_yaml(): == serialized_modifier.quantize_linear_activations == obj_modifier.quantize_linear_activations ) + assert ( + yaml_modifier.activation_bits + == serialized_modifier.activation_bits + == obj_modifier.activation_bits + ) assert ( yaml_modifier.num_calibration_steps == serialized_modifier.num_calibration_steps @@ -294,3 +320,8 @@ def test_quantization_modifier_yaml(): == serialized_modifier.exclude_module_types == obj_modifier.exclude_module_types ) + assert ( + yaml_modifier.activation_qconfig_kwargs + == serialized_modifier.activation_qconfig_kwargs + == obj_modifier.activation_qconfig_kwargs + )