From 4d386e0754e465db0cd0b8c83a4f2957aabe5d89 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 12:14:39 -0500 Subject: [PATCH 001/218] Removed output quantization from conv layers --- .../sparsification/quantization/helpers.py | 153 +++++++++++------- .../quantization/modifier_quantization.py | 90 ++++++----- 2 files changed, 145 insertions(+), 98 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 2f3bac7c85b..75d11c67c31 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,20 +31,21 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ + "QUANTIZABLE_MODULE_TYPES", "QATWrapper", "configure_module_qat_wrappers", "configure_module_default_qconfigs", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", + "get_updated_qconfig_kwargs", "fix_observer_quant_range", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] -_QUANTIZABLE_MODULE_TYPES = ( +QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -106,10 +106,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -141,7 +141,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -153,20 +153,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -285,12 +285,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -331,10 +331,10 @@ def _load_qconfigs( def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -383,7 +383,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -398,9 +398,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -424,7 +424,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -433,11 +433,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -509,7 +509,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -518,9 +518,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -529,9 +529,9 @@ def fix_observer_quant_range(module: Module): def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -566,14 +566,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -610,11 +610,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -643,6 +643,37 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits): + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) + + # update qconfig_kwargs for bits + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): + raise ValueError( + "Cannot override quant_max and quant_min when number of bits is set" + ) + + if bits: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( + dict( + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + ) + ) + + return qconfig_kwargs + + def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"): embedding.weight_fake_quant = qconfig.weight() diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 15ad82299d9..79772790566 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,12 +47,14 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( + QUANTIZABLE_MODULE_TYPES, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, fuse_module_conv_bn_relus, get_qat_qconfig, + get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -139,8 +141,11 @@ def __init__( model_fuse_fn_kwargs: Dict[str, Any] = None, quantize_embeddings: bool = True, reduce_range: bool = False, - quantize_linear_activations: bool = True, + quantize_linear_output_activations: bool = False, + quantize_conv_output_activations: bool = False, + quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, @@ -168,8 +173,11 @@ def __init__( self._freeze_bn_stats_epoch = freeze_bn_stats_epoch self._quantize_embeddings = quantize_embeddings self._reduce_range = reduce_range - self._quantize_linear_activations = quantize_linear_activations + self._quantize_linear_output_activations = quantize_linear_output_activations + self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits + self._weight_bits = weight_bits self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -306,7 +314,7 @@ def reduce_range(self) -> bool: return self._reduce_range @ModifierProp() - def quantize_linear_activations(self) -> bool: + def quantize_linear_output_activations(self) -> bool: """ :return: if False, FakeQuantize ops will not be run for activations of fully connected layers. this is important for quantizing @@ -314,7 +322,15 @@ def quantize_linear_activations(self) -> bool: are kept at 32 bits of precision and fake quantizing the outputs harm training recovery """ - return self._quantize_linear_activations + return self._quantize_linear_output_activations + + @ModifierProp() + def quantize_conv_output_activations(self) -> bool: + """ + :return: if False, FakeQuantize ops will not be run + for activations of convolutional layers. + """ + return self._quantize_linear_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -332,6 +348,15 @@ def activation_bits(self) -> Optional[int]: """ return self._activation_bits + @ModifierProp() + def weight_bits(self) -> Optional[int]: + """ + :return: Number of bits to be use for setting quant min/max values for + activations. Default is None, which will quantize activations to 8 bits. + """ + return self._weight_bits + + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -391,7 +416,10 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if name in self._submodules: + if ( + type(submodule) in QUANTIZABLE_MODULE_TYPES + and name in self._submodules + ): self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -499,12 +527,25 @@ def _enable_module_qat(self, module: Module): fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() + weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() + + to_remove_layer_name = [] + if not self._quantize_linear_output_activations: + to_remove_layer_name.extend(["Linear", "LinearReLu"]) + + if not self._quantize_conv_output_activations: + to_remove_layer_name.extend( + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + ) # prepare each module / submodule for quantization qconfig = get_qat_qconfig( reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) @@ -512,7 +553,7 @@ def _enable_module_qat(self, module: Module): quant_module, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -521,9 +562,7 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - - if not self._quantize_linear_activations: - remove_activation_qat_by_layer_name(quant_module, ["Linear"]) + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types if self._exclude_module_types: @@ -536,7 +575,7 @@ def _enable_module_qat(self, module: Module): module, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # propagate custom quant min/max range from FakeQuantize to Observer objects @@ -594,33 +633,10 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - activation_qconfig_kwargs = ( - self.activation_qconfig_kwargs.copy() - if self.activation_qconfig_kwargs - else {} - ) - - # update qconfig_kwargs for activation_bits - if self.activation_bits and ( - activation_qconfig_kwargs.get("quant_min") - or activation_qconfig_kwargs.get("quant_max") - ): - raise ValueError( - "Cannot override quant_max and quant_min with activation_bits enabled" - ) + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) - if self.activation_bits: - quant_min = 0 - quant_max = 2 ** self.activation_bits - 1 - dtype = torch.quint8 - activation_qconfig_kwargs.update( - dict( - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - ) - ) - return activation_qconfig_kwargs + def _get_updated_weight_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( From d875f97d7ee15b3f921058c9fb07ce0a47d682f3 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:35:49 -0500 Subject: [PATCH 002/218] Added _Add_ReLU module that enables QATWrapper for quantizaiton. --- .../pytorch/models/classification/resnet.py | 94 +++++++++---------- 1 file changed, 46 insertions(+), 48 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 40aef8a3c69..3112da7c2e1 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,13 +41,11 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: FloatFunctional = None - __all__ = [ "ResNetSectionSettings", "ResNet", @@ -141,6 +139,23 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: return in_channels != out_channels or stride > 1 +class _AddReLU(Module): + def __init__(self): + super().__init__() + if FloatFunctional: + self.functional = FloatFunctional() + self.wrap_qat = True + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + else: + self.functional = ReLU(num_channels=out_channels, inplace=True) + + def forward(self, x, y): + if isinstance(self.functional, FloatFunctional): + return self.functional.add_relu(x, y) + else: + return self.functional(x + y) + + class _BasicBlock(Module): def __init__(self, in_channels: int, out_channels: int, stride: int = 1): super().__init__() @@ -164,11 +179,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = ( - FloatFunctional() - if FloatFunctional is not None - else ReLU(num_channels=out_channels, inplace=True) - ) + self.add_relu = _AddReLU() self.initialize() @@ -181,12 +192,7 @@ def forward(self, inp: Tensor): out = self.bn2(out) identity_val = self.identity(inp) if self.identity is not None else inp - - if isinstance(self.add_relu, FloatFunctional): - out = self.add_relu.add_relu(out, identity_val) - else: - out += identity_val - out = self.add_relu(out) + out = self.add_relu(identity_val, out) return out @@ -199,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -230,11 +236,7 @@ def __init__( else None ) - self.add_relu = ( - FloatFunctional() - if FloatFunctional is not None - else ReLU(num_channels=out_channels, inplace=True) - ) + self.add_relu = _AddReLU() self.initialize() @@ -252,11 +254,7 @@ def forward(self, inp: Tensor): identity_val = self.identity(inp) if self.identity is not None else inp - if isinstance(self.add_relu, FloatFunctional): - out = self.add_relu.add_relu(out, identity_val) - else: - out += identity_val - out = self.add_relu(out) + out = self.add_relu(identity_val, out) return out @@ -323,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -439,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -481,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 831e4743939ad739fda76ffd9350c2284d0961a4 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:36:37 -0500 Subject: [PATCH 003/218] Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. --- .../sparsification/quantization/helpers.py | 6 +-- .../quantization/modifier_quantization.py | 39 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 75d11c67c31..f28656f1712 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,7 +32,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ - "QUANTIZABLE_MODULE_TYPES", "QATWrapper", "configure_module_qat_wrappers", "configure_module_default_qconfigs", @@ -45,7 +44,7 @@ "prepare_embeddings_qat", ] -QUANTIZABLE_MODULE_TYPES = ( +_QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -150,6 +149,7 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -398,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in QUANTIZABLE_MODULE_TYPES + type(module) in _QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 79772790566..f914b1f2b91 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,7 +47,6 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( - QUANTIZABLE_MODULE_TYPES, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, @@ -94,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use - the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' + to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,10 +142,10 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, - quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, + exclude_batchnorm: bool = True, exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, @@ -175,9 +174,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations - self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits + self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -233,7 +232,8 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - return self._model_fuse_fn_name + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + return fuse_fn @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -416,10 +416,7 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if ( - type(submodule) in QUANTIZABLE_MODULE_TYPES - and name in self._submodules - ): + if name in self._submodules: self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -509,10 +506,10 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if ( - self._model_fuse_fn_name is not None - and self._model_fuse_fn_name != "no_fuse" - ): # module class fn + if self._model_fuse_fn_name == 'conv_bn_relus': + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -522,16 +519,13 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) - elif self._model_fuse_fn_name is None: # default auto fn - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLu"]) + to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -565,8 +559,15 @@ def _enable_module_qat(self, module: Module): remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types + to_exclude = [] if self._exclude_module_types: - self._strip_excluded_module_qconfigs(module) + to_exclude.extend(self._exclude_module_types) + + if self._exclude_batchnorm: + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + + self._exclude_module_types = to_exclude + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) From 479d3c9667527894edd00bb1e0d1e591e9fd4d49 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:40:31 -0500 Subject: [PATCH 004/218] Minor fixes. Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 61 +++++---- .../sparsification/quantization/helpers.py | 129 +++++++++--------- .../quantization/modifier_quantization.py | 33 +++-- 3 files changed, 115 insertions(+), 108 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3112da7c2e1..be4182891d6 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -140,14 +141,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self): + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: - self.functional = ReLU(num_channels=out_channels, inplace=True) + self.functional = ReLU(num_channels=num_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +180,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -205,12 +206,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -236,7 +237,7 @@ def __init__( else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -321,12 +322,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -437,15 +438,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -479,10 +480,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index f28656f1712..ef4445a0d5f 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_qat_wrappers", @@ -105,10 +107,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -140,7 +142,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -153,20 +155,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -285,12 +287,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -331,10 +333,10 @@ def _load_qconfigs( def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -383,7 +385,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -398,9 +400,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -424,7 +426,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -433,11 +435,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -509,7 +511,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -518,9 +520,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -529,9 +531,9 @@ def fix_observer_quant_range(module: Module): def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -566,14 +568,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -610,11 +612,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -644,17 +646,10 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index f914b1f2b91..637bf7e52dd 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -93,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' - to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as + 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -232,7 +232,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" return fuse_fn @model_fuse_fn_name.setter @@ -356,7 +356,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -506,7 +505,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': + if self._model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -529,10 +528,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) # prepare each module / submodule for quantization @@ -564,7 +573,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -634,7 +643,9 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits + ) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From 30be2e40786c6048631789687098c3a08a84aee5 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:02:14 -0500 Subject: [PATCH 005/218] Added support to freezing bn stats. --- .../sparsification/quantization/helpers.py | 215 +++++++++++++----- .../quantization/modifier_quantization.py | 37 ++- 2 files changed, 167 insertions(+), 85 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index ef4445a0d5f..c4f165d23ef 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,16 +31,17 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", - "configure_module_qat_wrappers", + "configure_module_bn_wrappers", "configure_module_default_qconfigs", + "configure_module_qat_wrappers", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", + "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] @@ -69,6 +69,54 @@ else None ) +_BN_MODULE_TYPES = ( + { + # Conv based layers + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.ConvBnReLU1d, + nni.ConvBnReLU2d, + nni.ConvBnReLU3d, + } + if nni # nni will always import if torch.quantization is available + else {} +) + + +class BNWrapper(Module): + def __init__(self, module: Module): + super().__init__() + self.bn = module + self.freeze_bn = False + + def forward(self, x): + return self.bn(x) + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def train(self, mode=True): + if not self.freeze_bn: + self.bn.train() + return self + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + +_BN_MODULE_TYPES.add(BNWrapper) + class QATWrapper(Module): """ @@ -107,10 +155,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -142,7 +190,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -155,20 +203,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -287,12 +335,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -332,11 +380,40 @@ def _load_qconfigs( return qconfigs +def configure_module_bn_wrappers(module: Module): + """ + if any submodule of the given module has the attribute wrap_qat == True, + then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. + Other named kwargs to the QATWrapper constructor must be contained in a dictionary + under an attributed named `qat_wrapper_kwargs` + + :param module: module to potentially wrap the submodules of + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware + Default is False + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. Default is {} + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. Default is {} + """ + # wrap any children of the given module as a QATWrapper if required + if type(module) != BNWrapper: + for child_name, child_module in module.named_children(): + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) + # recurse on child module + configure_module_bn_wrappers(child_module) + + def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -385,7 +462,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -400,9 +477,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -426,7 +503,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -435,11 +512,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -511,7 +588,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -520,9 +597,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -530,10 +607,15 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True +def freeze_bn_stats(module: Module): + if type(module) in _BN_MODULE_TYPES: + module.freeze_bn_stats() + + def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -568,14 +650,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -612,11 +694,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -646,10 +728,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 637bf7e52dd..7eed410b441 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -48,9 +48,11 @@ from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( add_quant_dequant, + configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, + freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, get_updated_qconfig_kwargs, @@ -232,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' return fuse_fn @model_fuse_fn_name.setter @@ -262,7 +264,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """ + """print :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -356,6 +358,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -500,12 +503,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) + quant_module.apply(freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == "conv_bn_relus": + if self._model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -528,22 +531,14 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) + configure_module_bn_wrappers(module) + # prepare each module / submodule for quantization qconfig = get_qat_qconfig( reduce_range=self._reduce_range, @@ -573,7 +568,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -643,9 +638,7 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits - ) + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From 89b025f4c791a31cec89dd799ee9314604702f66 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:57:15 -0500 Subject: [PATCH 006/218] Added mode argument to wrapping of train function in BNWrapper --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index c4f165d23ef..64958570e2d 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -106,7 +106,7 @@ def reset_running_stats(self): def train(self, mode=True): if not self.freeze_bn: - self.bn.train() + self.bn.train(mode) return self def update_bn_stats(self): From 5865ca6287f653972bfd798eab43f4817285c5c5 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:08:20 -0500 Subject: [PATCH 007/218] Set BN fusing back as default. --- .../sparsification/quantization/modifier_quantization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 7eed410b441..37307e38863 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -508,8 +508,8 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True + if self.model_fuse_fn_name == 'conv_bn_relus': + self.model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 91cd835546f83d184dc28346f1019cda5aa89afa Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:19:09 -0500 Subject: [PATCH 008/218] Set BN fusing back as default. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- .../sparsification/quantization/modifier_quantization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 64958570e2d..a43d69d947b 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -397,7 +397,7 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) != BNWrapper: + if type(module) not in _BN_MODULE_TYPES: for child_name, child_module in module.named_children(): if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: setattr( diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 37307e38863..2a35ebd2aaf 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -509,7 +509,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs if self.model_fuse_fn_name == 'conv_bn_relus': - self.model_fuse_fn_kwargs["inplace"] = True + self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 9819a156b724803b1b705229af72ec5b7641057f Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 11 Mar 2022 19:24:03 -0500 Subject: [PATCH 009/218] Fixed custom freeze_bn_stats. --- .../sparsification/quantization/helpers.py | 251 +++++++++++------- .../quantization/modifier_quantization.py | 46 +++- 2 files changed, 185 insertions(+), 112 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index a43d69d947b..6110a499b70 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,23 +71,6 @@ else None ) -_BN_MODULE_TYPES = ( - { - # Conv based layers - nni.ConvBn1d, - nni.ConvBn2d, - nni.ConvBn3d, - nni.ConvReLU1d, - nni.ConvReLU2d, - nni.ConvReLU3d, - nni.ConvBnReLU1d, - nni.ConvBnReLU2d, - nni.ConvBnReLU3d, - } - if nni # nni will always import if torch.quantization is available - else {} -) - class BNWrapper(Module): def __init__(self, module: Module): @@ -93,6 +78,78 @@ def __init__(self, module: Module): self.bn = module self.freeze_bn = False + @property + def running_mean(self): + return self.bn.running_mean + + @running_mean.setter + def running_mean(self, value): + self.bn.running_mean = value + + @property + def running_var(self): + return self.bn.running_var + + @running_var.setter + def running_var(self, value): + self.bn.running_var = value + + @property + def weight(self): + return self.bn.weight + + @weight.setter + def weight(self, value): + self.bn.weight = value + + @property + def bias(self): + return self.bn.bias + + @bias.setter + def bias(self, value): + self.bn.bias = value + + @property + def gamma(self): + return self.bn.gamma + + @gamma.setter + def gamma(self, value): + self.bn.gamma = value + + @property + def beta(self): + return self.bn.beta + + @beta.setter + def beta(self, value): + self.bn.beta = value + + @property + def num_batches_tracked(self): + return self.bn.num_batches_tracked + + @num_batches_tracked.setter + def num_batches_tracked(self, value): + self.bn.num_batches_tracked = value + + @property + def eps(self): + return self.bn.eps + + @eps.setter + def eps(self, value): + self.bn.eps = value + + @property + def momentum(self): + return self.bn.momentum + + @momentum.setter + def momentum(self, value): + self.bn.momentum = value + def forward(self, x): return self.bn(x) @@ -115,9 +172,6 @@ def update_bn_stats(self): return self -_BN_MODULE_TYPES.add(BNWrapper) - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -155,10 +209,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -190,7 +244,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -203,20 +257,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -335,12 +389,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -397,23 +451,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) not in _BN_MODULE_TYPES: + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -462,7 +516,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -477,9 +531,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -503,7 +557,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -512,11 +566,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -588,7 +642,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -597,9 +651,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -608,14 +662,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if type(module) in _BN_MODULE_TYPES: + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -650,14 +704,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -694,11 +748,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -727,26 +781,25 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 + if mode == "symmetric": + quant_min = -(2 ** (bits - 1)) + quant_max = 2 ** (bits - 1) - 1 + dtype = torch.qint8 + else: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 2a35ebd2aaf..acbae885d71 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,9 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -332,7 +334,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_linear_output_activations + return self._quantize_conv_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -358,7 +360,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -504,11 +505,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) + # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -531,11 +533,23 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) + if len(to_remove_layer_name) == 0: + to_remove_layer_name = None configure_module_bn_wrappers(module) @@ -560,7 +574,8 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + if to_remove_layer_name: + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types to_exclude = [] @@ -568,10 +583,11 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude - self._strip_excluded_module_qconfigs(module) + if self._exclude_module_types: + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) @@ -638,10 +654,14 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" + ) def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + return get_updated_qconfig_kwargs( + self.weight_qconfig_kwargs, self.weight_bits, "symmetric" + ) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( From 749ca72d99152144a4bce489dfb4785ca9ded736 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 14 Mar 2022 15:35:52 -0400 Subject: [PATCH 010/218] Temporary files for evaluating changes to graphs. --- sandbox/quantization_recipe.yaml | 7 +++ sandbox/quantization_test.py | 23 ++++++++ .../pytorch/models/classification/resnet.py | 53 +++++++++---------- 3 files changed, 56 insertions(+), 27 deletions(-) create mode 100644 sandbox/quantization_recipe.yaml create mode 100644 sandbox/quantization_test.py diff --git a/sandbox/quantization_recipe.yaml b/sandbox/quantization_recipe.yaml new file mode 100644 index 00000000000..411dd6f025a --- /dev/null +++ b/sandbox/quantization_recipe.yaml @@ -0,0 +1,7 @@ +quantization_modifiers: + - !QuantizationModifier + start_epoch: -1.0 + model_fuse_fn_name: no_fuse + submodules: + - input + - sections diff --git a/sandbox/quantization_test.py b/sandbox/quantization_test.py new file mode 100644 index 00000000000..ea6fba5acd5 --- /dev/null +++ b/sandbox/quantization_test.py @@ -0,0 +1,23 @@ +import torch +from sparseml.pytorch.utils import ModuleExporter +from sparseml.pytorch.models import ModelRegistry +from sparseml.pytorch.optim import ScheduledModifierManager + +model = ModelRegistry.create( + key='resnet50', + pretrained=False, + pretrained_dataset="imagenet", + num_classes=1000 +) + + +ScheduledModifierManager.from_yaml("quantization_recipe.yaml").apply(model, epoch=float("inf")) + +print(model) + +exporter = ModuleExporter(model, ".") +exporter.export_onnx( + torch.randn(1, 3, 224, 224), + "quantized_test.onnx", + convert_qat=False, +) \ No newline at end of file diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index be4182891d6..21611f211d7 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,7 +41,6 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: @@ -146,7 +145,7 @@ def __init__(self, num_channels): if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -206,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -322,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -438,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -480,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 601cdeaa6ee21914d9b4a9f7dc67484c8e820c35 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 17 Mar 2022 11:51:50 -0400 Subject: [PATCH 011/218] Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. --- .../sparsification/quantization/helpers.py | 213 ++++++++++-------- .../quantization/modifier_quantization.py | 58 ++--- 2 files changed, 137 insertions(+), 134 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 6110a499b70..8ae045de9e8 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,7 +31,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -209,10 +207,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -244,7 +242,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -257,20 +255,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -389,12 +387,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -451,23 +449,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, "freeze_bn_stats"): + if not hasattr(module, 'freeze_bn_stats'): for child_name, child_module in module.named_children(): - if type(child_module) in [ - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - ]: - setattr(module, child_name, BNWrapper(child_module)) + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -506,6 +504,17 @@ def configure_module_qat_wrappers( ) +def compute_range(dtype: torch.dtype, bits: int): + if dtype == torch.qint8: + quant_min = -2 ** (bits - 1) + quant_max = 2 ** (bits - 1) - 1 + elif dtype == torch.quint8: + quant_min = 0 + quant_max = 2 ** bits - 1 + + return quant_min, quant_max + + def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -516,7 +525,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -531,9 +540,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -557,7 +566,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -566,11 +575,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = torch.quint8, + weight_dtype: Optional[torch.dtype] = torch.qint8, + activation_bits: Optional[int] = 8, + weight_bits: Optional[int] = 8, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -591,42 +604,35 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_qscheme = ( - torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine - ) - activation_observer_kwargs = dict( - observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - dtype=torch.quint8, - qscheme=activation_qscheme, - reduce_range=reduce_range, - ) - activation_observer_kwargs.update(activation_qconfig_kwargs or {}) - activation_observer = torch_quantization.FakeQuantize.with_args( - **activation_observer_kwargs, + activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, + activation_qconfig_kwargs) + weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, ) - weight_qscheme = ( - torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + + +def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): + qscheme = ( + torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - weight_observer_kwargs = dict( + quant_min, quant_max = compute_range(dtype, bits) + observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=-128, - quant_max=127, - dtype=torch.qint8, - qscheme=weight_qscheme, + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + qscheme=qscheme, reduce_range=reduce_range, ) - - weight_observer_kwargs.update(weight_qconfig_kwargs or {}) - weight_observer = torch_quantization.FakeQuantize.with_args( - **weight_observer_kwargs, - ) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + observer_kwargs.update(qconfig_kwargs or {}) + observer = torch_quantization.FakeQuantize.with_args( + **observer_kwargs, ) + return observer + def fix_observer_quant_range(module: Module): """ @@ -642,7 +648,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -651,9 +657,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -662,14 +668,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, "freeze_bn_stats"): + if hasattr(module, 'freeze_bn_stats'): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -704,14 +710,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -748,11 +754,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -782,17 +788,24 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) + quant_min = -2 ** (bits - 1) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index acbae885d71..5a5e1913b18 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -55,7 +55,6 @@ freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, - get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -151,6 +150,7 @@ def __init__( exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + tensorrt: Optional[bool] = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -187,6 +187,7 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -234,9 +235,10 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = ( - self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" - ) + if self._tensorrt: + fuse_fn = 'no_fuse' + else: + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -360,6 +362,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -505,12 +508,11 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) - # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == "conv_bn_relus": + if self.model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -524,29 +526,16 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) - activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() - weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() - to_remove_layer_name = [] if not self._quantize_linear_output_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -554,10 +543,21 @@ def _enable_module_qat(self, module: Module): configure_module_bn_wrappers(module) # prepare each module / submodule for quantization + if self.tensorrt: + _symmetric_activations = True + _activations_dtype = torch.qint8 + else: + _symmetric_activations = False + _activations_dtype = torch.quint8 + qconfig = get_qat_qconfig( + symmetric_activations=_symmetric_activations, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=_activations_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) @@ -583,7 +583,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude if self._exclude_module_types: @@ -653,16 +653,6 @@ def _calibrate(self, module): if module_training: module.train() - def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" - ) - - def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.weight_qconfig_kwargs, self.weight_bits, "symmetric" - ) - def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From f7aecf1225612d61402b830bf5172c1c0fd06f71 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Sun, 20 Mar 2022 11:42:14 -0400 Subject: [PATCH 012/218] Added support to TensorRT quantization --- .../sparsification/quantization/helpers.py | 166 ++++++++++++++++-- .../quantization/modifier_quantization.py | 61 +++++-- 2 files changed, 195 insertions(+), 32 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 8ae045de9e8..027c7514c32 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -208,9 +208,15 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - reduce_range: bool = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -232,6 +238,18 @@ def from_module( else {} ) + qat_wrapper_kwargs["symmetric_activations"] = ( + symmetric_activations + if "symmetric_activations" not in qat_wrapper_kwargs + else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] + ) + + qat_wrapper_kwargs["symmetric_weights"] = ( + symmetric_weights or False + if "symmetric_weights" not in qat_wrapper_kwargs + else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] + ) + qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -251,6 +269,30 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + qat_wrapper_kwargs["activation_dtype"] = ( + activation_dtype + if "activation_dtype" not in qat_wrapper_kwargs + else activation_dtype or qat_wrapper_kwargs["activation_dtype"] + ) + + qat_wrapper_kwargs["weight_dtype"] = ( + weight_dtype + if "weight_dtype" not in qat_wrapper_kwargs + else weight_dtype or qat_wrapper_kwargs["weight_dtype"] + ) + + qat_wrapper_kwargs["activation_bits"] = ( + activation_bits + if "activation_bits" not in qat_wrapper_kwargs + else activation_bits or qat_wrapper_kwargs["activation_bits"] + ) + + qat_wrapper_kwargs["weight_bits"] = ( + weight_bits + if "weight_bits" not in qat_wrapper_kwargs + else weight_bits or qat_wrapper_kwargs["weight_bits"] + ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -266,9 +308,15 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -288,25 +336,43 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn + self._symmetric_activations = symmetric_activations + self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._activation_dtype = activation_dtype + self._weight_dtype = weight_dtype + self._activation_bits = activation_bits + self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -390,9 +456,15 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -422,11 +494,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) + if symmetric_activations is None: + _symmetric_activations = qconfig == "symmetric" + else: + _symmetric_activations = symmetric_activations + qconfigs[idx] = get_qat_qconfig( - symmetric_activations=(qconfig == "symmetric"), + symmetric_activations=_symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) return qconfigs @@ -463,9 +545,15 @@ def configure_module_bn_wrappers(module: Module): def configure_module_qat_wrappers( module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -490,29 +578,43 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) -def compute_range(dtype: torch.dtype, bits: int): +def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + dtype = dtype if dtype else torch.quint8 + bits = bits if bits else 8 if dtype == torch.qint8: - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 + quant_min = -(2 ** (bits - 1)) + quant_max = (2 ** (bits - 1)) - 1 elif dtype == torch.quint8: quant_min = 0 - quant_max = 2 ** bits - 1 + quant_max = (2 ** bits) - 1 - return quant_min, quant_max + return quant_min, quant_max, dtype def configure_module_default_qconfigs(module: Module): @@ -575,15 +677,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = torch.quint8, - weight_dtype: Optional[torch.dtype] = torch.qint8, - activation_bits: Optional[int] = 8, - weight_bits: Optional[int] = 8, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -606,18 +708,28 @@ def get_qat_qconfig( """ activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, activation_qconfig_kwargs) - weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + if symmetric_weights is None: + _symmetric_weights = True + else: + _symmetric_weights = symmetric_weights + + if weight_dtype is None: + _weight_dtype = torch.qint8 + else: + _weight_dtype = weight_dtype + + weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): +def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): qscheme = ( torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - quant_min, quant_max = compute_range(dtype, bits) + quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, quant_min=quant_min, @@ -756,9 +868,15 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -776,11 +894,21 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: + if symmetric_weights is None: + _symmetric_weights = False + else: + _symmetric_weights = symmetric_weights + qconfig = get_qat_qconfig( - symmetric_weights=False, + symmetric_activations=symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 5a5e1913b18..27c5a4c336e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -147,10 +147,10 @@ def __init__( weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_batchnorm: bool = True, - exclude_module_types: Union[List[str], None] = None, + exclude_module_types: Optional[List[str]] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: Optional[bool] = False, + tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -379,7 +379,15 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - return self._weight_qconfig_kwargs + if "observer" in self._weight_qconfig_kwargs: + kwargs = self._weight_qconfig_kwargs.copy() + if kwargs["observer"] == "minmaxobserver": + kwargs["observer"] = torch_quantization.MinMaxObserver + return kwargs + else: + return self._weight_qconfig_kwargs + + @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -389,6 +397,15 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps + @ModifierProp() + def tensorrt(self) -> Dict[str, Any]: + """ + :return: Dictionary with correct quant_min, quant_max, and dtype values + for activations + + """ + return self._tensorrt + def initialize( self, module: Module, @@ -545,17 +562,23 @@ def _enable_module_qat(self, module: Module): # prepare each module / submodule for quantization if self.tensorrt: _symmetric_activations = True - _activations_dtype = torch.qint8 + _activation_dtype = torch.qint8 + _symmetric_weights = True + _weight_dtype = torch.qint8 else: - _symmetric_activations = False - _activations_dtype = torch.quint8 + _symmetric_activations = None + _activation_dtype = None + _symmetric_weights = None + _weight_dtype = None qconfig = get_qat_qconfig( symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=_activations_dtype, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, activation_bits=self.activation_bits, weight_bits=self.weight_bits ) @@ -563,9 +586,15 @@ def _enable_module_qat(self, module: Module): # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -594,9 +623,15 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, ) # propagate custom quant min/max range from FakeQuantize to Observer objects From d994b45359209cc048f1c43d2fa8938ec5702786 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 21 Mar 2022 19:16:26 -0400 Subject: [PATCH 013/218] Included check to account for when weight_qconfig_kwatgs is None. --- .../sparsification/quantization/modifier_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 27c5a4c336e..a306f4d8e73 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -379,7 +379,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if "observer" in self._weight_qconfig_kwargs: + if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver From 3099efb0b0399e9c1f83a6a6d5615d514e93f1ed Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 12:14:39 -0500 Subject: [PATCH 014/218] Removed output quantization from conv layers --- .../sparsification/quantization/helpers.py | 147 +++++++++++------- .../quantization/modifier_quantization.py | 90 ++++++----- 2 files changed, 142 insertions(+), 95 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index dddd41326d2..e10224bbce7 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,20 +31,21 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ + "QUANTIZABLE_MODULE_TYPES", "QATWrapper", "configure_module_qat_wrappers", "configure_module_default_qconfigs", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", + "get_updated_qconfig_kwargs", "fix_observer_quant_range", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] -_QUANTIZABLE_MODULE_TYPES = ( +QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -106,10 +106,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -141,7 +141,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -153,20 +153,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -285,12 +285,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -331,10 +331,10 @@ def _load_qconfigs( def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -383,7 +383,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -398,9 +398,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -424,7 +424,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -433,11 +433,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -509,7 +509,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -534,9 +534,9 @@ def fix_observer_quant_range(module: Module): def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -571,14 +571,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -615,11 +615,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -648,6 +648,37 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits): + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) + + # update qconfig_kwargs for bits + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): + raise ValueError( + "Cannot override quant_max and quant_min when number of bits is set" + ) + + if bits: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( + dict( + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + ) + ) + + return qconfig_kwargs + + def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"): embedding.weight_fake_quant = qconfig.weight() diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 15ad82299d9..79772790566 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,12 +47,14 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( + QUANTIZABLE_MODULE_TYPES, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, fuse_module_conv_bn_relus, get_qat_qconfig, + get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -139,8 +141,11 @@ def __init__( model_fuse_fn_kwargs: Dict[str, Any] = None, quantize_embeddings: bool = True, reduce_range: bool = False, - quantize_linear_activations: bool = True, + quantize_linear_output_activations: bool = False, + quantize_conv_output_activations: bool = False, + quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, @@ -168,8 +173,11 @@ def __init__( self._freeze_bn_stats_epoch = freeze_bn_stats_epoch self._quantize_embeddings = quantize_embeddings self._reduce_range = reduce_range - self._quantize_linear_activations = quantize_linear_activations + self._quantize_linear_output_activations = quantize_linear_output_activations + self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits + self._weight_bits = weight_bits self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -306,7 +314,7 @@ def reduce_range(self) -> bool: return self._reduce_range @ModifierProp() - def quantize_linear_activations(self) -> bool: + def quantize_linear_output_activations(self) -> bool: """ :return: if False, FakeQuantize ops will not be run for activations of fully connected layers. this is important for quantizing @@ -314,7 +322,15 @@ def quantize_linear_activations(self) -> bool: are kept at 32 bits of precision and fake quantizing the outputs harm training recovery """ - return self._quantize_linear_activations + return self._quantize_linear_output_activations + + @ModifierProp() + def quantize_conv_output_activations(self) -> bool: + """ + :return: if False, FakeQuantize ops will not be run + for activations of convolutional layers. + """ + return self._quantize_linear_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -332,6 +348,15 @@ def activation_bits(self) -> Optional[int]: """ return self._activation_bits + @ModifierProp() + def weight_bits(self) -> Optional[int]: + """ + :return: Number of bits to be use for setting quant min/max values for + activations. Default is None, which will quantize activations to 8 bits. + """ + return self._weight_bits + + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -391,7 +416,10 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if name in self._submodules: + if ( + type(submodule) in QUANTIZABLE_MODULE_TYPES + and name in self._submodules + ): self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -499,12 +527,25 @@ def _enable_module_qat(self, module: Module): fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() + weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() + + to_remove_layer_name = [] + if not self._quantize_linear_output_activations: + to_remove_layer_name.extend(["Linear", "LinearReLu"]) + + if not self._quantize_conv_output_activations: + to_remove_layer_name.extend( + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + ) # prepare each module / submodule for quantization qconfig = get_qat_qconfig( reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) @@ -512,7 +553,7 @@ def _enable_module_qat(self, module: Module): quant_module, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -521,9 +562,7 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - - if not self._quantize_linear_activations: - remove_activation_qat_by_layer_name(quant_module, ["Linear"]) + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types if self._exclude_module_types: @@ -536,7 +575,7 @@ def _enable_module_qat(self, module: Module): module, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # propagate custom quant min/max range from FakeQuantize to Observer objects @@ -594,33 +633,10 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - activation_qconfig_kwargs = ( - self.activation_qconfig_kwargs.copy() - if self.activation_qconfig_kwargs - else {} - ) - - # update qconfig_kwargs for activation_bits - if self.activation_bits and ( - activation_qconfig_kwargs.get("quant_min") - or activation_qconfig_kwargs.get("quant_max") - ): - raise ValueError( - "Cannot override quant_max and quant_min with activation_bits enabled" - ) + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) - if self.activation_bits: - quant_min = 0 - quant_max = 2 ** self.activation_bits - 1 - dtype = torch.quint8 - activation_qconfig_kwargs.update( - dict( - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - ) - ) - return activation_qconfig_kwargs + def _get_updated_weight_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( From 36bc220111110f2493e567f4c866612038af2633 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:35:49 -0500 Subject: [PATCH 015/218] Added _Add_ReLU module that enables QATWrapper for quantizaiton. --- .../pytorch/models/classification/resnet.py | 94 +++++++++---------- 1 file changed, 46 insertions(+), 48 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 40aef8a3c69..3112da7c2e1 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,13 +41,11 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: FloatFunctional = None - __all__ = [ "ResNetSectionSettings", "ResNet", @@ -141,6 +139,23 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: return in_channels != out_channels or stride > 1 +class _AddReLU(Module): + def __init__(self): + super().__init__() + if FloatFunctional: + self.functional = FloatFunctional() + self.wrap_qat = True + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + else: + self.functional = ReLU(num_channels=out_channels, inplace=True) + + def forward(self, x, y): + if isinstance(self.functional, FloatFunctional): + return self.functional.add_relu(x, y) + else: + return self.functional(x + y) + + class _BasicBlock(Module): def __init__(self, in_channels: int, out_channels: int, stride: int = 1): super().__init__() @@ -164,11 +179,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = ( - FloatFunctional() - if FloatFunctional is not None - else ReLU(num_channels=out_channels, inplace=True) - ) + self.add_relu = _AddReLU() self.initialize() @@ -181,12 +192,7 @@ def forward(self, inp: Tensor): out = self.bn2(out) identity_val = self.identity(inp) if self.identity is not None else inp - - if isinstance(self.add_relu, FloatFunctional): - out = self.add_relu.add_relu(out, identity_val) - else: - out += identity_val - out = self.add_relu(out) + out = self.add_relu(identity_val, out) return out @@ -199,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -230,11 +236,7 @@ def __init__( else None ) - self.add_relu = ( - FloatFunctional() - if FloatFunctional is not None - else ReLU(num_channels=out_channels, inplace=True) - ) + self.add_relu = _AddReLU() self.initialize() @@ -252,11 +254,7 @@ def forward(self, inp: Tensor): identity_val = self.identity(inp) if self.identity is not None else inp - if isinstance(self.add_relu, FloatFunctional): - out = self.add_relu.add_relu(out, identity_val) - else: - out += identity_val - out = self.add_relu(out) + out = self.add_relu(identity_val, out) return out @@ -323,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -439,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -481,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 722f866a5b20a15a5753c49c759b8a622846245a Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:36:37 -0500 Subject: [PATCH 016/218] Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. --- .../sparsification/quantization/helpers.py | 6 +-- .../quantization/modifier_quantization.py | 39 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index e10224bbce7..ec69ded82c8 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,7 +32,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ - "QUANTIZABLE_MODULE_TYPES", "QATWrapper", "configure_module_qat_wrappers", "configure_module_default_qconfigs", @@ -45,7 +44,7 @@ "prepare_embeddings_qat", ] -QUANTIZABLE_MODULE_TYPES = ( +_QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -150,6 +149,7 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -398,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in QUANTIZABLE_MODULE_TYPES + type(module) in _QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 79772790566..f914b1f2b91 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,7 +47,6 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( - QUANTIZABLE_MODULE_TYPES, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, @@ -94,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use - the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' + to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,10 +142,10 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, - quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, + exclude_batchnorm: bool = True, exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, @@ -175,9 +174,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations - self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits + self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -233,7 +232,8 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - return self._model_fuse_fn_name + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + return fuse_fn @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -416,10 +416,7 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if ( - type(submodule) in QUANTIZABLE_MODULE_TYPES - and name in self._submodules - ): + if name in self._submodules: self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -509,10 +506,10 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if ( - self._model_fuse_fn_name is not None - and self._model_fuse_fn_name != "no_fuse" - ): # module class fn + if self._model_fuse_fn_name == 'conv_bn_relus': + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -522,16 +519,13 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) - elif self._model_fuse_fn_name is None: # default auto fn - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLu"]) + to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -565,8 +559,15 @@ def _enable_module_qat(self, module: Module): remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types + to_exclude = [] if self._exclude_module_types: - self._strip_excluded_module_qconfigs(module) + to_exclude.extend(self._exclude_module_types) + + if self._exclude_batchnorm: + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + + self._exclude_module_types = to_exclude + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) From b55348229e7d2ed406cf12566ad5a2b931bba30c Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:40:31 -0500 Subject: [PATCH 017/218] Minor fixes. Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 61 ++++----- .../sparsification/quantization/helpers.py | 123 +++++++++--------- .../quantization/modifier_quantization.py | 33 +++-- 3 files changed, 112 insertions(+), 105 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3112da7c2e1..be4182891d6 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -140,14 +141,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self): + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: - self.functional = ReLU(num_channels=out_channels, inplace=True) + self.functional = ReLU(num_channels=num_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +180,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -205,12 +206,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -236,7 +237,7 @@ def __init__( else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -321,12 +322,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -437,15 +438,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -479,10 +480,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index ec69ded82c8..2c1ac640d6e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_qat_wrappers", @@ -105,10 +107,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -140,7 +142,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -153,20 +155,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -285,12 +287,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -331,10 +333,10 @@ def _load_qconfigs( def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -383,7 +385,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -398,9 +400,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -424,7 +426,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -433,11 +435,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -509,7 +511,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -534,9 +536,9 @@ def fix_observer_quant_range(module: Module): def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -571,14 +573,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -615,11 +617,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -649,17 +651,10 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index f914b1f2b91..637bf7e52dd 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -93,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' - to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as + 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -232,7 +232,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" return fuse_fn @model_fuse_fn_name.setter @@ -356,7 +356,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -506,7 +505,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': + if self._model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -529,10 +528,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) # prepare each module / submodule for quantization @@ -564,7 +573,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -634,7 +643,9 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits + ) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From aef2347c0c98f680cb7222fe17cea6849d70c4f0 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:02:14 -0500 Subject: [PATCH 018/218] Added support to freezing bn stats. --- .../sparsification/quantization/helpers.py | 209 +++++++++++++----- .../quantization/modifier_quantization.py | 37 ++-- 2 files changed, 164 insertions(+), 82 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 2c1ac640d6e..a44369550b1 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,16 +31,17 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", - "configure_module_qat_wrappers", + "configure_module_bn_wrappers", "configure_module_default_qconfigs", + "configure_module_qat_wrappers", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", + "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] @@ -69,6 +69,54 @@ else None ) +_BN_MODULE_TYPES = ( + { + # Conv based layers + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.ConvBnReLU1d, + nni.ConvBnReLU2d, + nni.ConvBnReLU3d, + } + if nni # nni will always import if torch.quantization is available + else {} +) + + +class BNWrapper(Module): + def __init__(self, module: Module): + super().__init__() + self.bn = module + self.freeze_bn = False + + def forward(self, x): + return self.bn(x) + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def train(self, mode=True): + if not self.freeze_bn: + self.bn.train() + return self + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + +_BN_MODULE_TYPES.add(BNWrapper) + class QATWrapper(Module): """ @@ -107,10 +155,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -142,7 +190,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -155,20 +203,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -287,12 +335,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -332,11 +380,40 @@ def _load_qconfigs( return qconfigs +def configure_module_bn_wrappers(module: Module): + """ + if any submodule of the given module has the attribute wrap_qat == True, + then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. + Other named kwargs to the QATWrapper constructor must be contained in a dictionary + under an attributed named `qat_wrapper_kwargs` + + :param module: module to potentially wrap the submodules of + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware + Default is False + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. Default is {} + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. Default is {} + """ + # wrap any children of the given module as a QATWrapper if required + if type(module) != BNWrapper: + for child_name, child_module in module.named_children(): + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) + # recurse on child module + configure_module_bn_wrappers(child_module) + + def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -385,7 +462,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -400,9 +477,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -426,7 +503,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -435,11 +512,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -511,7 +588,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -535,10 +612,15 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True +def freeze_bn_stats(module: Module): + if type(module) in _BN_MODULE_TYPES: + module.freeze_bn_stats() + + def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -573,14 +655,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -617,11 +699,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -651,10 +733,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 637bf7e52dd..7eed410b441 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -48,9 +48,11 @@ from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( add_quant_dequant, + configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, + freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, get_updated_qconfig_kwargs, @@ -232,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' return fuse_fn @model_fuse_fn_name.setter @@ -262,7 +264,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """ + """print :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -356,6 +358,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -500,12 +503,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) + quant_module.apply(freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == "conv_bn_relus": + if self._model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -528,22 +531,14 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) + configure_module_bn_wrappers(module) + # prepare each module / submodule for quantization qconfig = get_qat_qconfig( reduce_range=self._reduce_range, @@ -573,7 +568,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -643,9 +638,7 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits - ) + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From df88638b19ba57373f42e2ca0a6a6aabda42399d Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:57:15 -0500 Subject: [PATCH 019/218] Added mode argument to wrapping of train function in BNWrapper --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index a44369550b1..48ed0708eae 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -106,7 +106,7 @@ def reset_running_stats(self): def train(self, mode=True): if not self.freeze_bn: - self.bn.train() + self.bn.train(mode) return self def update_bn_stats(self): From d0b6354832a7599bf311e57551a7f9cf057e965f Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:08:20 -0500 Subject: [PATCH 020/218] Set BN fusing back as default. --- .../sparsification/quantization/modifier_quantization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 7eed410b441..37307e38863 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -508,8 +508,8 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True + if self.model_fuse_fn_name == 'conv_bn_relus': + self.model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From fbf22068fe49aed619353f3db1da837cb4f0ecc9 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:19:09 -0500 Subject: [PATCH 021/218] Set BN fusing back as default. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- .../sparsification/quantization/modifier_quantization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 48ed0708eae..71f6553fc44 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -397,7 +397,7 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) != BNWrapper: + if type(module) not in _BN_MODULE_TYPES: for child_name, child_module in module.named_children(): if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: setattr( diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 37307e38863..2a35ebd2aaf 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -509,7 +509,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs if self.model_fuse_fn_name == 'conv_bn_relus': - self.model_fuse_fn_kwargs["inplace"] = True + self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 2312af89ee32cf027d7f6d4fc25bdf1747faba40 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 11 Mar 2022 19:24:03 -0500 Subject: [PATCH 022/218] Fixed custom freeze_bn_stats. --- .../sparsification/quantization/helpers.py | 245 +++++++++++------- .../quantization/modifier_quantization.py | 46 +++- 2 files changed, 182 insertions(+), 109 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 71f6553fc44..e09c0e29690 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,23 +71,6 @@ else None ) -_BN_MODULE_TYPES = ( - { - # Conv based layers - nni.ConvBn1d, - nni.ConvBn2d, - nni.ConvBn3d, - nni.ConvReLU1d, - nni.ConvReLU2d, - nni.ConvReLU3d, - nni.ConvBnReLU1d, - nni.ConvBnReLU2d, - nni.ConvBnReLU3d, - } - if nni # nni will always import if torch.quantization is available - else {} -) - class BNWrapper(Module): def __init__(self, module: Module): @@ -93,6 +78,78 @@ def __init__(self, module: Module): self.bn = module self.freeze_bn = False + @property + def running_mean(self): + return self.bn.running_mean + + @running_mean.setter + def running_mean(self, value): + self.bn.running_mean = value + + @property + def running_var(self): + return self.bn.running_var + + @running_var.setter + def running_var(self, value): + self.bn.running_var = value + + @property + def weight(self): + return self.bn.weight + + @weight.setter + def weight(self, value): + self.bn.weight = value + + @property + def bias(self): + return self.bn.bias + + @bias.setter + def bias(self, value): + self.bn.bias = value + + @property + def gamma(self): + return self.bn.gamma + + @gamma.setter + def gamma(self, value): + self.bn.gamma = value + + @property + def beta(self): + return self.bn.beta + + @beta.setter + def beta(self, value): + self.bn.beta = value + + @property + def num_batches_tracked(self): + return self.bn.num_batches_tracked + + @num_batches_tracked.setter + def num_batches_tracked(self, value): + self.bn.num_batches_tracked = value + + @property + def eps(self): + return self.bn.eps + + @eps.setter + def eps(self, value): + self.bn.eps = value + + @property + def momentum(self): + return self.bn.momentum + + @momentum.setter + def momentum(self, value): + self.bn.momentum = value + def forward(self, x): return self.bn(x) @@ -115,9 +172,6 @@ def update_bn_stats(self): return self -_BN_MODULE_TYPES.add(BNWrapper) - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -155,10 +209,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -190,7 +244,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -203,20 +257,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -335,12 +389,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -397,23 +451,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) not in _BN_MODULE_TYPES: + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -462,7 +516,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -477,9 +531,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -503,7 +557,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -512,11 +566,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -588,7 +642,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -613,14 +667,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if type(module) in _BN_MODULE_TYPES: + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -655,14 +709,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -699,11 +753,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -732,26 +786,25 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 + if mode == "symmetric": + quant_min = -(2 ** (bits - 1)) + quant_max = 2 ** (bits - 1) - 1 + dtype = torch.qint8 + else: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 2a35ebd2aaf..acbae885d71 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,9 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -332,7 +334,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_linear_output_activations + return self._quantize_conv_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -358,7 +360,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -504,11 +505,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) + # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -531,11 +533,23 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) + if len(to_remove_layer_name) == 0: + to_remove_layer_name = None configure_module_bn_wrappers(module) @@ -560,7 +574,8 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + if to_remove_layer_name: + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types to_exclude = [] @@ -568,10 +583,11 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude - self._strip_excluded_module_qconfigs(module) + if self._exclude_module_types: + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) @@ -638,10 +654,14 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" + ) def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + return get_updated_qconfig_kwargs( + self.weight_qconfig_kwargs, self.weight_bits, "symmetric" + ) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( From f3bb04b8f5af16eec7618ecf25c1543f5eafaea9 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 14 Mar 2022 15:35:52 -0400 Subject: [PATCH 023/218] Temporary files for evaluating changes to graphs. --- sandbox/quantization_recipe.yaml | 7 +++ sandbox/quantization_test.py | 23 ++++++++ .../pytorch/models/classification/resnet.py | 53 +++++++++---------- 3 files changed, 56 insertions(+), 27 deletions(-) create mode 100644 sandbox/quantization_recipe.yaml create mode 100644 sandbox/quantization_test.py diff --git a/sandbox/quantization_recipe.yaml b/sandbox/quantization_recipe.yaml new file mode 100644 index 00000000000..411dd6f025a --- /dev/null +++ b/sandbox/quantization_recipe.yaml @@ -0,0 +1,7 @@ +quantization_modifiers: + - !QuantizationModifier + start_epoch: -1.0 + model_fuse_fn_name: no_fuse + submodules: + - input + - sections diff --git a/sandbox/quantization_test.py b/sandbox/quantization_test.py new file mode 100644 index 00000000000..ea6fba5acd5 --- /dev/null +++ b/sandbox/quantization_test.py @@ -0,0 +1,23 @@ +import torch +from sparseml.pytorch.utils import ModuleExporter +from sparseml.pytorch.models import ModelRegistry +from sparseml.pytorch.optim import ScheduledModifierManager + +model = ModelRegistry.create( + key='resnet50', + pretrained=False, + pretrained_dataset="imagenet", + num_classes=1000 +) + + +ScheduledModifierManager.from_yaml("quantization_recipe.yaml").apply(model, epoch=float("inf")) + +print(model) + +exporter = ModuleExporter(model, ".") +exporter.export_onnx( + torch.randn(1, 3, 224, 224), + "quantized_test.onnx", + convert_qat=False, +) \ No newline at end of file diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index be4182891d6..21611f211d7 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,7 +41,6 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: @@ -146,7 +145,7 @@ def __init__(self, num_channels): if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -206,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -322,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -438,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -480,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 5bca8d75c6a9bc3fc003553f4fafde10e61812d3 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 17 Mar 2022 11:51:50 -0400 Subject: [PATCH 024/218] Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. --- .../sparsification/quantization/helpers.py | 207 ++++++++++-------- .../quantization/modifier_quantization.py | 58 ++--- 2 files changed, 134 insertions(+), 131 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index e09c0e29690..57b919470e4 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,7 +31,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -209,10 +207,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -244,7 +242,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -257,20 +255,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -389,12 +387,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -451,23 +449,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, "freeze_bn_stats"): + if not hasattr(module, 'freeze_bn_stats'): for child_name, child_module in module.named_children(): - if type(child_module) in [ - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - ]: - setattr(module, child_name, BNWrapper(child_module)) + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -506,6 +504,17 @@ def configure_module_qat_wrappers( ) +def compute_range(dtype: torch.dtype, bits: int): + if dtype == torch.qint8: + quant_min = -2 ** (bits - 1) + quant_max = 2 ** (bits - 1) - 1 + elif dtype == torch.quint8: + quant_min = 0 + quant_max = 2 ** bits - 1 + + return quant_min, quant_max + + def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -516,7 +525,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -531,9 +540,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -557,7 +566,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -566,11 +575,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = torch.quint8, + weight_dtype: Optional[torch.dtype] = torch.qint8, + activation_bits: Optional[int] = 8, + weight_bits: Optional[int] = 8, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -591,42 +604,35 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_qscheme = ( - torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine - ) - activation_observer_kwargs = dict( - observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - dtype=torch.quint8, - qscheme=activation_qscheme, - reduce_range=reduce_range, - ) - activation_observer_kwargs.update(activation_qconfig_kwargs or {}) - activation_observer = torch_quantization.FakeQuantize.with_args( - **activation_observer_kwargs, + activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, + activation_qconfig_kwargs) + weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, ) - weight_qscheme = ( - torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + + +def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): + qscheme = ( + torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - weight_observer_kwargs = dict( + quant_min, quant_max = compute_range(dtype, bits) + observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=-128, - quant_max=127, - dtype=torch.qint8, - qscheme=weight_qscheme, + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + qscheme=qscheme, reduce_range=reduce_range, ) - - weight_observer_kwargs.update(weight_qconfig_kwargs or {}) - weight_observer = torch_quantization.FakeQuantize.with_args( - **weight_observer_kwargs, - ) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + observer_kwargs.update(qconfig_kwargs or {}) + observer = torch_quantization.FakeQuantize.with_args( + **observer_kwargs, ) + return observer + def fix_observer_quant_range(module: Module): """ @@ -642,7 +648,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -667,14 +673,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, "freeze_bn_stats"): + if hasattr(module, 'freeze_bn_stats'): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -709,14 +715,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -753,11 +759,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -787,17 +793,24 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) + quant_min = -2 ** (bits - 1) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index acbae885d71..5a5e1913b18 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -55,7 +55,6 @@ freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, - get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -151,6 +150,7 @@ def __init__( exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + tensorrt: Optional[bool] = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -187,6 +187,7 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -234,9 +235,10 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = ( - self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" - ) + if self._tensorrt: + fuse_fn = 'no_fuse' + else: + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -360,6 +362,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -505,12 +508,11 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) - # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == "conv_bn_relus": + if self.model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -524,29 +526,16 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) - activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() - weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() - to_remove_layer_name = [] if not self._quantize_linear_output_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -554,10 +543,21 @@ def _enable_module_qat(self, module: Module): configure_module_bn_wrappers(module) # prepare each module / submodule for quantization + if self.tensorrt: + _symmetric_activations = True + _activations_dtype = torch.qint8 + else: + _symmetric_activations = False + _activations_dtype = torch.quint8 + qconfig = get_qat_qconfig( + symmetric_activations=_symmetric_activations, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=_activations_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) @@ -583,7 +583,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude if self._exclude_module_types: @@ -653,16 +653,6 @@ def _calibrate(self, module): if module_training: module.train() - def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" - ) - - def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.weight_qconfig_kwargs, self.weight_bits, "symmetric" - ) - def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From 55b0a407b0fd43bc0c819b2f2f4962fa17b6de18 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Sun, 20 Mar 2022 11:42:14 -0400 Subject: [PATCH 025/218] Added support to TensorRT quantization --- .../sparsification/quantization/helpers.py | 166 ++++++++++++++++-- .../quantization/modifier_quantization.py | 61 +++++-- 2 files changed, 195 insertions(+), 32 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 57b919470e4..2ae713c16aa 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -208,9 +208,15 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - reduce_range: bool = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -232,6 +238,18 @@ def from_module( else {} ) + qat_wrapper_kwargs["symmetric_activations"] = ( + symmetric_activations + if "symmetric_activations" not in qat_wrapper_kwargs + else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] + ) + + qat_wrapper_kwargs["symmetric_weights"] = ( + symmetric_weights or False + if "symmetric_weights" not in qat_wrapper_kwargs + else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] + ) + qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -251,6 +269,30 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + qat_wrapper_kwargs["activation_dtype"] = ( + activation_dtype + if "activation_dtype" not in qat_wrapper_kwargs + else activation_dtype or qat_wrapper_kwargs["activation_dtype"] + ) + + qat_wrapper_kwargs["weight_dtype"] = ( + weight_dtype + if "weight_dtype" not in qat_wrapper_kwargs + else weight_dtype or qat_wrapper_kwargs["weight_dtype"] + ) + + qat_wrapper_kwargs["activation_bits"] = ( + activation_bits + if "activation_bits" not in qat_wrapper_kwargs + else activation_bits or qat_wrapper_kwargs["activation_bits"] + ) + + qat_wrapper_kwargs["weight_bits"] = ( + weight_bits + if "weight_bits" not in qat_wrapper_kwargs + else weight_bits or qat_wrapper_kwargs["weight_bits"] + ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -266,9 +308,15 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -288,25 +336,43 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn + self._symmetric_activations = symmetric_activations + self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._activation_dtype = activation_dtype + self._weight_dtype = weight_dtype + self._activation_bits = activation_bits + self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -390,9 +456,15 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -422,11 +494,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) + if symmetric_activations is None: + _symmetric_activations = qconfig == "symmetric" + else: + _symmetric_activations = symmetric_activations + qconfigs[idx] = get_qat_qconfig( - symmetric_activations=(qconfig == "symmetric"), + symmetric_activations=_symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) return qconfigs @@ -463,9 +545,15 @@ def configure_module_bn_wrappers(module: Module): def configure_module_qat_wrappers( module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -490,29 +578,43 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) -def compute_range(dtype: torch.dtype, bits: int): +def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + dtype = dtype if dtype else torch.quint8 + bits = bits if bits else 8 if dtype == torch.qint8: - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 + quant_min = -(2 ** (bits - 1)) + quant_max = (2 ** (bits - 1)) - 1 elif dtype == torch.quint8: quant_min = 0 - quant_max = 2 ** bits - 1 + quant_max = (2 ** bits) - 1 - return quant_min, quant_max + return quant_min, quant_max, dtype def configure_module_default_qconfigs(module: Module): @@ -575,15 +677,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = torch.quint8, - weight_dtype: Optional[torch.dtype] = torch.qint8, - activation_bits: Optional[int] = 8, - weight_bits: Optional[int] = 8, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -606,18 +708,28 @@ def get_qat_qconfig( """ activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, activation_qconfig_kwargs) - weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + if symmetric_weights is None: + _symmetric_weights = True + else: + _symmetric_weights = symmetric_weights + + if weight_dtype is None: + _weight_dtype = torch.qint8 + else: + _weight_dtype = weight_dtype + + weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): +def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): qscheme = ( torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - quant_min, quant_max = compute_range(dtype, bits) + quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, quant_min=quant_min, @@ -761,9 +873,15 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -781,11 +899,21 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: + if symmetric_weights is None: + _symmetric_weights = False + else: + _symmetric_weights = symmetric_weights + qconfig = get_qat_qconfig( - symmetric_weights=False, + symmetric_activations=symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 5a5e1913b18..27c5a4c336e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -147,10 +147,10 @@ def __init__( weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_batchnorm: bool = True, - exclude_module_types: Union[List[str], None] = None, + exclude_module_types: Optional[List[str]] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: Optional[bool] = False, + tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -379,7 +379,15 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - return self._weight_qconfig_kwargs + if "observer" in self._weight_qconfig_kwargs: + kwargs = self._weight_qconfig_kwargs.copy() + if kwargs["observer"] == "minmaxobserver": + kwargs["observer"] = torch_quantization.MinMaxObserver + return kwargs + else: + return self._weight_qconfig_kwargs + + @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -389,6 +397,15 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps + @ModifierProp() + def tensorrt(self) -> Dict[str, Any]: + """ + :return: Dictionary with correct quant_min, quant_max, and dtype values + for activations + + """ + return self._tensorrt + def initialize( self, module: Module, @@ -545,17 +562,23 @@ def _enable_module_qat(self, module: Module): # prepare each module / submodule for quantization if self.tensorrt: _symmetric_activations = True - _activations_dtype = torch.qint8 + _activation_dtype = torch.qint8 + _symmetric_weights = True + _weight_dtype = torch.qint8 else: - _symmetric_activations = False - _activations_dtype = torch.quint8 + _symmetric_activations = None + _activation_dtype = None + _symmetric_weights = None + _weight_dtype = None qconfig = get_qat_qconfig( symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=_activations_dtype, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, activation_bits=self.activation_bits, weight_bits=self.weight_bits ) @@ -563,9 +586,15 @@ def _enable_module_qat(self, module: Module): # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -594,9 +623,15 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, ) # propagate custom quant min/max range from FakeQuantize to Observer objects From 5730071ec36937339ecb81eaddce531cea09fc61 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 21 Mar 2022 19:16:26 -0400 Subject: [PATCH 026/218] Included check to account for when weight_qconfig_kwatgs is None. --- .../sparsification/quantization/modifier_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 27c5a4c336e..a306f4d8e73 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -379,7 +379,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if "observer" in self._weight_qconfig_kwargs: + if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver From 4938803e140bb4ab52729435c35574a963ac29a8 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 14:20:19 -0400 Subject: [PATCH 027/218] Modified argument names for backwards compatibility. --- .../quantization/modifier_quantization.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index a306f4d8e73..73a50e0f9c4 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -141,8 +141,8 @@ def __init__( model_fuse_fn_kwargs: Dict[str, Any] = None, quantize_embeddings: bool = True, reduce_range: bool = False, - quantize_linear_output_activations: bool = False, - quantize_conv_output_activations: bool = False, + quantize_linear_activations: bool = False, + quantize_conv_activations: bool = False, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, @@ -174,8 +174,8 @@ def __init__( self._freeze_bn_stats_epoch = freeze_bn_stats_epoch self._quantize_embeddings = quantize_embeddings self._reduce_range = reduce_range - self._quantize_linear_output_activations = quantize_linear_output_activations - self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_linear_activations = quantize_linear_activations + self._quantize_conv_activations = quantize_conv_activations self._activation_bits = activation_bits self._weight_bits = weight_bits self._exclude_batchnorm = exclude_batchnorm @@ -320,7 +320,7 @@ def reduce_range(self) -> bool: return self._reduce_range @ModifierProp() - def quantize_linear_output_activations(self) -> bool: + def quantize_linear_activations(self) -> bool: """ :return: if False, FakeQuantize ops will not be run for activations of fully connected layers. this is important for quantizing @@ -328,15 +328,15 @@ def quantize_linear_output_activations(self) -> bool: are kept at 32 bits of precision and fake quantizing the outputs harm training recovery """ - return self._quantize_linear_output_activations + return self._quantize_linear_activations @ModifierProp() - def quantize_conv_output_activations(self) -> bool: + def quantize_conv_activations(self) -> bool: """ :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_conv_output_activations + return self._quantize_conv_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -544,10 +544,10 @@ def _enable_module_qat(self, module: Module): module_fuse_fn(**self._model_fuse_fn_kwargs) to_remove_layer_name = [] - if not self._quantize_linear_output_activations: + if not self._quantize_linear_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) - if not self._quantize_conv_output_activations: + if not self._quantize_conv_activations: to_remove_layer_name.extend( ["Conv1d", "Conv2d", "Conv3d", "ConvBn1d", "ConvBn2d", "ConvBn3d", From 70f57046c6540abf24fc56284f890c69af86fc14 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:40:51 -0400 Subject: [PATCH 028/218] Updated documentation to reflect changes. --- .../sparsification/quantization/helpers.py | 118 ++++++++++++------ 1 file changed, 81 insertions(+), 37 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 2ae713c16aa..bf49eb670b4 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -69,8 +69,14 @@ else None ) - +# class BNWrapper(Module): + """ + Wraps BatchNormalization module to expose methods needed to enable + freezing/unfreezing of statistics + + :param module: BatchNormalization module to be wrapped + """ def __init__(self, module: Module): super().__init__() self.bn = module @@ -220,14 +226,25 @@ def from_module( ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for - :param reduce_range: if True, the quantization range will be reduced by one - bit. This may prevent overflow issues with model execution on certain - hardware. Default is None, will only override qat_wrapper_kwargs if set - to a bool value + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. :return: QATWrapper object created using the given Module as the forward function. Will attempt to find any other named parameter of the QATWrapper constructor from the attributes of the given Module @@ -293,6 +310,7 @@ def from_module( else weight_bits or qat_wrapper_kwargs["weight_bits"] ) + # Remove qconfig from wrapped layer to avoid duplicate quantization module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -516,19 +534,10 @@ def _load_qconfigs( def configure_module_bn_wrappers(module: Module): """ - if any submodule of the given module has the attribute wrap_qat == True, - then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. - Other named kwargs to the QATWrapper constructor must be contained in a dictionary - under an attributed named `qat_wrapper_kwargs` + Wrap any BatchNormalization modules that are not fused with convolutions + with BNWrapper to enable freezing/unfreezing of BN statistics :param module: module to potentially wrap the submodules of - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required if not hasattr(module, 'freeze_bn_stats'): @@ -562,14 +571,25 @@ def configure_module_qat_wrappers( under an attributed named `qat_wrapper_kwargs` :param module: module to potentially wrap the submodules of + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} - """ + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. """ # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: @@ -605,6 +625,13 @@ def configure_module_qat_wrappers( def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + """ + compute quantization limits depending on data type and number of bits + + :param dtype: data type. If None dtype is set to torch.quint8. + :param bits: number of bits. If None is set to 8. + :return: minimum limit, maximum limit, data type + """ dtype = dtype if dtype else torch.quint8 bits = bits if bits else 8 if dtype == torch.qint8: @@ -689,18 +716,24 @@ def get_qat_qconfig( ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric - UINT8 quantization range with zero point set to 128. Otherwise activations - will use asymmetric quantization with any zero point. Default is False + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. :param symmetric_weights: if True, weights will have a symmetric - INT8 quantization range with zero point set to 0. Otherwise activations - will use asymmetric quantization with any zero point. Default is True + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. :return: A QAT fake quantization config for symmetric weight quantization and asymmetric activation quantization. The difference between this and torch.quantization.default_qat_qconfig is that the activation observer @@ -890,14 +923,25 @@ def prepare_embeddings_qat( :param module: module to run QAT for the embeddings of :param qconfig: qconfig to generate the fake quantize ops from. Default uses INT8 asymmetric range - :param activation_qconfig_kwargs: additional kwargs for quantizing activations. - Default is {}. - :param weight_qconfig_kwargs: additional kwargs for quantizing the weights. - Default is {}. + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware. - Default is False - """ + Default is False. + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. """ if qconfig is None: if symmetric_weights is None: _symmetric_weights = False From dc773f0bf92d1afcf233f4dae4dec74e02a4e4f3 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:40:57 -0400 Subject: [PATCH 029/218] Updated documentation to reflect changes. --- .../quantization/modifier_quantization.py | 59 ++++++++++++------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 73a50e0f9c4..4f912b3d8bb 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -113,21 +113,26 @@ class QuantizationModifier(ScheduledModifier): :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware Default is False - :param quantize_linear_activations: if False, FakeQuantize ops will not be run - for activations of fully connected layers. this is important for quantizing - transformer based models such as BERT where the quantized MatMul outputs - are kept at 32 bits of precision and fake quantizing the outputs harm training - recovery. Default is True + :param quantize_linear_activations: if True, FakeQuantize ops will be run + for output activations of fully connected layers. Default is False. + :param quantize_conv_activations: if True, FakeQuantize ops will be run + for output activations of convolutional layers. Default is False. :param activation_bits: Number of bits to use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. + activations. Default is None, which will quantize activations to 8 bits. + :param weight_bits: Number of bits to use for setting quant min/max values for + weights. Default is None, which will quantize weights to 8 bits. :param num_calibration_steps: Number of steps to run post training calibration for. - When None, the entire calibration_dataloader is used + When None, the entire calibration_dataloader is used + :param exclude_batchnorm: If True, do not propagate quantization qconfigs to + batch-normalization modules :param exclude_module_types: optional list of module class names to not propagate quantization configs to. Default is None :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. + weights. + :param tenssorrt: if True sets quantization configuration for compatibility with + explict quantization as supported by TensorRT 8.2. """ def __init__( @@ -232,11 +237,12 @@ def submodules(self, value: Union[List[str], None]): def model_fuse_fn_name(self) -> Union[str, None]: """ :return: Name of model function to fuse the model in place prior - to performing QAT. None to uses the default function + to performing QAT. None sets to default function. + If tensorrt flag is True, default is 'no_fuse', otherwise `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ if self._tensorrt: - fuse_fn = 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' else: fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @@ -322,19 +328,16 @@ def reduce_range(self) -> bool: @ModifierProp() def quantize_linear_activations(self) -> bool: """ - :return: if False, FakeQuantize ops will not be run - for activations of fully connected layers. this is important for quantizing - transformer based models such as BERT where the quantized MatMul outputs - are kept at 32 bits of precision and fake quantizing the outputs harm - training recovery + :return: if True, FakeQuantize ops will be run for output activations + of fully connected layers """ return self._quantize_linear_activations @ModifierProp() def quantize_conv_activations(self) -> bool: """ - :return: if False, FakeQuantize ops will not be run - for activations of convolutional layers. + :return: if True, FakeQuantize ops will be run for output activations + of convolutional layers """ return self._quantize_conv_activations @@ -358,7 +361,7 @@ def activation_bits(self) -> Optional[int]: def weight_bits(self) -> Optional[int]: """ :return: Number of bits to be use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. + weights. Default is None, which will quantize weights to 8 bits. """ return self._weight_bits @@ -543,6 +546,7 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) + # build list of layer types that should not quantize output activations to_remove_layer_name = [] if not self._quantize_linear_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) @@ -557,9 +561,16 @@ def _enable_module_qat(self, module: Module): if len(to_remove_layer_name) == 0: to_remove_layer_name = None + # fix for freezing batchnorm statistics when not fusing BN with convs. + # pytorch only supports freezing batchnorm statistics for fused modules. + # this fix wraps BN modules adding with a new module class that supports + # methods related to freezing/unfreezing BN statistics. configure_module_bn_wrappers(module) - # prepare each module / submodule for quantization + # set qconfig. + # if tensorrt flag is used, set activation and weights to symmetric + # quantization. + # otherwise, use the default values set in get_qat_qconfig if self.tensorrt: _symmetric_activations = True _activation_dtype = torch.qint8 @@ -582,6 +593,8 @@ def _enable_module_qat(self, module: Module): activation_bits=self.activation_bits, weight_bits=self.weight_bits ) + + # prepare each module / submodule for quantization for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( @@ -596,13 +609,17 @@ def _enable_module_qat(self, module: Module): activation_bits=self.activation_bits, weight_bits=self.weight_bits ) + # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig + # wrap all conv / linear blocks in with quantization observers torch_quantization.propagate_qconfig_(quant_module) configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) + + # Remove output quantization from appropriate modules if to_remove_layer_name: remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) @@ -611,6 +628,8 @@ def _enable_module_qat(self, module: Module): if self._exclude_module_types: to_exclude.extend(self._exclude_module_types) + # if exclude_batchnorm flag is used, add batch norm layers to list of + # modules to exclude qconfig if self._exclude_batchnorm: to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) From 8e2ddcb73072c91249d196612e4fe4558f7b235c Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:42:27 -0400 Subject: [PATCH 030/218] Updated documentation to reflect changes. --- src/sparseml/pytorch/models/classification/resnet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 21611f211d7..3a7a5169447 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -140,6 +140,10 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): + """ + Wrapper for the FloatFunctional class that enables QATWrapper used to + quantize the first input to the Add operation + """ def __init__(self, num_channels): super().__init__() if FloatFunctional: From dc561f8f41f7c007663be1755257aec6df31f99c Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:52:15 -0400 Subject: [PATCH 031/218] Fixed default weights data type. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index bf49eb670b4..3d3687d2c28 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -751,7 +751,7 @@ def get_qat_qconfig( else: _weight_dtype = weight_dtype - weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + weight_observer = get_observer(_symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, From 8c0a1e8c1718a94f8a7d128bb9ede1ba82c97b8a Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 17:02:48 -0400 Subject: [PATCH 032/218] Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 54 ++-- .../sparsification/quantization/helpers.py | 247 +++++++++--------- .../quantization/modifier_quantization.py | 44 +++- 3 files changed, 186 insertions(+), 159 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3a7a5169447..cd8b979c3ad 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -144,12 +145,13 @@ class _AddReLU(Module): Wrapper for the FloatFunctional class that enables QATWrapper used to quantize the first input to the Add operation """ + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -209,12 +211,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -325,12 +327,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -441,15 +443,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -483,10 +485,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 3d3687d2c28..0a4d2fc2e4e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,7 +71,7 @@ else None ) -# + class BNWrapper(Module): """ Wraps BatchNormalization module to expose methods needed to enable @@ -77,6 +79,7 @@ class BNWrapper(Module): :param module: BatchNormalization module to be wrapped """ + def __init__(self, module: Module): super().__init__() self.bn = module @@ -213,16 +216,16 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -241,8 +244,10 @@ def from_module( activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_dtype: quantized activation data type. + Default is torch.quint8. + :param weight_dtype: quantized weights data type. + Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. :param weight_bits: number of bits for weights. Default is 8. :return: QATWrapper object created using the given Module as the forward @@ -277,7 +282,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -315,26 +320,26 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -471,18 +476,18 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -540,29 +545,29 @@ def configure_module_bn_wrappers(module: Module): :param module: module to potentially wrap the submodules of """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, 'freeze_bn_stats'): + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -589,7 +594,7 @@ def configure_module_qat_wrappers( :param activation_dtype: quantized activation data type. Default is torch.quint8. :param weight_dtype: quantized weights data type. Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. """ + :param weight_bits: number of bits for weights. Default is 8.""" # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: @@ -654,7 +659,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -669,9 +674,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -695,7 +700,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -704,15 +709,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -739,8 +744,13 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, - activation_qconfig_kwargs) + activation_observer = get_observer( + symmetric_activations, + activation_dtype, + activation_bits, + reduce_range, + activation_qconfig_kwargs, + ) if symmetric_weights is None: _symmetric_weights = True else: @@ -751,17 +761,23 @@ def get_qat_qconfig( else: _weight_dtype = weight_dtype - weight_observer = get_observer(_symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs) + weight_observer = get_observer( + _symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs + ) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): - qscheme = ( - torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine - ) +def get_observer( + symmetric: Optional[bool], + dtype: Optional[torch.dtype], + bits: Optional[int], + reduce_range: bool, + qconfig_kwargs: Dict[str, Any], +): + qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, @@ -793,7 +809,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -818,14 +834,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, 'freeze_bn_stats'): + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -860,14 +876,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -904,17 +920,17 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -941,7 +957,7 @@ def prepare_embeddings_qat( :param activation_dtype: quantized activation data type. Default is torch.quint8. :param weight_dtype: quantized weights data type. Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. """ + :param weight_bits: number of bits for weights. Default is 8.""" if qconfig is None: if symmetric_weights is None: _symmetric_weights = False @@ -965,24 +981,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -2 ** (bits - 1) + quant_min = -(2 ** (bits - 1)) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 4f912b3d8bb..30e1aefbe15 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -242,9 +242,15 @@ def model_fuse_fn_name(self) -> Union[str, None]: `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ if self._tensorrt: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + ) else: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name + if self._model_fuse_fn_name + else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -365,7 +371,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -382,7 +387,10 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: + if ( + self._weight_qconfig_kwargs is not None + and "observer" in self._weight_qconfig_kwargs + ): kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver @@ -390,8 +398,6 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: else: return self._weight_qconfig_kwargs - - @ModifierProp() def num_calibration_steps(self) -> Optional[int]: """ @@ -532,7 +538,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -553,10 +559,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -591,7 +607,7 @@ def _enable_module_qat(self, module: Module): activation_dtype=_activation_dtype, weight_dtype=_weight_dtype, activation_bits=self.activation_bits, - weight_bits=self.weight_bits + weight_bits=self.weight_bits, ) # prepare each module / submodule for quantization @@ -607,7 +623,7 @@ def _enable_module_qat(self, module: Module): activation_dtype=_activation_dtype, weight_dtype=_weight_dtype, activation_bits=self.activation_bits, - weight_bits=self.weight_bits + weight_bits=self.weight_bits, ) # set quantization config (asymmetric activations, symmetric weights) @@ -631,7 +647,7 @@ def _enable_module_qat(self, module: Module): # if exclude_batchnorm flag is used, add batch norm layers to list of # modules to exclude qconfig if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude if self._exclude_module_types: From a771a2ab8708e5dea78c3b6cd796826ab9c57f19 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 17:53:05 -0400 Subject: [PATCH 033/218] Removed unused method --- .../sparsification/quantization/helpers.py | 31 ------------------- 1 file changed, 31 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 0a4d2fc2e4e..cc78d06ae0f 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -41,7 +41,6 @@ "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", - "get_updated_qconfig_kwargs", "fix_observer_quant_range", "freeze_bn_stats", "fuse_module_conv_bn_relus", @@ -980,36 +979,6 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} - - # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): - raise ValueError( - "Cannot override quant_max and quant_min when number of bits is set" - ) - - if bits: - if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) - quant_max = 2 ** (bits - 1) - 1 - dtype = torch.qint8 - else: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 - - qconfig_kwargs.update( - dict( - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - ) - ) - - return qconfig_kwargs - - def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"): embedding.weight_fake_quant = qconfig.weight() From 8f9d9d76969b95b7e38c5c7076bf9a05302427cc Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 12:14:39 -0500 Subject: [PATCH 034/218] Removed output quantization from conv layers --- .../sparsification/quantization/helpers.py | 147 +++++++++++------- .../quantization/modifier_quantization.py | 90 ++++++----- 2 files changed, 142 insertions(+), 95 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index dddd41326d2..e10224bbce7 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,20 +31,21 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ + "QUANTIZABLE_MODULE_TYPES", "QATWrapper", "configure_module_qat_wrappers", "configure_module_default_qconfigs", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", + "get_updated_qconfig_kwargs", "fix_observer_quant_range", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] -_QUANTIZABLE_MODULE_TYPES = ( +QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -106,10 +106,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -141,7 +141,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -153,20 +153,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -285,12 +285,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -331,10 +331,10 @@ def _load_qconfigs( def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -383,7 +383,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -398,9 +398,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -424,7 +424,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -433,11 +433,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -509,7 +509,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -534,9 +534,9 @@ def fix_observer_quant_range(module: Module): def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -571,14 +571,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -615,11 +615,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -648,6 +648,37 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits): + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) + + # update qconfig_kwargs for bits + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): + raise ValueError( + "Cannot override quant_max and quant_min when number of bits is set" + ) + + if bits: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( + dict( + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + ) + ) + + return qconfig_kwargs + + def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"): embedding.weight_fake_quant = qconfig.weight() diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 15ad82299d9..79772790566 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,12 +47,14 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( + QUANTIZABLE_MODULE_TYPES, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, fuse_module_conv_bn_relus, get_qat_qconfig, + get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -139,8 +141,11 @@ def __init__( model_fuse_fn_kwargs: Dict[str, Any] = None, quantize_embeddings: bool = True, reduce_range: bool = False, - quantize_linear_activations: bool = True, + quantize_linear_output_activations: bool = False, + quantize_conv_output_activations: bool = False, + quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, @@ -168,8 +173,11 @@ def __init__( self._freeze_bn_stats_epoch = freeze_bn_stats_epoch self._quantize_embeddings = quantize_embeddings self._reduce_range = reduce_range - self._quantize_linear_activations = quantize_linear_activations + self._quantize_linear_output_activations = quantize_linear_output_activations + self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits + self._weight_bits = weight_bits self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -306,7 +314,7 @@ def reduce_range(self) -> bool: return self._reduce_range @ModifierProp() - def quantize_linear_activations(self) -> bool: + def quantize_linear_output_activations(self) -> bool: """ :return: if False, FakeQuantize ops will not be run for activations of fully connected layers. this is important for quantizing @@ -314,7 +322,15 @@ def quantize_linear_activations(self) -> bool: are kept at 32 bits of precision and fake quantizing the outputs harm training recovery """ - return self._quantize_linear_activations + return self._quantize_linear_output_activations + + @ModifierProp() + def quantize_conv_output_activations(self) -> bool: + """ + :return: if False, FakeQuantize ops will not be run + for activations of convolutional layers. + """ + return self._quantize_linear_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -332,6 +348,15 @@ def activation_bits(self) -> Optional[int]: """ return self._activation_bits + @ModifierProp() + def weight_bits(self) -> Optional[int]: + """ + :return: Number of bits to be use for setting quant min/max values for + activations. Default is None, which will quantize activations to 8 bits. + """ + return self._weight_bits + + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -391,7 +416,10 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if name in self._submodules: + if ( + type(submodule) in QUANTIZABLE_MODULE_TYPES + and name in self._submodules + ): self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -499,12 +527,25 @@ def _enable_module_qat(self, module: Module): fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() + weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() + + to_remove_layer_name = [] + if not self._quantize_linear_output_activations: + to_remove_layer_name.extend(["Linear", "LinearReLu"]) + + if not self._quantize_conv_output_activations: + to_remove_layer_name.extend( + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + ) # prepare each module / submodule for quantization qconfig = get_qat_qconfig( reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) @@ -512,7 +553,7 @@ def _enable_module_qat(self, module: Module): quant_module, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -521,9 +562,7 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - - if not self._quantize_linear_activations: - remove_activation_qat_by_layer_name(quant_module, ["Linear"]) + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types if self._exclude_module_types: @@ -536,7 +575,7 @@ def _enable_module_qat(self, module: Module): module, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # propagate custom quant min/max range from FakeQuantize to Observer objects @@ -594,33 +633,10 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - activation_qconfig_kwargs = ( - self.activation_qconfig_kwargs.copy() - if self.activation_qconfig_kwargs - else {} - ) - - # update qconfig_kwargs for activation_bits - if self.activation_bits and ( - activation_qconfig_kwargs.get("quant_min") - or activation_qconfig_kwargs.get("quant_max") - ): - raise ValueError( - "Cannot override quant_max and quant_min with activation_bits enabled" - ) + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) - if self.activation_bits: - quant_min = 0 - quant_max = 2 ** self.activation_bits - 1 - dtype = torch.quint8 - activation_qconfig_kwargs.update( - dict( - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - ) - ) - return activation_qconfig_kwargs + def _get_updated_weight_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( From 1293a21c14a01196b6ad5536c38e0c3d1dfbe678 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:35:49 -0500 Subject: [PATCH 035/218] Added _Add_ReLU module that enables QATWrapper for quantizaiton. --- .../pytorch/models/classification/resnet.py | 94 +++++++++---------- 1 file changed, 46 insertions(+), 48 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 40aef8a3c69..3112da7c2e1 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,13 +41,11 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: FloatFunctional = None - __all__ = [ "ResNetSectionSettings", "ResNet", @@ -141,6 +139,23 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: return in_channels != out_channels or stride > 1 +class _AddReLU(Module): + def __init__(self): + super().__init__() + if FloatFunctional: + self.functional = FloatFunctional() + self.wrap_qat = True + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + else: + self.functional = ReLU(num_channels=out_channels, inplace=True) + + def forward(self, x, y): + if isinstance(self.functional, FloatFunctional): + return self.functional.add_relu(x, y) + else: + return self.functional(x + y) + + class _BasicBlock(Module): def __init__(self, in_channels: int, out_channels: int, stride: int = 1): super().__init__() @@ -164,11 +179,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = ( - FloatFunctional() - if FloatFunctional is not None - else ReLU(num_channels=out_channels, inplace=True) - ) + self.add_relu = _AddReLU() self.initialize() @@ -181,12 +192,7 @@ def forward(self, inp: Tensor): out = self.bn2(out) identity_val = self.identity(inp) if self.identity is not None else inp - - if isinstance(self.add_relu, FloatFunctional): - out = self.add_relu.add_relu(out, identity_val) - else: - out += identity_val - out = self.add_relu(out) + out = self.add_relu(identity_val, out) return out @@ -199,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -230,11 +236,7 @@ def __init__( else None ) - self.add_relu = ( - FloatFunctional() - if FloatFunctional is not None - else ReLU(num_channels=out_channels, inplace=True) - ) + self.add_relu = _AddReLU() self.initialize() @@ -252,11 +254,7 @@ def forward(self, inp: Tensor): identity_val = self.identity(inp) if self.identity is not None else inp - if isinstance(self.add_relu, FloatFunctional): - out = self.add_relu.add_relu(out, identity_val) - else: - out += identity_val - out = self.add_relu(out) + out = self.add_relu(identity_val, out) return out @@ -323,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -439,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -481,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From db277ca222b2fffffe6bce129f13363494c022d1 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:36:37 -0500 Subject: [PATCH 036/218] Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. --- .../sparsification/quantization/helpers.py | 6 +-- .../quantization/modifier_quantization.py | 39 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index e10224bbce7..ec69ded82c8 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,7 +32,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ - "QUANTIZABLE_MODULE_TYPES", "QATWrapper", "configure_module_qat_wrappers", "configure_module_default_qconfigs", @@ -45,7 +44,7 @@ "prepare_embeddings_qat", ] -QUANTIZABLE_MODULE_TYPES = ( +_QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -150,6 +149,7 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -398,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in QUANTIZABLE_MODULE_TYPES + type(module) in _QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 79772790566..f914b1f2b91 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,7 +47,6 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( - QUANTIZABLE_MODULE_TYPES, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, @@ -94,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use - the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' + to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,10 +142,10 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, - quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, + exclude_batchnorm: bool = True, exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, @@ -175,9 +174,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations - self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits + self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -233,7 +232,8 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - return self._model_fuse_fn_name + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + return fuse_fn @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -416,10 +416,7 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if ( - type(submodule) in QUANTIZABLE_MODULE_TYPES - and name in self._submodules - ): + if name in self._submodules: self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -509,10 +506,10 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if ( - self._model_fuse_fn_name is not None - and self._model_fuse_fn_name != "no_fuse" - ): # module class fn + if self._model_fuse_fn_name == 'conv_bn_relus': + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -522,16 +519,13 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) - elif self._model_fuse_fn_name is None: # default auto fn - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLu"]) + to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -565,8 +559,15 @@ def _enable_module_qat(self, module: Module): remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types + to_exclude = [] if self._exclude_module_types: - self._strip_excluded_module_qconfigs(module) + to_exclude.extend(self._exclude_module_types) + + if self._exclude_batchnorm: + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + + self._exclude_module_types = to_exclude + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) From 04a1ddff314f156cd0e31d24b21387a40e4303c9 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:40:31 -0500 Subject: [PATCH 037/218] Minor fixes. Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 61 ++++----- .../sparsification/quantization/helpers.py | 123 +++++++++--------- .../quantization/modifier_quantization.py | 33 +++-- 3 files changed, 112 insertions(+), 105 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3112da7c2e1..be4182891d6 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -140,14 +141,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self): + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: - self.functional = ReLU(num_channels=out_channels, inplace=True) + self.functional = ReLU(num_channels=num_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +180,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -205,12 +206,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -236,7 +237,7 @@ def __init__( else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -321,12 +322,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -437,15 +438,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -479,10 +480,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index ec69ded82c8..2c1ac640d6e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_qat_wrappers", @@ -105,10 +107,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -140,7 +142,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -153,20 +155,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -285,12 +287,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -331,10 +333,10 @@ def _load_qconfigs( def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -383,7 +385,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -398,9 +400,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -424,7 +426,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -433,11 +435,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -509,7 +511,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -534,9 +536,9 @@ def fix_observer_quant_range(module: Module): def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -571,14 +573,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -615,11 +617,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -649,17 +651,10 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index f914b1f2b91..637bf7e52dd 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -93,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' - to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as + 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -232,7 +232,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" return fuse_fn @model_fuse_fn_name.setter @@ -356,7 +356,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -506,7 +505,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': + if self._model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -529,10 +528,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) # prepare each module / submodule for quantization @@ -564,7 +573,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -634,7 +643,9 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits + ) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From c2bd48eb9962c8acdb6a05a7bab46c5529725f17 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:02:14 -0500 Subject: [PATCH 038/218] Added support to freezing bn stats. --- .../sparsification/quantization/helpers.py | 209 +++++++++++++----- .../quantization/modifier_quantization.py | 37 ++-- 2 files changed, 164 insertions(+), 82 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 2c1ac640d6e..a44369550b1 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,16 +31,17 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", - "configure_module_qat_wrappers", + "configure_module_bn_wrappers", "configure_module_default_qconfigs", + "configure_module_qat_wrappers", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", + "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] @@ -69,6 +69,54 @@ else None ) +_BN_MODULE_TYPES = ( + { + # Conv based layers + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.ConvBnReLU1d, + nni.ConvBnReLU2d, + nni.ConvBnReLU3d, + } + if nni # nni will always import if torch.quantization is available + else {} +) + + +class BNWrapper(Module): + def __init__(self, module: Module): + super().__init__() + self.bn = module + self.freeze_bn = False + + def forward(self, x): + return self.bn(x) + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def train(self, mode=True): + if not self.freeze_bn: + self.bn.train() + return self + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + +_BN_MODULE_TYPES.add(BNWrapper) + class QATWrapper(Module): """ @@ -107,10 +155,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -142,7 +190,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -155,20 +203,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -287,12 +335,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -332,11 +380,40 @@ def _load_qconfigs( return qconfigs +def configure_module_bn_wrappers(module: Module): + """ + if any submodule of the given module has the attribute wrap_qat == True, + then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. + Other named kwargs to the QATWrapper constructor must be contained in a dictionary + under an attributed named `qat_wrapper_kwargs` + + :param module: module to potentially wrap the submodules of + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware + Default is False + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. Default is {} + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. Default is {} + """ + # wrap any children of the given module as a QATWrapper if required + if type(module) != BNWrapper: + for child_name, child_module in module.named_children(): + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) + # recurse on child module + configure_module_bn_wrappers(child_module) + + def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -385,7 +462,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -400,9 +477,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -426,7 +503,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -435,11 +512,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -511,7 +588,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -535,10 +612,15 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True +def freeze_bn_stats(module: Module): + if type(module) in _BN_MODULE_TYPES: + module.freeze_bn_stats() + + def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -573,14 +655,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -617,11 +699,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -651,10 +733,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 637bf7e52dd..7eed410b441 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -48,9 +48,11 @@ from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( add_quant_dequant, + configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, + freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, get_updated_qconfig_kwargs, @@ -232,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' return fuse_fn @model_fuse_fn_name.setter @@ -262,7 +264,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """ + """print :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -356,6 +358,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -500,12 +503,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) + quant_module.apply(freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == "conv_bn_relus": + if self._model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -528,22 +531,14 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) + configure_module_bn_wrappers(module) + # prepare each module / submodule for quantization qconfig = get_qat_qconfig( reduce_range=self._reduce_range, @@ -573,7 +568,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -643,9 +638,7 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits - ) + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From b5248d43e2959ea7e390a7df88f56eb5aec16109 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:57:15 -0500 Subject: [PATCH 039/218] Added mode argument to wrapping of train function in BNWrapper --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index a44369550b1..48ed0708eae 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -106,7 +106,7 @@ def reset_running_stats(self): def train(self, mode=True): if not self.freeze_bn: - self.bn.train() + self.bn.train(mode) return self def update_bn_stats(self): From a64d626e925c2c0338d10beb5b83269faf8524bb Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:08:20 -0500 Subject: [PATCH 040/218] Set BN fusing back as default. --- .../sparsification/quantization/modifier_quantization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 7eed410b441..37307e38863 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -508,8 +508,8 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True + if self.model_fuse_fn_name == 'conv_bn_relus': + self.model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 6ece7fc7a1dc235b759c25284e403a66597b544b Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:19:09 -0500 Subject: [PATCH 041/218] Set BN fusing back as default. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- .../sparsification/quantization/modifier_quantization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 48ed0708eae..71f6553fc44 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -397,7 +397,7 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) != BNWrapper: + if type(module) not in _BN_MODULE_TYPES: for child_name, child_module in module.named_children(): if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: setattr( diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 37307e38863..2a35ebd2aaf 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -509,7 +509,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs if self.model_fuse_fn_name == 'conv_bn_relus': - self.model_fuse_fn_kwargs["inplace"] = True + self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From c26ec4c5c93316c99cb3cd4506f4dda3bbb29b8f Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 11 Mar 2022 19:24:03 -0500 Subject: [PATCH 042/218] Fixed custom freeze_bn_stats. --- .../sparsification/quantization/helpers.py | 245 +++++++++++------- .../quantization/modifier_quantization.py | 46 +++- 2 files changed, 182 insertions(+), 109 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 71f6553fc44..e09c0e29690 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,23 +71,6 @@ else None ) -_BN_MODULE_TYPES = ( - { - # Conv based layers - nni.ConvBn1d, - nni.ConvBn2d, - nni.ConvBn3d, - nni.ConvReLU1d, - nni.ConvReLU2d, - nni.ConvReLU3d, - nni.ConvBnReLU1d, - nni.ConvBnReLU2d, - nni.ConvBnReLU3d, - } - if nni # nni will always import if torch.quantization is available - else {} -) - class BNWrapper(Module): def __init__(self, module: Module): @@ -93,6 +78,78 @@ def __init__(self, module: Module): self.bn = module self.freeze_bn = False + @property + def running_mean(self): + return self.bn.running_mean + + @running_mean.setter + def running_mean(self, value): + self.bn.running_mean = value + + @property + def running_var(self): + return self.bn.running_var + + @running_var.setter + def running_var(self, value): + self.bn.running_var = value + + @property + def weight(self): + return self.bn.weight + + @weight.setter + def weight(self, value): + self.bn.weight = value + + @property + def bias(self): + return self.bn.bias + + @bias.setter + def bias(self, value): + self.bn.bias = value + + @property + def gamma(self): + return self.bn.gamma + + @gamma.setter + def gamma(self, value): + self.bn.gamma = value + + @property + def beta(self): + return self.bn.beta + + @beta.setter + def beta(self, value): + self.bn.beta = value + + @property + def num_batches_tracked(self): + return self.bn.num_batches_tracked + + @num_batches_tracked.setter + def num_batches_tracked(self, value): + self.bn.num_batches_tracked = value + + @property + def eps(self): + return self.bn.eps + + @eps.setter + def eps(self, value): + self.bn.eps = value + + @property + def momentum(self): + return self.bn.momentum + + @momentum.setter + def momentum(self, value): + self.bn.momentum = value + def forward(self, x): return self.bn(x) @@ -115,9 +172,6 @@ def update_bn_stats(self): return self -_BN_MODULE_TYPES.add(BNWrapper) - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -155,10 +209,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -190,7 +244,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -203,20 +257,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -335,12 +389,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -397,23 +451,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) not in _BN_MODULE_TYPES: + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -462,7 +516,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -477,9 +531,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -503,7 +557,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -512,11 +566,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -588,7 +642,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -613,14 +667,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if type(module) in _BN_MODULE_TYPES: + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -655,14 +709,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -699,11 +753,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -732,26 +786,25 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 + if mode == "symmetric": + quant_min = -(2 ** (bits - 1)) + quant_max = 2 ** (bits - 1) - 1 + dtype = torch.qint8 + else: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 2a35ebd2aaf..acbae885d71 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,9 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -332,7 +334,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_linear_output_activations + return self._quantize_conv_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -358,7 +360,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -504,11 +505,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) + # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -531,11 +533,23 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) + if len(to_remove_layer_name) == 0: + to_remove_layer_name = None configure_module_bn_wrappers(module) @@ -560,7 +574,8 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + if to_remove_layer_name: + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types to_exclude = [] @@ -568,10 +583,11 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude - self._strip_excluded_module_qconfigs(module) + if self._exclude_module_types: + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) @@ -638,10 +654,14 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" + ) def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + return get_updated_qconfig_kwargs( + self.weight_qconfig_kwargs, self.weight_bits, "symmetric" + ) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( From 58b60fe98751c5853e5f6cff100b92f63a893a3e Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 14 Mar 2022 15:35:52 -0400 Subject: [PATCH 043/218] Temporary files for evaluating changes to graphs. --- sandbox/quantization_recipe.yaml | 7 +++ sandbox/quantization_test.py | 23 ++++++++ .../pytorch/models/classification/resnet.py | 53 +++++++++---------- 3 files changed, 56 insertions(+), 27 deletions(-) create mode 100644 sandbox/quantization_recipe.yaml create mode 100644 sandbox/quantization_test.py diff --git a/sandbox/quantization_recipe.yaml b/sandbox/quantization_recipe.yaml new file mode 100644 index 00000000000..411dd6f025a --- /dev/null +++ b/sandbox/quantization_recipe.yaml @@ -0,0 +1,7 @@ +quantization_modifiers: + - !QuantizationModifier + start_epoch: -1.0 + model_fuse_fn_name: no_fuse + submodules: + - input + - sections diff --git a/sandbox/quantization_test.py b/sandbox/quantization_test.py new file mode 100644 index 00000000000..ea6fba5acd5 --- /dev/null +++ b/sandbox/quantization_test.py @@ -0,0 +1,23 @@ +import torch +from sparseml.pytorch.utils import ModuleExporter +from sparseml.pytorch.models import ModelRegistry +from sparseml.pytorch.optim import ScheduledModifierManager + +model = ModelRegistry.create( + key='resnet50', + pretrained=False, + pretrained_dataset="imagenet", + num_classes=1000 +) + + +ScheduledModifierManager.from_yaml("quantization_recipe.yaml").apply(model, epoch=float("inf")) + +print(model) + +exporter = ModuleExporter(model, ".") +exporter.export_onnx( + torch.randn(1, 3, 224, 224), + "quantized_test.onnx", + convert_qat=False, +) \ No newline at end of file diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index be4182891d6..21611f211d7 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,7 +41,6 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: @@ -146,7 +145,7 @@ def __init__(self, num_channels): if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -206,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -322,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -438,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -480,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 38a502a90ee6d4b34015ed2fc0e99a18797cf491 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 17 Mar 2022 11:51:50 -0400 Subject: [PATCH 044/218] Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. --- .../sparsification/quantization/helpers.py | 207 ++++++++++-------- .../quantization/modifier_quantization.py | 58 ++--- 2 files changed, 134 insertions(+), 131 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index e09c0e29690..57b919470e4 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,7 +31,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -209,10 +207,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -244,7 +242,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -257,20 +255,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -389,12 +387,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -451,23 +449,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, "freeze_bn_stats"): + if not hasattr(module, 'freeze_bn_stats'): for child_name, child_module in module.named_children(): - if type(child_module) in [ - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - ]: - setattr(module, child_name, BNWrapper(child_module)) + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -506,6 +504,17 @@ def configure_module_qat_wrappers( ) +def compute_range(dtype: torch.dtype, bits: int): + if dtype == torch.qint8: + quant_min = -2 ** (bits - 1) + quant_max = 2 ** (bits - 1) - 1 + elif dtype == torch.quint8: + quant_min = 0 + quant_max = 2 ** bits - 1 + + return quant_min, quant_max + + def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -516,7 +525,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -531,9 +540,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -557,7 +566,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -566,11 +575,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = torch.quint8, + weight_dtype: Optional[torch.dtype] = torch.qint8, + activation_bits: Optional[int] = 8, + weight_bits: Optional[int] = 8, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -591,42 +604,35 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_qscheme = ( - torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine - ) - activation_observer_kwargs = dict( - observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - dtype=torch.quint8, - qscheme=activation_qscheme, - reduce_range=reduce_range, - ) - activation_observer_kwargs.update(activation_qconfig_kwargs or {}) - activation_observer = torch_quantization.FakeQuantize.with_args( - **activation_observer_kwargs, + activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, + activation_qconfig_kwargs) + weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, ) - weight_qscheme = ( - torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + + +def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): + qscheme = ( + torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - weight_observer_kwargs = dict( + quant_min, quant_max = compute_range(dtype, bits) + observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=-128, - quant_max=127, - dtype=torch.qint8, - qscheme=weight_qscheme, + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + qscheme=qscheme, reduce_range=reduce_range, ) - - weight_observer_kwargs.update(weight_qconfig_kwargs or {}) - weight_observer = torch_quantization.FakeQuantize.with_args( - **weight_observer_kwargs, - ) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + observer_kwargs.update(qconfig_kwargs or {}) + observer = torch_quantization.FakeQuantize.with_args( + **observer_kwargs, ) + return observer + def fix_observer_quant_range(module: Module): """ @@ -642,7 +648,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -667,14 +673,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, "freeze_bn_stats"): + if hasattr(module, 'freeze_bn_stats'): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -709,14 +715,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -753,11 +759,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -787,17 +793,24 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) + quant_min = -2 ** (bits - 1) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index acbae885d71..5a5e1913b18 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -55,7 +55,6 @@ freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, - get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -151,6 +150,7 @@ def __init__( exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + tensorrt: Optional[bool] = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -187,6 +187,7 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -234,9 +235,10 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = ( - self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" - ) + if self._tensorrt: + fuse_fn = 'no_fuse' + else: + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -360,6 +362,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -505,12 +508,11 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) - # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == "conv_bn_relus": + if self.model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -524,29 +526,16 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) - activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() - weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() - to_remove_layer_name = [] if not self._quantize_linear_output_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -554,10 +543,21 @@ def _enable_module_qat(self, module: Module): configure_module_bn_wrappers(module) # prepare each module / submodule for quantization + if self.tensorrt: + _symmetric_activations = True + _activations_dtype = torch.qint8 + else: + _symmetric_activations = False + _activations_dtype = torch.quint8 + qconfig = get_qat_qconfig( + symmetric_activations=_symmetric_activations, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=_activations_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) @@ -583,7 +583,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude if self._exclude_module_types: @@ -653,16 +653,6 @@ def _calibrate(self, module): if module_training: module.train() - def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" - ) - - def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.weight_qconfig_kwargs, self.weight_bits, "symmetric" - ) - def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From 77e5267171d654071a6dde3bd02c9fdcac2dc7c1 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Sun, 20 Mar 2022 11:42:14 -0400 Subject: [PATCH 045/218] Added support to TensorRT quantization --- .../sparsification/quantization/helpers.py | 166 ++++++++++++++++-- .../quantization/modifier_quantization.py | 61 +++++-- 2 files changed, 195 insertions(+), 32 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 57b919470e4..2ae713c16aa 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -208,9 +208,15 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - reduce_range: bool = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -232,6 +238,18 @@ def from_module( else {} ) + qat_wrapper_kwargs["symmetric_activations"] = ( + symmetric_activations + if "symmetric_activations" not in qat_wrapper_kwargs + else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] + ) + + qat_wrapper_kwargs["symmetric_weights"] = ( + symmetric_weights or False + if "symmetric_weights" not in qat_wrapper_kwargs + else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] + ) + qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -251,6 +269,30 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + qat_wrapper_kwargs["activation_dtype"] = ( + activation_dtype + if "activation_dtype" not in qat_wrapper_kwargs + else activation_dtype or qat_wrapper_kwargs["activation_dtype"] + ) + + qat_wrapper_kwargs["weight_dtype"] = ( + weight_dtype + if "weight_dtype" not in qat_wrapper_kwargs + else weight_dtype or qat_wrapper_kwargs["weight_dtype"] + ) + + qat_wrapper_kwargs["activation_bits"] = ( + activation_bits + if "activation_bits" not in qat_wrapper_kwargs + else activation_bits or qat_wrapper_kwargs["activation_bits"] + ) + + qat_wrapper_kwargs["weight_bits"] = ( + weight_bits + if "weight_bits" not in qat_wrapper_kwargs + else weight_bits or qat_wrapper_kwargs["weight_bits"] + ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -266,9 +308,15 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -288,25 +336,43 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn + self._symmetric_activations = symmetric_activations + self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._activation_dtype = activation_dtype + self._weight_dtype = weight_dtype + self._activation_bits = activation_bits + self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -390,9 +456,15 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -422,11 +494,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) + if symmetric_activations is None: + _symmetric_activations = qconfig == "symmetric" + else: + _symmetric_activations = symmetric_activations + qconfigs[idx] = get_qat_qconfig( - symmetric_activations=(qconfig == "symmetric"), + symmetric_activations=_symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) return qconfigs @@ -463,9 +545,15 @@ def configure_module_bn_wrappers(module: Module): def configure_module_qat_wrappers( module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -490,29 +578,43 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) -def compute_range(dtype: torch.dtype, bits: int): +def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + dtype = dtype if dtype else torch.quint8 + bits = bits if bits else 8 if dtype == torch.qint8: - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 + quant_min = -(2 ** (bits - 1)) + quant_max = (2 ** (bits - 1)) - 1 elif dtype == torch.quint8: quant_min = 0 - quant_max = 2 ** bits - 1 + quant_max = (2 ** bits) - 1 - return quant_min, quant_max + return quant_min, quant_max, dtype def configure_module_default_qconfigs(module: Module): @@ -575,15 +677,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = torch.quint8, - weight_dtype: Optional[torch.dtype] = torch.qint8, - activation_bits: Optional[int] = 8, - weight_bits: Optional[int] = 8, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -606,18 +708,28 @@ def get_qat_qconfig( """ activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, activation_qconfig_kwargs) - weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + if symmetric_weights is None: + _symmetric_weights = True + else: + _symmetric_weights = symmetric_weights + + if weight_dtype is None: + _weight_dtype = torch.qint8 + else: + _weight_dtype = weight_dtype + + weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): +def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): qscheme = ( torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - quant_min, quant_max = compute_range(dtype, bits) + quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, quant_min=quant_min, @@ -761,9 +873,15 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -781,11 +899,21 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: + if symmetric_weights is None: + _symmetric_weights = False + else: + _symmetric_weights = symmetric_weights + qconfig = get_qat_qconfig( - symmetric_weights=False, + symmetric_activations=symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 5a5e1913b18..27c5a4c336e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -147,10 +147,10 @@ def __init__( weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_batchnorm: bool = True, - exclude_module_types: Union[List[str], None] = None, + exclude_module_types: Optional[List[str]] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: Optional[bool] = False, + tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -379,7 +379,15 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - return self._weight_qconfig_kwargs + if "observer" in self._weight_qconfig_kwargs: + kwargs = self._weight_qconfig_kwargs.copy() + if kwargs["observer"] == "minmaxobserver": + kwargs["observer"] = torch_quantization.MinMaxObserver + return kwargs + else: + return self._weight_qconfig_kwargs + + @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -389,6 +397,15 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps + @ModifierProp() + def tensorrt(self) -> Dict[str, Any]: + """ + :return: Dictionary with correct quant_min, quant_max, and dtype values + for activations + + """ + return self._tensorrt + def initialize( self, module: Module, @@ -545,17 +562,23 @@ def _enable_module_qat(self, module: Module): # prepare each module / submodule for quantization if self.tensorrt: _symmetric_activations = True - _activations_dtype = torch.qint8 + _activation_dtype = torch.qint8 + _symmetric_weights = True + _weight_dtype = torch.qint8 else: - _symmetric_activations = False - _activations_dtype = torch.quint8 + _symmetric_activations = None + _activation_dtype = None + _symmetric_weights = None + _weight_dtype = None qconfig = get_qat_qconfig( symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=_activations_dtype, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, activation_bits=self.activation_bits, weight_bits=self.weight_bits ) @@ -563,9 +586,15 @@ def _enable_module_qat(self, module: Module): # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -594,9 +623,15 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, ) # propagate custom quant min/max range from FakeQuantize to Observer objects From 2a25845fe00bdabec79de1a33bf9639f930d7d6b Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 21 Mar 2022 19:16:26 -0400 Subject: [PATCH 046/218] Included check to account for when weight_qconfig_kwatgs is None. --- .../sparsification/quantization/modifier_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 27c5a4c336e..a306f4d8e73 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -379,7 +379,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if "observer" in self._weight_qconfig_kwargs: + if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver From 9ed579d81d9cee28c51a67cd75c0a87984ff7cc0 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 12:14:39 -0500 Subject: [PATCH 047/218] Removed output quantization from conv layers --- .../sparsification/quantization/helpers.py | 377 +++--------------- .../quantization/modifier_quantization.py | 130 ++---- 2 files changed, 87 insertions(+), 420 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 2ae713c16aa..75d11c67c31 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,21 +32,20 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ + "QUANTIZABLE_MODULE_TYPES", "QATWrapper", - "configure_module_bn_wrappers", - "configure_module_default_qconfigs", "configure_module_qat_wrappers", + "configure_module_default_qconfigs", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", - "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] -_QUANTIZABLE_MODULE_TYPES = ( +QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -70,106 +69,6 @@ ) -class BNWrapper(Module): - def __init__(self, module: Module): - super().__init__() - self.bn = module - self.freeze_bn = False - - @property - def running_mean(self): - return self.bn.running_mean - - @running_mean.setter - def running_mean(self, value): - self.bn.running_mean = value - - @property - def running_var(self): - return self.bn.running_var - - @running_var.setter - def running_var(self, value): - self.bn.running_var = value - - @property - def weight(self): - return self.bn.weight - - @weight.setter - def weight(self, value): - self.bn.weight = value - - @property - def bias(self): - return self.bn.bias - - @bias.setter - def bias(self, value): - self.bn.bias = value - - @property - def gamma(self): - return self.bn.gamma - - @gamma.setter - def gamma(self, value): - self.bn.gamma = value - - @property - def beta(self): - return self.bn.beta - - @beta.setter - def beta(self, value): - self.bn.beta = value - - @property - def num_batches_tracked(self): - return self.bn.num_batches_tracked - - @num_batches_tracked.setter - def num_batches_tracked(self, value): - self.bn.num_batches_tracked = value - - @property - def eps(self): - return self.bn.eps - - @eps.setter - def eps(self, value): - self.bn.eps = value - - @property - def momentum(self): - return self.bn.momentum - - @momentum.setter - def momentum(self, value): - self.bn.momentum = value - - def forward(self, x): - return self.bn(x) - - def freeze_bn_stats(self): - self.freeze_bn = True - self.bn.training = False - return self - - def reset_running_stats(self): - self.bn.reset_running_stats() - - def train(self, mode=True): - if not self.freeze_bn: - self.bn.train(mode) - return self - - def update_bn_stats(self): - self.freeze_bn = False - self.bn.training = True - return self - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -208,15 +107,9 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, + reduce_range: bool = None, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -238,18 +131,6 @@ def from_module( else {} ) - qat_wrapper_kwargs["symmetric_activations"] = ( - symmetric_activations - if "symmetric_activations" not in qat_wrapper_kwargs - else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] - ) - - qat_wrapper_kwargs["symmetric_weights"] = ( - symmetric_weights or False - if "symmetric_weights" not in qat_wrapper_kwargs - else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] - ) - qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -269,31 +150,6 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) - qat_wrapper_kwargs["activation_dtype"] = ( - activation_dtype - if "activation_dtype" not in qat_wrapper_kwargs - else activation_dtype or qat_wrapper_kwargs["activation_dtype"] - ) - - qat_wrapper_kwargs["weight_dtype"] = ( - weight_dtype - if "weight_dtype" not in qat_wrapper_kwargs - else weight_dtype or qat_wrapper_kwargs["weight_dtype"] - ) - - qat_wrapper_kwargs["activation_bits"] = ( - activation_bits - if "activation_bits" not in qat_wrapper_kwargs - else activation_bits or qat_wrapper_kwargs["activation_bits"] - ) - - qat_wrapper_kwargs["weight_bits"] = ( - weight_bits - if "weight_bits" not in qat_wrapper_kwargs - else weight_bits or qat_wrapper_kwargs["weight_bits"] - ) - - module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -308,15 +164,9 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ): super().__init__() @@ -336,43 +186,25 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn - self._symmetric_activations = symmetric_activations - self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._activation_dtype = activation_dtype - self._weight_dtype = weight_dtype - self._activation_bits = activation_bits - self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -456,15 +288,9 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -494,66 +320,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) - if symmetric_activations is None: - _symmetric_activations = qconfig == "symmetric" - else: - _symmetric_activations = symmetric_activations - qconfigs[idx] = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=symmetric_weights, + symmetric_activations=(qconfig == "symmetric"), reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) return qconfigs -def configure_module_bn_wrappers(module: Module): - """ - if any submodule of the given module has the attribute wrap_qat == True, - then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. - Other named kwargs to the QATWrapper constructor must be contained in a dictionary - under an attributed named `qat_wrapper_kwargs` - - :param module: module to potentially wrap the submodules of - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} - """ - # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, 'freeze_bn_stats'): - for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) - # recurse on child module - configure_module_bn_wrappers(child_module) - - def configure_module_qat_wrappers( module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -578,45 +359,20 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) -def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): - dtype = dtype if dtype else torch.quint8 - bits = bits if bits else 8 - if dtype == torch.qint8: - quant_min = -(2 ** (bits - 1)) - quant_max = (2 ** (bits - 1)) - 1 - elif dtype == torch.quint8: - quant_min = 0 - quant_max = (2 ** bits) - 1 - - return quant_min, quant_max, dtype - - def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -642,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES + type(module) in QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): @@ -677,15 +433,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -706,44 +458,41 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, - activation_qconfig_kwargs) - if symmetric_weights is None: - _symmetric_weights = True - else: - _symmetric_weights = symmetric_weights - - if weight_dtype is None: - _weight_dtype = torch.qint8 - else: - _weight_dtype = weight_dtype - - weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + activation_qscheme = ( + torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine ) - - -def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): - qscheme = ( - torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine - ) - quant_min, quant_max, dtype = compute_range(dtype, bits) - observer_kwargs = dict( + activation_observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - qscheme=qscheme, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=activation_qscheme, reduce_range=reduce_range, ) - observer_kwargs.update(qconfig_kwargs or {}) - observer = torch_quantization.FakeQuantize.with_args( - **observer_kwargs, + activation_observer_kwargs.update(activation_qconfig_kwargs or {}) + activation_observer = torch_quantization.FakeQuantize.with_args( + **activation_observer_kwargs, + ) + weight_qscheme = ( + torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + ) + weight_observer_kwargs = dict( + observer=torch_quantization.MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=weight_qscheme, + reduce_range=reduce_range, ) - return observer + weight_observer_kwargs.update(weight_qconfig_kwargs or {}) + weight_observer = torch_quantization.FakeQuantize.with_args( + **weight_observer_kwargs, + ) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, + ) def fix_observer_quant_range(module: Module): @@ -769,14 +518,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) - or ( # do not propagate default uint8 symmetric range - observer.qscheme == torch.per_tensor_symmetric - and fake_quantize.quant_min == 0 - and fake_quantize.quant_max == 255 - ) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -784,11 +528,6 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True -def freeze_bn_stats(module: Module): - if hasattr(module, 'freeze_bn_stats'): - module.freeze_bn_stats() - - def fuse_module_conv_bn_relus( module: Module, inplace: bool = True, @@ -873,15 +612,9 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -899,28 +632,18 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: - if symmetric_weights is None: - _symmetric_weights = False - else: - _symmetric_weights = symmetric_weights - qconfig = get_qat_qconfig( - symmetric_activations=symmetric_activations, - symmetric_weights=_symmetric_weights, + symmetric_weights=False, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): +def get_updated_qconfig_kwargs(qconfig_kwargs, bits): qconfig_kwargs = ( qconfig_kwargs.copy() if qconfig_kwargs @@ -937,15 +660,9 @@ def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): ) if bits: - if mode == "symmetric": - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 - dtype = torch.qint8 - else: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 - + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index a306f4d8e73..79772790566 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,14 +47,14 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( + QUANTIZABLE_MODULE_TYPES, add_quant_dequant, - configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, - freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, + get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -94,8 +94,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as - 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use + the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,14 +143,13 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, + quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, - exclude_batchnorm: bool = True, - exclude_module_types: Optional[List[str]] = None, + exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -176,9 +175,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits - self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -187,7 +186,6 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -235,11 +233,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - if self._tensorrt: - fuse_fn = 'no_fuse' - else: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' - return fuse_fn + return self._model_fuse_fn_name @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -268,7 +262,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """print + """ :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -336,7 +330,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_conv_output_activations + return self._quantize_linear_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -379,15 +373,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: - kwargs = self._weight_qconfig_kwargs.copy() - if kwargs["observer"] == "minmaxobserver": - kwargs["observer"] = torch_quantization.MinMaxObserver - return kwargs - else: - return self._weight_qconfig_kwargs - - + return self._weight_qconfig_kwargs @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -397,15 +383,6 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps - @ModifierProp() - def tensorrt(self) -> Dict[str, Any]: - """ - :return: Dictionary with correct quant_min, quant_max, and dtype values - for activations - - """ - return self._tensorrt - def initialize( self, module: Module, @@ -439,7 +416,10 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if name in self._submodules: + if ( + type(submodule) in QUANTIZABLE_MODULE_TYPES + and name in self._submodules + ): self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -524,15 +504,15 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(freeze_bn_stats) + quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) - elif self.model_fuse_fn_name != "no_fuse": + if ( + self._model_fuse_fn_name is not None + and self._model_fuse_fn_name != "no_fuse" + ): # module class fn module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -542,10 +522,16 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) + elif self._model_fuse_fn_name is None: # default auto fn + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + + activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() + weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLU"]) + to_remove_layer_name.extend(["Linear", "LinearReLu"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -554,47 +540,20 @@ def _enable_module_qat(self, module: Module): "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) - if len(to_remove_layer_name) == 0: - to_remove_layer_name = None - - configure_module_bn_wrappers(module) # prepare each module / submodule for quantization - if self.tensorrt: - _symmetric_activations = True - _activation_dtype = torch.qint8 - _symmetric_weights = True - _weight_dtype = torch.qint8 - else: - _symmetric_activations = None - _activation_dtype = None - _symmetric_weights = None - _weight_dtype = None - qconfig = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -603,18 +562,9 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - if to_remove_layer_name: - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types - to_exclude = [] - if self._exclude_module_types: - to_exclude.extend(self._exclude_module_types) - - if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) - - self._exclude_module_types = to_exclude if self._exclude_module_types: self._strip_excluded_module_qconfigs(module) @@ -623,15 +573,9 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # propagate custom quant min/max range from FakeQuantize to Observer objects @@ -688,6 +632,12 @@ def _calibrate(self, module): if module_training: module.train() + def _get_updated_activation_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + + def _get_updated_weight_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From 43cf40407090537629a29b019b372db2efb9b347 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:35:49 -0500 Subject: [PATCH 048/218] Added _Add_ReLU module that enables QATWrapper for quantization. --- src/sparseml/pytorch/models/classification/resnet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 21611f211d7..3112da7c2e1 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -140,14 +140,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self, num_channels): + def __init__(self): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: - self.functional = ReLU(num_channels=num_channels, inplace=True) + self.functional = ReLU(num_channels=out_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +179,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU(out_channels) + self.add_relu = _AddReLU() self.initialize() @@ -236,7 +236,7 @@ def __init__( else None ) - self.add_relu = _AddReLU(out_channels) + self.add_relu = _AddReLU() self.initialize() From 05a8735d9479be2f24ca5c21ed449f2f2b428985 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:36:37 -0500 Subject: [PATCH 049/218] Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. --- .../sparsification/quantization/helpers.py | 6 +-- .../quantization/modifier_quantization.py | 39 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 75d11c67c31..f28656f1712 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,7 +32,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ - "QUANTIZABLE_MODULE_TYPES", "QATWrapper", "configure_module_qat_wrappers", "configure_module_default_qconfigs", @@ -45,7 +44,7 @@ "prepare_embeddings_qat", ] -QUANTIZABLE_MODULE_TYPES = ( +_QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -150,6 +149,7 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -398,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in QUANTIZABLE_MODULE_TYPES + type(module) in _QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 79772790566..f914b1f2b91 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,7 +47,6 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( - QUANTIZABLE_MODULE_TYPES, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, @@ -94,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use - the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' + to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,10 +142,10 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, - quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, + exclude_batchnorm: bool = True, exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, @@ -175,9 +174,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations - self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits + self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -233,7 +232,8 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - return self._model_fuse_fn_name + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + return fuse_fn @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -416,10 +416,7 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if ( - type(submodule) in QUANTIZABLE_MODULE_TYPES - and name in self._submodules - ): + if name in self._submodules: self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -509,10 +506,10 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if ( - self._model_fuse_fn_name is not None - and self._model_fuse_fn_name != "no_fuse" - ): # module class fn + if self._model_fuse_fn_name == 'conv_bn_relus': + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -522,16 +519,13 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) - elif self._model_fuse_fn_name is None: # default auto fn - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLu"]) + to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -565,8 +559,15 @@ def _enable_module_qat(self, module: Module): remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types + to_exclude = [] if self._exclude_module_types: - self._strip_excluded_module_qconfigs(module) + to_exclude.extend(self._exclude_module_types) + + if self._exclude_batchnorm: + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + + self._exclude_module_types = to_exclude + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) From 458e2713baba3d80ef4eed228da1a9cf9dd370d9 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:40:31 -0500 Subject: [PATCH 050/218] Minor fixes. Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 61 +++++---- .../sparsification/quantization/helpers.py | 129 +++++++++--------- .../quantization/modifier_quantization.py | 33 +++-- 3 files changed, 115 insertions(+), 108 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3112da7c2e1..be4182891d6 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -140,14 +141,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self): + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: - self.functional = ReLU(num_channels=out_channels, inplace=True) + self.functional = ReLU(num_channels=num_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +180,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -205,12 +206,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -236,7 +237,7 @@ def __init__( else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -321,12 +322,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -437,15 +438,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -479,10 +480,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index f28656f1712..ef4445a0d5f 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_qat_wrappers", @@ -105,10 +107,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -140,7 +142,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -153,20 +155,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -285,12 +287,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -331,10 +333,10 @@ def _load_qconfigs( def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -383,7 +385,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -398,9 +400,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -424,7 +426,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -433,11 +435,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -509,7 +511,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -518,9 +520,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -529,9 +531,9 @@ def fix_observer_quant_range(module: Module): def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -566,14 +568,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -610,11 +612,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -644,17 +646,10 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index f914b1f2b91..637bf7e52dd 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -93,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' - to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as + 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -232,7 +232,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" return fuse_fn @model_fuse_fn_name.setter @@ -356,7 +356,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -506,7 +505,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': + if self._model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -529,10 +528,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) # prepare each module / submodule for quantization @@ -564,7 +573,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -634,7 +643,9 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits + ) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From 10b12448d007947809436c2ad304858eed2cb86e Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:02:14 -0500 Subject: [PATCH 051/218] Added support to freezing bn stats. --- .../sparsification/quantization/helpers.py | 215 +++++++++++++----- .../quantization/modifier_quantization.py | 37 ++- 2 files changed, 167 insertions(+), 85 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index ef4445a0d5f..c4f165d23ef 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,16 +31,17 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", - "configure_module_qat_wrappers", + "configure_module_bn_wrappers", "configure_module_default_qconfigs", + "configure_module_qat_wrappers", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", + "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] @@ -69,6 +69,54 @@ else None ) +_BN_MODULE_TYPES = ( + { + # Conv based layers + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.ConvBnReLU1d, + nni.ConvBnReLU2d, + nni.ConvBnReLU3d, + } + if nni # nni will always import if torch.quantization is available + else {} +) + + +class BNWrapper(Module): + def __init__(self, module: Module): + super().__init__() + self.bn = module + self.freeze_bn = False + + def forward(self, x): + return self.bn(x) + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def train(self, mode=True): + if not self.freeze_bn: + self.bn.train() + return self + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + +_BN_MODULE_TYPES.add(BNWrapper) + class QATWrapper(Module): """ @@ -107,10 +155,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -142,7 +190,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -155,20 +203,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -287,12 +335,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -332,11 +380,40 @@ def _load_qconfigs( return qconfigs +def configure_module_bn_wrappers(module: Module): + """ + if any submodule of the given module has the attribute wrap_qat == True, + then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. + Other named kwargs to the QATWrapper constructor must be contained in a dictionary + under an attributed named `qat_wrapper_kwargs` + + :param module: module to potentially wrap the submodules of + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware + Default is False + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. Default is {} + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. Default is {} + """ + # wrap any children of the given module as a QATWrapper if required + if type(module) != BNWrapper: + for child_name, child_module in module.named_children(): + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) + # recurse on child module + configure_module_bn_wrappers(child_module) + + def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -385,7 +462,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -400,9 +477,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -426,7 +503,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -435,11 +512,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -511,7 +588,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -520,9 +597,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -530,10 +607,15 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True +def freeze_bn_stats(module: Module): + if type(module) in _BN_MODULE_TYPES: + module.freeze_bn_stats() + + def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -568,14 +650,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -612,11 +694,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -646,10 +728,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 637bf7e52dd..7eed410b441 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -48,9 +48,11 @@ from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( add_quant_dequant, + configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, + freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, get_updated_qconfig_kwargs, @@ -232,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' return fuse_fn @model_fuse_fn_name.setter @@ -262,7 +264,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """ + """print :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -356,6 +358,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -500,12 +503,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) + quant_module.apply(freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == "conv_bn_relus": + if self._model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -528,22 +531,14 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) + configure_module_bn_wrappers(module) + # prepare each module / submodule for quantization qconfig = get_qat_qconfig( reduce_range=self._reduce_range, @@ -573,7 +568,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -643,9 +638,7 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits - ) + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From 2bd1cfd3e083cce41287195f67bd6e8a490b066e Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:57:15 -0500 Subject: [PATCH 052/218] Added mode argument to wrapping of train function in BNWrapper --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index c4f165d23ef..64958570e2d 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -106,7 +106,7 @@ def reset_running_stats(self): def train(self, mode=True): if not self.freeze_bn: - self.bn.train() + self.bn.train(mode) return self def update_bn_stats(self): From 4c6c1299cc6a47961821a87f182ef667ac0db73e Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:08:20 -0500 Subject: [PATCH 053/218] Set BN fusing back as default. --- .../sparsification/quantization/modifier_quantization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 7eed410b441..37307e38863 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -508,8 +508,8 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True + if self.model_fuse_fn_name == 'conv_bn_relus': + self.model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 91831fdf9b4ff7bffc651042dbd3009e7ec68560 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:19:09 -0500 Subject: [PATCH 054/218] Set BN fusing back as default. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- .../sparsification/quantization/modifier_quantization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 64958570e2d..a43d69d947b 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -397,7 +397,7 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) != BNWrapper: + if type(module) not in _BN_MODULE_TYPES: for child_name, child_module in module.named_children(): if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: setattr( diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 37307e38863..2a35ebd2aaf 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -509,7 +509,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs if self.model_fuse_fn_name == 'conv_bn_relus': - self.model_fuse_fn_kwargs["inplace"] = True + self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 37b73da224a6deaea974b677f28c9e2ed7972e7c Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 11 Mar 2022 19:24:03 -0500 Subject: [PATCH 055/218] Fixed custom freeze_bn_stats. --- .../sparsification/quantization/helpers.py | 251 +++++++++++------- .../quantization/modifier_quantization.py | 46 +++- 2 files changed, 185 insertions(+), 112 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index a43d69d947b..6110a499b70 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,23 +71,6 @@ else None ) -_BN_MODULE_TYPES = ( - { - # Conv based layers - nni.ConvBn1d, - nni.ConvBn2d, - nni.ConvBn3d, - nni.ConvReLU1d, - nni.ConvReLU2d, - nni.ConvReLU3d, - nni.ConvBnReLU1d, - nni.ConvBnReLU2d, - nni.ConvBnReLU3d, - } - if nni # nni will always import if torch.quantization is available - else {} -) - class BNWrapper(Module): def __init__(self, module: Module): @@ -93,6 +78,78 @@ def __init__(self, module: Module): self.bn = module self.freeze_bn = False + @property + def running_mean(self): + return self.bn.running_mean + + @running_mean.setter + def running_mean(self, value): + self.bn.running_mean = value + + @property + def running_var(self): + return self.bn.running_var + + @running_var.setter + def running_var(self, value): + self.bn.running_var = value + + @property + def weight(self): + return self.bn.weight + + @weight.setter + def weight(self, value): + self.bn.weight = value + + @property + def bias(self): + return self.bn.bias + + @bias.setter + def bias(self, value): + self.bn.bias = value + + @property + def gamma(self): + return self.bn.gamma + + @gamma.setter + def gamma(self, value): + self.bn.gamma = value + + @property + def beta(self): + return self.bn.beta + + @beta.setter + def beta(self, value): + self.bn.beta = value + + @property + def num_batches_tracked(self): + return self.bn.num_batches_tracked + + @num_batches_tracked.setter + def num_batches_tracked(self, value): + self.bn.num_batches_tracked = value + + @property + def eps(self): + return self.bn.eps + + @eps.setter + def eps(self, value): + self.bn.eps = value + + @property + def momentum(self): + return self.bn.momentum + + @momentum.setter + def momentum(self, value): + self.bn.momentum = value + def forward(self, x): return self.bn(x) @@ -115,9 +172,6 @@ def update_bn_stats(self): return self -_BN_MODULE_TYPES.add(BNWrapper) - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -155,10 +209,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -190,7 +244,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -203,20 +257,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -335,12 +389,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -397,23 +451,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) not in _BN_MODULE_TYPES: + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -462,7 +516,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -477,9 +531,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -503,7 +557,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -512,11 +566,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -588,7 +642,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -597,9 +651,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -608,14 +662,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if type(module) in _BN_MODULE_TYPES: + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -650,14 +704,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -694,11 +748,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -727,26 +781,25 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 + if mode == "symmetric": + quant_min = -(2 ** (bits - 1)) + quant_max = 2 ** (bits - 1) - 1 + dtype = torch.qint8 + else: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 2a35ebd2aaf..acbae885d71 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,9 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -332,7 +334,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_linear_output_activations + return self._quantize_conv_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -358,7 +360,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -504,11 +505,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) + # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -531,11 +533,23 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) + if len(to_remove_layer_name) == 0: + to_remove_layer_name = None configure_module_bn_wrappers(module) @@ -560,7 +574,8 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + if to_remove_layer_name: + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types to_exclude = [] @@ -568,10 +583,11 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude - self._strip_excluded_module_qconfigs(module) + if self._exclude_module_types: + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) @@ -638,10 +654,14 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" + ) def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + return get_updated_qconfig_kwargs( + self.weight_qconfig_kwargs, self.weight_bits, "symmetric" + ) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( From 8e06b4171dda3e09d02e53d3b245fcec4a7f4a90 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 14 Mar 2022 15:35:52 -0400 Subject: [PATCH 056/218] Temporary files for evaluating changes to graphs. --- .../pytorch/models/classification/resnet.py | 53 +++++++++---------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index be4182891d6..21611f211d7 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,7 +41,6 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: @@ -146,7 +145,7 @@ def __init__(self, num_channels): if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -206,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -322,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -438,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -480,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 21acec138cb48817378c6b5606db442bfe6a0ded Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 17 Mar 2022 11:51:50 -0400 Subject: [PATCH 057/218] Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. --- .../sparsification/quantization/helpers.py | 213 ++++++++++-------- .../quantization/modifier_quantization.py | 58 ++--- 2 files changed, 137 insertions(+), 134 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 6110a499b70..8ae045de9e8 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,7 +31,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -209,10 +207,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -244,7 +242,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -257,20 +255,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -389,12 +387,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -451,23 +449,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, "freeze_bn_stats"): + if not hasattr(module, 'freeze_bn_stats'): for child_name, child_module in module.named_children(): - if type(child_module) in [ - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - ]: - setattr(module, child_name, BNWrapper(child_module)) + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -506,6 +504,17 @@ def configure_module_qat_wrappers( ) +def compute_range(dtype: torch.dtype, bits: int): + if dtype == torch.qint8: + quant_min = -2 ** (bits - 1) + quant_max = 2 ** (bits - 1) - 1 + elif dtype == torch.quint8: + quant_min = 0 + quant_max = 2 ** bits - 1 + + return quant_min, quant_max + + def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -516,7 +525,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -531,9 +540,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -557,7 +566,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -566,11 +575,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = torch.quint8, + weight_dtype: Optional[torch.dtype] = torch.qint8, + activation_bits: Optional[int] = 8, + weight_bits: Optional[int] = 8, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -591,42 +604,35 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_qscheme = ( - torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine - ) - activation_observer_kwargs = dict( - observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - dtype=torch.quint8, - qscheme=activation_qscheme, - reduce_range=reduce_range, - ) - activation_observer_kwargs.update(activation_qconfig_kwargs or {}) - activation_observer = torch_quantization.FakeQuantize.with_args( - **activation_observer_kwargs, + activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, + activation_qconfig_kwargs) + weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, ) - weight_qscheme = ( - torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + + +def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): + qscheme = ( + torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - weight_observer_kwargs = dict( + quant_min, quant_max = compute_range(dtype, bits) + observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=-128, - quant_max=127, - dtype=torch.qint8, - qscheme=weight_qscheme, + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + qscheme=qscheme, reduce_range=reduce_range, ) - - weight_observer_kwargs.update(weight_qconfig_kwargs or {}) - weight_observer = torch_quantization.FakeQuantize.with_args( - **weight_observer_kwargs, - ) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + observer_kwargs.update(qconfig_kwargs or {}) + observer = torch_quantization.FakeQuantize.with_args( + **observer_kwargs, ) + return observer + def fix_observer_quant_range(module: Module): """ @@ -642,7 +648,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -651,9 +657,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -662,14 +668,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, "freeze_bn_stats"): + if hasattr(module, 'freeze_bn_stats'): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -704,14 +710,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -748,11 +754,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -782,17 +788,24 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) + quant_min = -2 ** (bits - 1) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index acbae885d71..5a5e1913b18 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -55,7 +55,6 @@ freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, - get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -151,6 +150,7 @@ def __init__( exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + tensorrt: Optional[bool] = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -187,6 +187,7 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -234,9 +235,10 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = ( - self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" - ) + if self._tensorrt: + fuse_fn = 'no_fuse' + else: + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -360,6 +362,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -505,12 +508,11 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) - # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == "conv_bn_relus": + if self.model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -524,29 +526,16 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) - activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() - weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() - to_remove_layer_name = [] if not self._quantize_linear_output_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -554,10 +543,21 @@ def _enable_module_qat(self, module: Module): configure_module_bn_wrappers(module) # prepare each module / submodule for quantization + if self.tensorrt: + _symmetric_activations = True + _activations_dtype = torch.qint8 + else: + _symmetric_activations = False + _activations_dtype = torch.quint8 + qconfig = get_qat_qconfig( + symmetric_activations=_symmetric_activations, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=_activations_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) @@ -583,7 +583,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude if self._exclude_module_types: @@ -653,16 +653,6 @@ def _calibrate(self, module): if module_training: module.train() - def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" - ) - - def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.weight_qconfig_kwargs, self.weight_bits, "symmetric" - ) - def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From 9274d5dd5ee890f3fea6edf1d6fad9f55c06a456 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Sun, 20 Mar 2022 11:42:14 -0400 Subject: [PATCH 058/218] Added support to TensorRT quantization --- .../sparsification/quantization/helpers.py | 166 ++++++++++++++++-- .../quantization/modifier_quantization.py | 61 +++++-- 2 files changed, 195 insertions(+), 32 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 8ae045de9e8..027c7514c32 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -208,9 +208,15 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - reduce_range: bool = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -232,6 +238,18 @@ def from_module( else {} ) + qat_wrapper_kwargs["symmetric_activations"] = ( + symmetric_activations + if "symmetric_activations" not in qat_wrapper_kwargs + else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] + ) + + qat_wrapper_kwargs["symmetric_weights"] = ( + symmetric_weights or False + if "symmetric_weights" not in qat_wrapper_kwargs + else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] + ) + qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -251,6 +269,30 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + qat_wrapper_kwargs["activation_dtype"] = ( + activation_dtype + if "activation_dtype" not in qat_wrapper_kwargs + else activation_dtype or qat_wrapper_kwargs["activation_dtype"] + ) + + qat_wrapper_kwargs["weight_dtype"] = ( + weight_dtype + if "weight_dtype" not in qat_wrapper_kwargs + else weight_dtype or qat_wrapper_kwargs["weight_dtype"] + ) + + qat_wrapper_kwargs["activation_bits"] = ( + activation_bits + if "activation_bits" not in qat_wrapper_kwargs + else activation_bits or qat_wrapper_kwargs["activation_bits"] + ) + + qat_wrapper_kwargs["weight_bits"] = ( + weight_bits + if "weight_bits" not in qat_wrapper_kwargs + else weight_bits or qat_wrapper_kwargs["weight_bits"] + ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -266,9 +308,15 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -288,25 +336,43 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn + self._symmetric_activations = symmetric_activations + self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._activation_dtype = activation_dtype + self._weight_dtype = weight_dtype + self._activation_bits = activation_bits + self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -390,9 +456,15 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -422,11 +494,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) + if symmetric_activations is None: + _symmetric_activations = qconfig == "symmetric" + else: + _symmetric_activations = symmetric_activations + qconfigs[idx] = get_qat_qconfig( - symmetric_activations=(qconfig == "symmetric"), + symmetric_activations=_symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) return qconfigs @@ -463,9 +545,15 @@ def configure_module_bn_wrappers(module: Module): def configure_module_qat_wrappers( module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -490,29 +578,43 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) -def compute_range(dtype: torch.dtype, bits: int): +def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + dtype = dtype if dtype else torch.quint8 + bits = bits if bits else 8 if dtype == torch.qint8: - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 + quant_min = -(2 ** (bits - 1)) + quant_max = (2 ** (bits - 1)) - 1 elif dtype == torch.quint8: quant_min = 0 - quant_max = 2 ** bits - 1 + quant_max = (2 ** bits) - 1 - return quant_min, quant_max + return quant_min, quant_max, dtype def configure_module_default_qconfigs(module: Module): @@ -575,15 +677,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = torch.quint8, - weight_dtype: Optional[torch.dtype] = torch.qint8, - activation_bits: Optional[int] = 8, - weight_bits: Optional[int] = 8, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -606,18 +708,28 @@ def get_qat_qconfig( """ activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, activation_qconfig_kwargs) - weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + if symmetric_weights is None: + _symmetric_weights = True + else: + _symmetric_weights = symmetric_weights + + if weight_dtype is None: + _weight_dtype = torch.qint8 + else: + _weight_dtype = weight_dtype + + weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): +def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): qscheme = ( torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - quant_min, quant_max = compute_range(dtype, bits) + quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, quant_min=quant_min, @@ -756,9 +868,15 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -776,11 +894,21 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: + if symmetric_weights is None: + _symmetric_weights = False + else: + _symmetric_weights = symmetric_weights + qconfig = get_qat_qconfig( - symmetric_weights=False, + symmetric_activations=symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 5a5e1913b18..27c5a4c336e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -147,10 +147,10 @@ def __init__( weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_batchnorm: bool = True, - exclude_module_types: Union[List[str], None] = None, + exclude_module_types: Optional[List[str]] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: Optional[bool] = False, + tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -379,7 +379,15 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - return self._weight_qconfig_kwargs + if "observer" in self._weight_qconfig_kwargs: + kwargs = self._weight_qconfig_kwargs.copy() + if kwargs["observer"] == "minmaxobserver": + kwargs["observer"] = torch_quantization.MinMaxObserver + return kwargs + else: + return self._weight_qconfig_kwargs + + @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -389,6 +397,15 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps + @ModifierProp() + def tensorrt(self) -> Dict[str, Any]: + """ + :return: Dictionary with correct quant_min, quant_max, and dtype values + for activations + + """ + return self._tensorrt + def initialize( self, module: Module, @@ -545,17 +562,23 @@ def _enable_module_qat(self, module: Module): # prepare each module / submodule for quantization if self.tensorrt: _symmetric_activations = True - _activations_dtype = torch.qint8 + _activation_dtype = torch.qint8 + _symmetric_weights = True + _weight_dtype = torch.qint8 else: - _symmetric_activations = False - _activations_dtype = torch.quint8 + _symmetric_activations = None + _activation_dtype = None + _symmetric_weights = None + _weight_dtype = None qconfig = get_qat_qconfig( symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=_activations_dtype, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, activation_bits=self.activation_bits, weight_bits=self.weight_bits ) @@ -563,9 +586,15 @@ def _enable_module_qat(self, module: Module): # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -594,9 +623,15 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, ) # propagate custom quant min/max range from FakeQuantize to Observer objects From e62ba98ce157784d5006a54d605922445fde2cc1 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 21 Mar 2022 19:16:26 -0400 Subject: [PATCH 059/218] Included check to account for when weight_qconfig_kwatgs is None. --- .../sparsification/quantization/modifier_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 27c5a4c336e..a306f4d8e73 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -379,7 +379,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if "observer" in self._weight_qconfig_kwargs: + if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver From e6e7aae4864131922733e605a954a34666934711 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 14:20:19 -0400 Subject: [PATCH 060/218] Modified argument names for backwards compatibility. --- .../quantization/modifier_quantization.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index a306f4d8e73..73a50e0f9c4 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -141,8 +141,8 @@ def __init__( model_fuse_fn_kwargs: Dict[str, Any] = None, quantize_embeddings: bool = True, reduce_range: bool = False, - quantize_linear_output_activations: bool = False, - quantize_conv_output_activations: bool = False, + quantize_linear_activations: bool = False, + quantize_conv_activations: bool = False, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, @@ -174,8 +174,8 @@ def __init__( self._freeze_bn_stats_epoch = freeze_bn_stats_epoch self._quantize_embeddings = quantize_embeddings self._reduce_range = reduce_range - self._quantize_linear_output_activations = quantize_linear_output_activations - self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_linear_activations = quantize_linear_activations + self._quantize_conv_activations = quantize_conv_activations self._activation_bits = activation_bits self._weight_bits = weight_bits self._exclude_batchnorm = exclude_batchnorm @@ -320,7 +320,7 @@ def reduce_range(self) -> bool: return self._reduce_range @ModifierProp() - def quantize_linear_output_activations(self) -> bool: + def quantize_linear_activations(self) -> bool: """ :return: if False, FakeQuantize ops will not be run for activations of fully connected layers. this is important for quantizing @@ -328,15 +328,15 @@ def quantize_linear_output_activations(self) -> bool: are kept at 32 bits of precision and fake quantizing the outputs harm training recovery """ - return self._quantize_linear_output_activations + return self._quantize_linear_activations @ModifierProp() - def quantize_conv_output_activations(self) -> bool: + def quantize_conv_activations(self) -> bool: """ :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_conv_output_activations + return self._quantize_conv_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -544,10 +544,10 @@ def _enable_module_qat(self, module: Module): module_fuse_fn(**self._model_fuse_fn_kwargs) to_remove_layer_name = [] - if not self._quantize_linear_output_activations: + if not self._quantize_linear_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) - if not self._quantize_conv_output_activations: + if not self._quantize_conv_activations: to_remove_layer_name.extend( ["Conv1d", "Conv2d", "Conv3d", "ConvBn1d", "ConvBn2d", "ConvBn3d", From a87fb618a8a223cc42ced679f6e6351b3c4d7c0e Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:40:51 -0400 Subject: [PATCH 061/218] Updated documentation to reflect changes. --- .../sparsification/quantization/helpers.py | 118 ++++++++++++------ 1 file changed, 81 insertions(+), 37 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 027c7514c32..bc9aeb6d58c 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -69,8 +69,14 @@ else None ) - +# class BNWrapper(Module): + """ + Wraps BatchNormalization module to expose methods needed to enable + freezing/unfreezing of statistics + + :param module: BatchNormalization module to be wrapped + """ def __init__(self, module: Module): super().__init__() self.bn = module @@ -220,14 +226,25 @@ def from_module( ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for - :param reduce_range: if True, the quantization range will be reduced by one - bit. This may prevent overflow issues with model execution on certain - hardware. Default is None, will only override qat_wrapper_kwargs if set - to a bool value + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. :return: QATWrapper object created using the given Module as the forward function. Will attempt to find any other named parameter of the QATWrapper constructor from the attributes of the given Module @@ -293,6 +310,7 @@ def from_module( else weight_bits or qat_wrapper_kwargs["weight_bits"] ) + # Remove qconfig from wrapped layer to avoid duplicate quantization module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -516,19 +534,10 @@ def _load_qconfigs( def configure_module_bn_wrappers(module: Module): """ - if any submodule of the given module has the attribute wrap_qat == True, - then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. - Other named kwargs to the QATWrapper constructor must be contained in a dictionary - under an attributed named `qat_wrapper_kwargs` + Wrap any BatchNormalization modules that are not fused with convolutions + with BNWrapper to enable freezing/unfreezing of BN statistics :param module: module to potentially wrap the submodules of - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required if not hasattr(module, 'freeze_bn_stats'): @@ -562,14 +571,25 @@ def configure_module_qat_wrappers( under an attributed named `qat_wrapper_kwargs` :param module: module to potentially wrap the submodules of + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} - """ + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. """ # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: @@ -605,6 +625,13 @@ def configure_module_qat_wrappers( def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + """ + compute quantization limits depending on data type and number of bits + + :param dtype: data type. If None dtype is set to torch.quint8. + :param bits: number of bits. If None is set to 8. + :return: minimum limit, maximum limit, data type + """ dtype = dtype if dtype else torch.quint8 bits = bits if bits else 8 if dtype == torch.qint8: @@ -689,18 +716,24 @@ def get_qat_qconfig( ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric - UINT8 quantization range with zero point set to 128. Otherwise activations - will use asymmetric quantization with any zero point. Default is False + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. :param symmetric_weights: if True, weights will have a symmetric - INT8 quantization range with zero point set to 0. Otherwise activations - will use asymmetric quantization with any zero point. Default is True + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. :return: A QAT fake quantization config for symmetric weight quantization and asymmetric activation quantization. The difference between this and torch.quantization.default_qat_qconfig is that the activation observer @@ -885,14 +918,25 @@ def prepare_embeddings_qat( :param module: module to run QAT for the embeddings of :param qconfig: qconfig to generate the fake quantize ops from. Default uses INT8 asymmetric range - :param activation_qconfig_kwargs: additional kwargs for quantizing activations. - Default is {}. - :param weight_qconfig_kwargs: additional kwargs for quantizing the weights. - Default is {}. + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware. - Default is False - """ + Default is False. + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. """ if qconfig is None: if symmetric_weights is None: _symmetric_weights = False From 6d8ee7d6e05664d6733a03e26b480322e26f14cc Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:40:57 -0400 Subject: [PATCH 062/218] Updated documentation to reflect changes. --- .../quantization/modifier_quantization.py | 59 ++++++++++++------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 73a50e0f9c4..4f912b3d8bb 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -113,21 +113,26 @@ class QuantizationModifier(ScheduledModifier): :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware Default is False - :param quantize_linear_activations: if False, FakeQuantize ops will not be run - for activations of fully connected layers. this is important for quantizing - transformer based models such as BERT where the quantized MatMul outputs - are kept at 32 bits of precision and fake quantizing the outputs harm training - recovery. Default is True + :param quantize_linear_activations: if True, FakeQuantize ops will be run + for output activations of fully connected layers. Default is False. + :param quantize_conv_activations: if True, FakeQuantize ops will be run + for output activations of convolutional layers. Default is False. :param activation_bits: Number of bits to use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. + activations. Default is None, which will quantize activations to 8 bits. + :param weight_bits: Number of bits to use for setting quant min/max values for + weights. Default is None, which will quantize weights to 8 bits. :param num_calibration_steps: Number of steps to run post training calibration for. - When None, the entire calibration_dataloader is used + When None, the entire calibration_dataloader is used + :param exclude_batchnorm: If True, do not propagate quantization qconfigs to + batch-normalization modules :param exclude_module_types: optional list of module class names to not propagate quantization configs to. Default is None :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. + weights. + :param tenssorrt: if True sets quantization configuration for compatibility with + explict quantization as supported by TensorRT 8.2. """ def __init__( @@ -232,11 +237,12 @@ def submodules(self, value: Union[List[str], None]): def model_fuse_fn_name(self) -> Union[str, None]: """ :return: Name of model function to fuse the model in place prior - to performing QAT. None to uses the default function + to performing QAT. None sets to default function. + If tensorrt flag is True, default is 'no_fuse', otherwise `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ if self._tensorrt: - fuse_fn = 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' else: fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @@ -322,19 +328,16 @@ def reduce_range(self) -> bool: @ModifierProp() def quantize_linear_activations(self) -> bool: """ - :return: if False, FakeQuantize ops will not be run - for activations of fully connected layers. this is important for quantizing - transformer based models such as BERT where the quantized MatMul outputs - are kept at 32 bits of precision and fake quantizing the outputs harm - training recovery + :return: if True, FakeQuantize ops will be run for output activations + of fully connected layers """ return self._quantize_linear_activations @ModifierProp() def quantize_conv_activations(self) -> bool: """ - :return: if False, FakeQuantize ops will not be run - for activations of convolutional layers. + :return: if True, FakeQuantize ops will be run for output activations + of convolutional layers """ return self._quantize_conv_activations @@ -358,7 +361,7 @@ def activation_bits(self) -> Optional[int]: def weight_bits(self) -> Optional[int]: """ :return: Number of bits to be use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. + weights. Default is None, which will quantize weights to 8 bits. """ return self._weight_bits @@ -543,6 +546,7 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) + # build list of layer types that should not quantize output activations to_remove_layer_name = [] if not self._quantize_linear_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) @@ -557,9 +561,16 @@ def _enable_module_qat(self, module: Module): if len(to_remove_layer_name) == 0: to_remove_layer_name = None + # fix for freezing batchnorm statistics when not fusing BN with convs. + # pytorch only supports freezing batchnorm statistics for fused modules. + # this fix wraps BN modules adding with a new module class that supports + # methods related to freezing/unfreezing BN statistics. configure_module_bn_wrappers(module) - # prepare each module / submodule for quantization + # set qconfig. + # if tensorrt flag is used, set activation and weights to symmetric + # quantization. + # otherwise, use the default values set in get_qat_qconfig if self.tensorrt: _symmetric_activations = True _activation_dtype = torch.qint8 @@ -582,6 +593,8 @@ def _enable_module_qat(self, module: Module): activation_bits=self.activation_bits, weight_bits=self.weight_bits ) + + # prepare each module / submodule for quantization for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( @@ -596,13 +609,17 @@ def _enable_module_qat(self, module: Module): activation_bits=self.activation_bits, weight_bits=self.weight_bits ) + # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig + # wrap all conv / linear blocks in with quantization observers torch_quantization.propagate_qconfig_(quant_module) configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) + + # Remove output quantization from appropriate modules if to_remove_layer_name: remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) @@ -611,6 +628,8 @@ def _enable_module_qat(self, module: Module): if self._exclude_module_types: to_exclude.extend(self._exclude_module_types) + # if exclude_batchnorm flag is used, add batch norm layers to list of + # modules to exclude qconfig if self._exclude_batchnorm: to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) From 26696f5457d54f01543d36189bae10fbd60f06d0 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:42:27 -0400 Subject: [PATCH 063/218] Updated documentation to reflect changes. --- src/sparseml/pytorch/models/classification/resnet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 21611f211d7..3a7a5169447 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -140,6 +140,10 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): + """ + Wrapper for the FloatFunctional class that enables QATWrapper used to + quantize the first input to the Add operation + """ def __init__(self, num_channels): super().__init__() if FloatFunctional: From c90694cb8b550b426ff3110b74e8c83ca26a825a Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:52:15 -0400 Subject: [PATCH 064/218] Fixed default weights data type. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index bc9aeb6d58c..b3e47162c5e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -751,7 +751,7 @@ def get_qat_qconfig( else: _weight_dtype = weight_dtype - weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + weight_observer = get_observer(_symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, From 681334c75f4c762a8f054c008157a537ef5af1d3 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 17:02:48 -0400 Subject: [PATCH 065/218] Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 54 ++-- .../sparsification/quantization/helpers.py | 247 +++++++++--------- .../quantization/modifier_quantization.py | 44 +++- 3 files changed, 186 insertions(+), 159 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3a7a5169447..cd8b979c3ad 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -144,12 +145,13 @@ class _AddReLU(Module): Wrapper for the FloatFunctional class that enables QATWrapper used to quantize the first input to the Add operation """ + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -209,12 +211,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -325,12 +327,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -441,15 +443,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -483,10 +485,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index b3e47162c5e..c2e21d30a16 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,7 +71,7 @@ else None ) -# + class BNWrapper(Module): """ Wraps BatchNormalization module to expose methods needed to enable @@ -77,6 +79,7 @@ class BNWrapper(Module): :param module: BatchNormalization module to be wrapped """ + def __init__(self, module: Module): super().__init__() self.bn = module @@ -213,16 +216,16 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -241,8 +244,10 @@ def from_module( activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_dtype: quantized activation data type. + Default is torch.quint8. + :param weight_dtype: quantized weights data type. + Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. :param weight_bits: number of bits for weights. Default is 8. :return: QATWrapper object created using the given Module as the forward @@ -277,7 +282,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -315,26 +320,26 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -471,18 +476,18 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -540,29 +545,29 @@ def configure_module_bn_wrappers(module: Module): :param module: module to potentially wrap the submodules of """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, 'freeze_bn_stats'): + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -589,7 +594,7 @@ def configure_module_qat_wrappers( :param activation_dtype: quantized activation data type. Default is torch.quint8. :param weight_dtype: quantized weights data type. Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. """ + :param weight_bits: number of bits for weights. Default is 8.""" # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: @@ -654,7 +659,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -669,9 +674,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -695,7 +700,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -704,15 +709,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -739,8 +744,13 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, - activation_qconfig_kwargs) + activation_observer = get_observer( + symmetric_activations, + activation_dtype, + activation_bits, + reduce_range, + activation_qconfig_kwargs, + ) if symmetric_weights is None: _symmetric_weights = True else: @@ -751,17 +761,23 @@ def get_qat_qconfig( else: _weight_dtype = weight_dtype - weight_observer = get_observer(_symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs) + weight_observer = get_observer( + _symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs + ) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): - qscheme = ( - torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine - ) +def get_observer( + symmetric: Optional[bool], + dtype: Optional[torch.dtype], + bits: Optional[int], + reduce_range: bool, + qconfig_kwargs: Dict[str, Any], +): + qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, @@ -793,7 +809,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -813,14 +829,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, 'freeze_bn_stats'): + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -855,14 +871,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -899,17 +915,17 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -936,7 +952,7 @@ def prepare_embeddings_qat( :param activation_dtype: quantized activation data type. Default is torch.quint8. :param weight_dtype: quantized weights data type. Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. """ + :param weight_bits: number of bits for weights. Default is 8.""" if qconfig is None: if symmetric_weights is None: _symmetric_weights = False @@ -960,24 +976,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -2 ** (bits - 1) + quant_min = -(2 ** (bits - 1)) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 4f912b3d8bb..30e1aefbe15 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -242,9 +242,15 @@ def model_fuse_fn_name(self) -> Union[str, None]: `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ if self._tensorrt: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + ) else: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name + if self._model_fuse_fn_name + else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -365,7 +371,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -382,7 +387,10 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: + if ( + self._weight_qconfig_kwargs is not None + and "observer" in self._weight_qconfig_kwargs + ): kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver @@ -390,8 +398,6 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: else: return self._weight_qconfig_kwargs - - @ModifierProp() def num_calibration_steps(self) -> Optional[int]: """ @@ -532,7 +538,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -553,10 +559,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -591,7 +607,7 @@ def _enable_module_qat(self, module: Module): activation_dtype=_activation_dtype, weight_dtype=_weight_dtype, activation_bits=self.activation_bits, - weight_bits=self.weight_bits + weight_bits=self.weight_bits, ) # prepare each module / submodule for quantization @@ -607,7 +623,7 @@ def _enable_module_qat(self, module: Module): activation_dtype=_activation_dtype, weight_dtype=_weight_dtype, activation_bits=self.activation_bits, - weight_bits=self.weight_bits + weight_bits=self.weight_bits, ) # set quantization config (asymmetric activations, symmetric weights) @@ -631,7 +647,7 @@ def _enable_module_qat(self, module: Module): # if exclude_batchnorm flag is used, add batch norm layers to list of # modules to exclude qconfig if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude if self._exclude_module_types: From f9d882b64af134956260f71824bd46f919645ee1 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 17:53:05 -0400 Subject: [PATCH 066/218] Removed unused method --- .../sparsification/quantization/helpers.py | 31 ------------------- 1 file changed, 31 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index c2e21d30a16..6c30789fbb7 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -41,7 +41,6 @@ "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", - "get_updated_qconfig_kwargs", "fix_observer_quant_range", "freeze_bn_stats", "fuse_module_conv_bn_relus", @@ -975,36 +974,6 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} - - # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): - raise ValueError( - "Cannot override quant_max and quant_min when number of bits is set" - ) - - if bits: - if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) - quant_max = 2 ** (bits - 1) - 1 - dtype = torch.qint8 - else: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 - - qconfig_kwargs.update( - dict( - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - ) - ) - - return qconfig_kwargs - - def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"): embedding.weight_fake_quant = qconfig.weight() From e8743fba4bb542aabda8a7ebd8896a12ba27a82c Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 12:14:39 -0500 Subject: [PATCH 067/218] Removed output quantization from conv layers --- .../sparsification/quantization/helpers.py | 147 +++++++++++------- .../quantization/modifier_quantization.py | 90 ++++++----- 2 files changed, 142 insertions(+), 95 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index dddd41326d2..e10224bbce7 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,20 +31,21 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ + "QUANTIZABLE_MODULE_TYPES", "QATWrapper", "configure_module_qat_wrappers", "configure_module_default_qconfigs", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", + "get_updated_qconfig_kwargs", "fix_observer_quant_range", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] -_QUANTIZABLE_MODULE_TYPES = ( +QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -106,10 +106,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -141,7 +141,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -153,20 +153,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -285,12 +285,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -331,10 +331,10 @@ def _load_qconfigs( def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -383,7 +383,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -398,9 +398,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -424,7 +424,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -433,11 +433,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -509,7 +509,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -534,9 +534,9 @@ def fix_observer_quant_range(module: Module): def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -571,14 +571,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -615,11 +615,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -648,6 +648,37 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits): + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) + + # update qconfig_kwargs for bits + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): + raise ValueError( + "Cannot override quant_max and quant_min when number of bits is set" + ) + + if bits: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( + dict( + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + ) + ) + + return qconfig_kwargs + + def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"): embedding.weight_fake_quant = qconfig.weight() diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 15ad82299d9..79772790566 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,12 +47,14 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( + QUANTIZABLE_MODULE_TYPES, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, fuse_module_conv_bn_relus, get_qat_qconfig, + get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -139,8 +141,11 @@ def __init__( model_fuse_fn_kwargs: Dict[str, Any] = None, quantize_embeddings: bool = True, reduce_range: bool = False, - quantize_linear_activations: bool = True, + quantize_linear_output_activations: bool = False, + quantize_conv_output_activations: bool = False, + quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, @@ -168,8 +173,11 @@ def __init__( self._freeze_bn_stats_epoch = freeze_bn_stats_epoch self._quantize_embeddings = quantize_embeddings self._reduce_range = reduce_range - self._quantize_linear_activations = quantize_linear_activations + self._quantize_linear_output_activations = quantize_linear_output_activations + self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits + self._weight_bits = weight_bits self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -306,7 +314,7 @@ def reduce_range(self) -> bool: return self._reduce_range @ModifierProp() - def quantize_linear_activations(self) -> bool: + def quantize_linear_output_activations(self) -> bool: """ :return: if False, FakeQuantize ops will not be run for activations of fully connected layers. this is important for quantizing @@ -314,7 +322,15 @@ def quantize_linear_activations(self) -> bool: are kept at 32 bits of precision and fake quantizing the outputs harm training recovery """ - return self._quantize_linear_activations + return self._quantize_linear_output_activations + + @ModifierProp() + def quantize_conv_output_activations(self) -> bool: + """ + :return: if False, FakeQuantize ops will not be run + for activations of convolutional layers. + """ + return self._quantize_linear_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -332,6 +348,15 @@ def activation_bits(self) -> Optional[int]: """ return self._activation_bits + @ModifierProp() + def weight_bits(self) -> Optional[int]: + """ + :return: Number of bits to be use for setting quant min/max values for + activations. Default is None, which will quantize activations to 8 bits. + """ + return self._weight_bits + + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -391,7 +416,10 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if name in self._submodules: + if ( + type(submodule) in QUANTIZABLE_MODULE_TYPES + and name in self._submodules + ): self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -499,12 +527,25 @@ def _enable_module_qat(self, module: Module): fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() + weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() + + to_remove_layer_name = [] + if not self._quantize_linear_output_activations: + to_remove_layer_name.extend(["Linear", "LinearReLu"]) + + if not self._quantize_conv_output_activations: + to_remove_layer_name.extend( + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + ) # prepare each module / submodule for quantization qconfig = get_qat_qconfig( reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) @@ -512,7 +553,7 @@ def _enable_module_qat(self, module: Module): quant_module, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -521,9 +562,7 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - - if not self._quantize_linear_activations: - remove_activation_qat_by_layer_name(quant_module, ["Linear"]) + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types if self._exclude_module_types: @@ -536,7 +575,7 @@ def _enable_module_qat(self, module: Module): module, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # propagate custom quant min/max range from FakeQuantize to Observer objects @@ -594,33 +633,10 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - activation_qconfig_kwargs = ( - self.activation_qconfig_kwargs.copy() - if self.activation_qconfig_kwargs - else {} - ) - - # update qconfig_kwargs for activation_bits - if self.activation_bits and ( - activation_qconfig_kwargs.get("quant_min") - or activation_qconfig_kwargs.get("quant_max") - ): - raise ValueError( - "Cannot override quant_max and quant_min with activation_bits enabled" - ) + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) - if self.activation_bits: - quant_min = 0 - quant_max = 2 ** self.activation_bits - 1 - dtype = torch.quint8 - activation_qconfig_kwargs.update( - dict( - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - ) - ) - return activation_qconfig_kwargs + def _get_updated_weight_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( From 6ec4aa719190a53a4a305a9a8248ebbbd0dd0f94 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:35:49 -0500 Subject: [PATCH 068/218] Added _Add_ReLU module that enables QATWrapper for quantizaiton. --- .../pytorch/models/classification/resnet.py | 94 +++++++++---------- 1 file changed, 46 insertions(+), 48 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 40aef8a3c69..3112da7c2e1 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,13 +41,11 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: FloatFunctional = None - __all__ = [ "ResNetSectionSettings", "ResNet", @@ -141,6 +139,23 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: return in_channels != out_channels or stride > 1 +class _AddReLU(Module): + def __init__(self): + super().__init__() + if FloatFunctional: + self.functional = FloatFunctional() + self.wrap_qat = True + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + else: + self.functional = ReLU(num_channels=out_channels, inplace=True) + + def forward(self, x, y): + if isinstance(self.functional, FloatFunctional): + return self.functional.add_relu(x, y) + else: + return self.functional(x + y) + + class _BasicBlock(Module): def __init__(self, in_channels: int, out_channels: int, stride: int = 1): super().__init__() @@ -164,11 +179,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = ( - FloatFunctional() - if FloatFunctional is not None - else ReLU(num_channels=out_channels, inplace=True) - ) + self.add_relu = _AddReLU() self.initialize() @@ -181,12 +192,7 @@ def forward(self, inp: Tensor): out = self.bn2(out) identity_val = self.identity(inp) if self.identity is not None else inp - - if isinstance(self.add_relu, FloatFunctional): - out = self.add_relu.add_relu(out, identity_val) - else: - out += identity_val - out = self.add_relu(out) + out = self.add_relu(identity_val, out) return out @@ -199,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -230,11 +236,7 @@ def __init__( else None ) - self.add_relu = ( - FloatFunctional() - if FloatFunctional is not None - else ReLU(num_channels=out_channels, inplace=True) - ) + self.add_relu = _AddReLU() self.initialize() @@ -252,11 +254,7 @@ def forward(self, inp: Tensor): identity_val = self.identity(inp) if self.identity is not None else inp - if isinstance(self.add_relu, FloatFunctional): - out = self.add_relu.add_relu(out, identity_val) - else: - out += identity_val - out = self.add_relu(out) + out = self.add_relu(identity_val, out) return out @@ -323,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -439,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -481,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 2368c17c5121f9160c08fa9cb93e81856ce1b24b Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:36:37 -0500 Subject: [PATCH 069/218] Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. --- .../sparsification/quantization/helpers.py | 6 +-- .../quantization/modifier_quantization.py | 39 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index e10224bbce7..ec69ded82c8 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,7 +32,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ - "QUANTIZABLE_MODULE_TYPES", "QATWrapper", "configure_module_qat_wrappers", "configure_module_default_qconfigs", @@ -45,7 +44,7 @@ "prepare_embeddings_qat", ] -QUANTIZABLE_MODULE_TYPES = ( +_QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -150,6 +149,7 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -398,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in QUANTIZABLE_MODULE_TYPES + type(module) in _QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 79772790566..f914b1f2b91 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,7 +47,6 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( - QUANTIZABLE_MODULE_TYPES, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, @@ -94,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use - the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' + to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,10 +142,10 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, - quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, + exclude_batchnorm: bool = True, exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, @@ -175,9 +174,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations - self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits + self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -233,7 +232,8 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - return self._model_fuse_fn_name + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + return fuse_fn @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -416,10 +416,7 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if ( - type(submodule) in QUANTIZABLE_MODULE_TYPES - and name in self._submodules - ): + if name in self._submodules: self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -509,10 +506,10 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if ( - self._model_fuse_fn_name is not None - and self._model_fuse_fn_name != "no_fuse" - ): # module class fn + if self._model_fuse_fn_name == 'conv_bn_relus': + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -522,16 +519,13 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) - elif self._model_fuse_fn_name is None: # default auto fn - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLu"]) + to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -565,8 +559,15 @@ def _enable_module_qat(self, module: Module): remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types + to_exclude = [] if self._exclude_module_types: - self._strip_excluded_module_qconfigs(module) + to_exclude.extend(self._exclude_module_types) + + if self._exclude_batchnorm: + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + + self._exclude_module_types = to_exclude + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) From 67347f7b73f959f2739ad05ad791d786544ce7e8 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:40:31 -0500 Subject: [PATCH 070/218] Minor fixes. Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 61 ++++----- .../sparsification/quantization/helpers.py | 123 +++++++++--------- .../quantization/modifier_quantization.py | 33 +++-- 3 files changed, 112 insertions(+), 105 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3112da7c2e1..be4182891d6 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -140,14 +141,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self): + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: - self.functional = ReLU(num_channels=out_channels, inplace=True) + self.functional = ReLU(num_channels=num_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +180,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -205,12 +206,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -236,7 +237,7 @@ def __init__( else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -321,12 +322,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -437,15 +438,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -479,10 +480,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index ec69ded82c8..2c1ac640d6e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_qat_wrappers", @@ -105,10 +107,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -140,7 +142,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -153,20 +155,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -285,12 +287,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -331,10 +333,10 @@ def _load_qconfigs( def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -383,7 +385,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -398,9 +400,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -424,7 +426,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -433,11 +435,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -509,7 +511,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -534,9 +536,9 @@ def fix_observer_quant_range(module: Module): def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -571,14 +573,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -615,11 +617,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -649,17 +651,10 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index f914b1f2b91..637bf7e52dd 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -93,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' - to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as + 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -232,7 +232,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" return fuse_fn @model_fuse_fn_name.setter @@ -356,7 +356,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -506,7 +505,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': + if self._model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -529,10 +528,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) # prepare each module / submodule for quantization @@ -564,7 +573,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -634,7 +643,9 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits + ) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From ad5f938896fde70c39c3a2243337536cd0866740 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:02:14 -0500 Subject: [PATCH 071/218] Added support to freezing bn stats. --- .../sparsification/quantization/helpers.py | 209 +++++++++++++----- .../quantization/modifier_quantization.py | 37 ++-- 2 files changed, 164 insertions(+), 82 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 2c1ac640d6e..a44369550b1 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,16 +31,17 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", - "configure_module_qat_wrappers", + "configure_module_bn_wrappers", "configure_module_default_qconfigs", + "configure_module_qat_wrappers", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", + "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] @@ -69,6 +69,54 @@ else None ) +_BN_MODULE_TYPES = ( + { + # Conv based layers + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.ConvBnReLU1d, + nni.ConvBnReLU2d, + nni.ConvBnReLU3d, + } + if nni # nni will always import if torch.quantization is available + else {} +) + + +class BNWrapper(Module): + def __init__(self, module: Module): + super().__init__() + self.bn = module + self.freeze_bn = False + + def forward(self, x): + return self.bn(x) + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def train(self, mode=True): + if not self.freeze_bn: + self.bn.train() + return self + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + +_BN_MODULE_TYPES.add(BNWrapper) + class QATWrapper(Module): """ @@ -107,10 +155,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -142,7 +190,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -155,20 +203,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -287,12 +335,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -332,11 +380,40 @@ def _load_qconfigs( return qconfigs +def configure_module_bn_wrappers(module: Module): + """ + if any submodule of the given module has the attribute wrap_qat == True, + then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. + Other named kwargs to the QATWrapper constructor must be contained in a dictionary + under an attributed named `qat_wrapper_kwargs` + + :param module: module to potentially wrap the submodules of + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware + Default is False + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. Default is {} + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. Default is {} + """ + # wrap any children of the given module as a QATWrapper if required + if type(module) != BNWrapper: + for child_name, child_module in module.named_children(): + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) + # recurse on child module + configure_module_bn_wrappers(child_module) + + def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -385,7 +462,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -400,9 +477,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -426,7 +503,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -435,11 +512,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -511,7 +588,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -535,10 +612,15 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True +def freeze_bn_stats(module: Module): + if type(module) in _BN_MODULE_TYPES: + module.freeze_bn_stats() + + def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -573,14 +655,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -617,11 +699,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -651,10 +733,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 637bf7e52dd..7eed410b441 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -48,9 +48,11 @@ from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( add_quant_dequant, + configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, + freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, get_updated_qconfig_kwargs, @@ -232,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' return fuse_fn @model_fuse_fn_name.setter @@ -262,7 +264,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """ + """print :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -356,6 +358,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -500,12 +503,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) + quant_module.apply(freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == "conv_bn_relus": + if self._model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -528,22 +531,14 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) + configure_module_bn_wrappers(module) + # prepare each module / submodule for quantization qconfig = get_qat_qconfig( reduce_range=self._reduce_range, @@ -573,7 +568,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -643,9 +638,7 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits - ) + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From 73701a8b7162b685085e4ae42ddcd7e874ab3528 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:57:15 -0500 Subject: [PATCH 072/218] Added mode argument to wrapping of train function in BNWrapper --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index a44369550b1..48ed0708eae 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -106,7 +106,7 @@ def reset_running_stats(self): def train(self, mode=True): if not self.freeze_bn: - self.bn.train() + self.bn.train(mode) return self def update_bn_stats(self): From 82df8caad58687ae30a4f0a292deab7b825e0aba Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:08:20 -0500 Subject: [PATCH 073/218] Set BN fusing back as default. --- .../sparsification/quantization/modifier_quantization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 7eed410b441..37307e38863 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -508,8 +508,8 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True + if self.model_fuse_fn_name == 'conv_bn_relus': + self.model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 1a778bda379a97b3df76f38b898553636e662e1c Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:19:09 -0500 Subject: [PATCH 074/218] Set BN fusing back as default. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- .../sparsification/quantization/modifier_quantization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 48ed0708eae..71f6553fc44 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -397,7 +397,7 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) != BNWrapper: + if type(module) not in _BN_MODULE_TYPES: for child_name, child_module in module.named_children(): if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: setattr( diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 37307e38863..2a35ebd2aaf 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -509,7 +509,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs if self.model_fuse_fn_name == 'conv_bn_relus': - self.model_fuse_fn_kwargs["inplace"] = True + self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 600eaa84a9e23c74a275bdfee8d63fef0ce3bc40 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 11 Mar 2022 19:24:03 -0500 Subject: [PATCH 075/218] Fixed custom freeze_bn_stats. --- .../sparsification/quantization/helpers.py | 245 +++++++++++------- .../quantization/modifier_quantization.py | 46 +++- 2 files changed, 182 insertions(+), 109 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 71f6553fc44..e09c0e29690 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,23 +71,6 @@ else None ) -_BN_MODULE_TYPES = ( - { - # Conv based layers - nni.ConvBn1d, - nni.ConvBn2d, - nni.ConvBn3d, - nni.ConvReLU1d, - nni.ConvReLU2d, - nni.ConvReLU3d, - nni.ConvBnReLU1d, - nni.ConvBnReLU2d, - nni.ConvBnReLU3d, - } - if nni # nni will always import if torch.quantization is available - else {} -) - class BNWrapper(Module): def __init__(self, module: Module): @@ -93,6 +78,78 @@ def __init__(self, module: Module): self.bn = module self.freeze_bn = False + @property + def running_mean(self): + return self.bn.running_mean + + @running_mean.setter + def running_mean(self, value): + self.bn.running_mean = value + + @property + def running_var(self): + return self.bn.running_var + + @running_var.setter + def running_var(self, value): + self.bn.running_var = value + + @property + def weight(self): + return self.bn.weight + + @weight.setter + def weight(self, value): + self.bn.weight = value + + @property + def bias(self): + return self.bn.bias + + @bias.setter + def bias(self, value): + self.bn.bias = value + + @property + def gamma(self): + return self.bn.gamma + + @gamma.setter + def gamma(self, value): + self.bn.gamma = value + + @property + def beta(self): + return self.bn.beta + + @beta.setter + def beta(self, value): + self.bn.beta = value + + @property + def num_batches_tracked(self): + return self.bn.num_batches_tracked + + @num_batches_tracked.setter + def num_batches_tracked(self, value): + self.bn.num_batches_tracked = value + + @property + def eps(self): + return self.bn.eps + + @eps.setter + def eps(self, value): + self.bn.eps = value + + @property + def momentum(self): + return self.bn.momentum + + @momentum.setter + def momentum(self, value): + self.bn.momentum = value + def forward(self, x): return self.bn(x) @@ -115,9 +172,6 @@ def update_bn_stats(self): return self -_BN_MODULE_TYPES.add(BNWrapper) - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -155,10 +209,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -190,7 +244,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -203,20 +257,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -335,12 +389,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -397,23 +451,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) not in _BN_MODULE_TYPES: + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -462,7 +516,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -477,9 +531,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -503,7 +557,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -512,11 +566,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -588,7 +642,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -613,14 +667,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if type(module) in _BN_MODULE_TYPES: + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -655,14 +709,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -699,11 +753,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -732,26 +786,25 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 + if mode == "symmetric": + quant_min = -(2 ** (bits - 1)) + quant_max = 2 ** (bits - 1) - 1 + dtype = torch.qint8 + else: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 2a35ebd2aaf..acbae885d71 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,9 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -332,7 +334,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_linear_output_activations + return self._quantize_conv_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -358,7 +360,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -504,11 +505,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) + # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -531,11 +533,23 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) + if len(to_remove_layer_name) == 0: + to_remove_layer_name = None configure_module_bn_wrappers(module) @@ -560,7 +574,8 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + if to_remove_layer_name: + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types to_exclude = [] @@ -568,10 +583,11 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude - self._strip_excluded_module_qconfigs(module) + if self._exclude_module_types: + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) @@ -638,10 +654,14 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" + ) def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + return get_updated_qconfig_kwargs( + self.weight_qconfig_kwargs, self.weight_bits, "symmetric" + ) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( From 86c1233b389f2fb8afa1fcd052fc05dc346bfb64 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 14 Mar 2022 15:35:52 -0400 Subject: [PATCH 076/218] Temporary files for evaluating changes to graphs. --- sandbox/quantization_recipe.yaml | 7 +++ sandbox/quantization_test.py | 23 ++++++++ .../pytorch/models/classification/resnet.py | 53 +++++++++---------- 3 files changed, 56 insertions(+), 27 deletions(-) create mode 100644 sandbox/quantization_recipe.yaml create mode 100644 sandbox/quantization_test.py diff --git a/sandbox/quantization_recipe.yaml b/sandbox/quantization_recipe.yaml new file mode 100644 index 00000000000..411dd6f025a --- /dev/null +++ b/sandbox/quantization_recipe.yaml @@ -0,0 +1,7 @@ +quantization_modifiers: + - !QuantizationModifier + start_epoch: -1.0 + model_fuse_fn_name: no_fuse + submodules: + - input + - sections diff --git a/sandbox/quantization_test.py b/sandbox/quantization_test.py new file mode 100644 index 00000000000..ea6fba5acd5 --- /dev/null +++ b/sandbox/quantization_test.py @@ -0,0 +1,23 @@ +import torch +from sparseml.pytorch.utils import ModuleExporter +from sparseml.pytorch.models import ModelRegistry +from sparseml.pytorch.optim import ScheduledModifierManager + +model = ModelRegistry.create( + key='resnet50', + pretrained=False, + pretrained_dataset="imagenet", + num_classes=1000 +) + + +ScheduledModifierManager.from_yaml("quantization_recipe.yaml").apply(model, epoch=float("inf")) + +print(model) + +exporter = ModuleExporter(model, ".") +exporter.export_onnx( + torch.randn(1, 3, 224, 224), + "quantized_test.onnx", + convert_qat=False, +) \ No newline at end of file diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index be4182891d6..21611f211d7 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,7 +41,6 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: @@ -146,7 +145,7 @@ def __init__(self, num_channels): if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -206,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -322,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -438,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -480,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 1c434c31243774b9a7f5ad596147005363aa8bf0 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 17 Mar 2022 11:51:50 -0400 Subject: [PATCH 077/218] Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. --- .../sparsification/quantization/helpers.py | 207 ++++++++++-------- .../quantization/modifier_quantization.py | 58 ++--- 2 files changed, 134 insertions(+), 131 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index e09c0e29690..57b919470e4 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,7 +31,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -209,10 +207,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -244,7 +242,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -257,20 +255,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -389,12 +387,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -451,23 +449,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, "freeze_bn_stats"): + if not hasattr(module, 'freeze_bn_stats'): for child_name, child_module in module.named_children(): - if type(child_module) in [ - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - ]: - setattr(module, child_name, BNWrapper(child_module)) + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -506,6 +504,17 @@ def configure_module_qat_wrappers( ) +def compute_range(dtype: torch.dtype, bits: int): + if dtype == torch.qint8: + quant_min = -2 ** (bits - 1) + quant_max = 2 ** (bits - 1) - 1 + elif dtype == torch.quint8: + quant_min = 0 + quant_max = 2 ** bits - 1 + + return quant_min, quant_max + + def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -516,7 +525,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -531,9 +540,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -557,7 +566,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -566,11 +575,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = torch.quint8, + weight_dtype: Optional[torch.dtype] = torch.qint8, + activation_bits: Optional[int] = 8, + weight_bits: Optional[int] = 8, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -591,42 +604,35 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_qscheme = ( - torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine - ) - activation_observer_kwargs = dict( - observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - dtype=torch.quint8, - qscheme=activation_qscheme, - reduce_range=reduce_range, - ) - activation_observer_kwargs.update(activation_qconfig_kwargs or {}) - activation_observer = torch_quantization.FakeQuantize.with_args( - **activation_observer_kwargs, + activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, + activation_qconfig_kwargs) + weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, ) - weight_qscheme = ( - torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + + +def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): + qscheme = ( + torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - weight_observer_kwargs = dict( + quant_min, quant_max = compute_range(dtype, bits) + observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=-128, - quant_max=127, - dtype=torch.qint8, - qscheme=weight_qscheme, + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + qscheme=qscheme, reduce_range=reduce_range, ) - - weight_observer_kwargs.update(weight_qconfig_kwargs or {}) - weight_observer = torch_quantization.FakeQuantize.with_args( - **weight_observer_kwargs, - ) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + observer_kwargs.update(qconfig_kwargs or {}) + observer = torch_quantization.FakeQuantize.with_args( + **observer_kwargs, ) + return observer + def fix_observer_quant_range(module: Module): """ @@ -642,7 +648,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -667,14 +673,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, "freeze_bn_stats"): + if hasattr(module, 'freeze_bn_stats'): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -709,14 +715,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -753,11 +759,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -787,17 +793,24 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) + quant_min = -2 ** (bits - 1) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index acbae885d71..5a5e1913b18 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -55,7 +55,6 @@ freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, - get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -151,6 +150,7 @@ def __init__( exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + tensorrt: Optional[bool] = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -187,6 +187,7 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -234,9 +235,10 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = ( - self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" - ) + if self._tensorrt: + fuse_fn = 'no_fuse' + else: + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -360,6 +362,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -505,12 +508,11 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) - # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == "conv_bn_relus": + if self.model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -524,29 +526,16 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) - activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() - weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() - to_remove_layer_name = [] if not self._quantize_linear_output_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -554,10 +543,21 @@ def _enable_module_qat(self, module: Module): configure_module_bn_wrappers(module) # prepare each module / submodule for quantization + if self.tensorrt: + _symmetric_activations = True + _activations_dtype = torch.qint8 + else: + _symmetric_activations = False + _activations_dtype = torch.quint8 + qconfig = get_qat_qconfig( + symmetric_activations=_symmetric_activations, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=_activations_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) @@ -583,7 +583,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude if self._exclude_module_types: @@ -653,16 +653,6 @@ def _calibrate(self, module): if module_training: module.train() - def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" - ) - - def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.weight_qconfig_kwargs, self.weight_bits, "symmetric" - ) - def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From b414593cd32c92788b149f722655a3c1b62172a5 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Sun, 20 Mar 2022 11:42:14 -0400 Subject: [PATCH 078/218] Added support to TensorRT quantization --- .../sparsification/quantization/helpers.py | 166 ++++++++++++++++-- .../quantization/modifier_quantization.py | 61 +++++-- 2 files changed, 195 insertions(+), 32 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 57b919470e4..2ae713c16aa 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -208,9 +208,15 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - reduce_range: bool = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -232,6 +238,18 @@ def from_module( else {} ) + qat_wrapper_kwargs["symmetric_activations"] = ( + symmetric_activations + if "symmetric_activations" not in qat_wrapper_kwargs + else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] + ) + + qat_wrapper_kwargs["symmetric_weights"] = ( + symmetric_weights or False + if "symmetric_weights" not in qat_wrapper_kwargs + else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] + ) + qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -251,6 +269,30 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + qat_wrapper_kwargs["activation_dtype"] = ( + activation_dtype + if "activation_dtype" not in qat_wrapper_kwargs + else activation_dtype or qat_wrapper_kwargs["activation_dtype"] + ) + + qat_wrapper_kwargs["weight_dtype"] = ( + weight_dtype + if "weight_dtype" not in qat_wrapper_kwargs + else weight_dtype or qat_wrapper_kwargs["weight_dtype"] + ) + + qat_wrapper_kwargs["activation_bits"] = ( + activation_bits + if "activation_bits" not in qat_wrapper_kwargs + else activation_bits or qat_wrapper_kwargs["activation_bits"] + ) + + qat_wrapper_kwargs["weight_bits"] = ( + weight_bits + if "weight_bits" not in qat_wrapper_kwargs + else weight_bits or qat_wrapper_kwargs["weight_bits"] + ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -266,9 +308,15 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -288,25 +336,43 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn + self._symmetric_activations = symmetric_activations + self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._activation_dtype = activation_dtype + self._weight_dtype = weight_dtype + self._activation_bits = activation_bits + self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -390,9 +456,15 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -422,11 +494,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) + if symmetric_activations is None: + _symmetric_activations = qconfig == "symmetric" + else: + _symmetric_activations = symmetric_activations + qconfigs[idx] = get_qat_qconfig( - symmetric_activations=(qconfig == "symmetric"), + symmetric_activations=_symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) return qconfigs @@ -463,9 +545,15 @@ def configure_module_bn_wrappers(module: Module): def configure_module_qat_wrappers( module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -490,29 +578,43 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) -def compute_range(dtype: torch.dtype, bits: int): +def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + dtype = dtype if dtype else torch.quint8 + bits = bits if bits else 8 if dtype == torch.qint8: - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 + quant_min = -(2 ** (bits - 1)) + quant_max = (2 ** (bits - 1)) - 1 elif dtype == torch.quint8: quant_min = 0 - quant_max = 2 ** bits - 1 + quant_max = (2 ** bits) - 1 - return quant_min, quant_max + return quant_min, quant_max, dtype def configure_module_default_qconfigs(module: Module): @@ -575,15 +677,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = torch.quint8, - weight_dtype: Optional[torch.dtype] = torch.qint8, - activation_bits: Optional[int] = 8, - weight_bits: Optional[int] = 8, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -606,18 +708,28 @@ def get_qat_qconfig( """ activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, activation_qconfig_kwargs) - weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + if symmetric_weights is None: + _symmetric_weights = True + else: + _symmetric_weights = symmetric_weights + + if weight_dtype is None: + _weight_dtype = torch.qint8 + else: + _weight_dtype = weight_dtype + + weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): +def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): qscheme = ( torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - quant_min, quant_max = compute_range(dtype, bits) + quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, quant_min=quant_min, @@ -761,9 +873,15 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -781,11 +899,21 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: + if symmetric_weights is None: + _symmetric_weights = False + else: + _symmetric_weights = symmetric_weights + qconfig = get_qat_qconfig( - symmetric_weights=False, + symmetric_activations=symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 5a5e1913b18..27c5a4c336e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -147,10 +147,10 @@ def __init__( weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_batchnorm: bool = True, - exclude_module_types: Union[List[str], None] = None, + exclude_module_types: Optional[List[str]] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: Optional[bool] = False, + tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -379,7 +379,15 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - return self._weight_qconfig_kwargs + if "observer" in self._weight_qconfig_kwargs: + kwargs = self._weight_qconfig_kwargs.copy() + if kwargs["observer"] == "minmaxobserver": + kwargs["observer"] = torch_quantization.MinMaxObserver + return kwargs + else: + return self._weight_qconfig_kwargs + + @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -389,6 +397,15 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps + @ModifierProp() + def tensorrt(self) -> Dict[str, Any]: + """ + :return: Dictionary with correct quant_min, quant_max, and dtype values + for activations + + """ + return self._tensorrt + def initialize( self, module: Module, @@ -545,17 +562,23 @@ def _enable_module_qat(self, module: Module): # prepare each module / submodule for quantization if self.tensorrt: _symmetric_activations = True - _activations_dtype = torch.qint8 + _activation_dtype = torch.qint8 + _symmetric_weights = True + _weight_dtype = torch.qint8 else: - _symmetric_activations = False - _activations_dtype = torch.quint8 + _symmetric_activations = None + _activation_dtype = None + _symmetric_weights = None + _weight_dtype = None qconfig = get_qat_qconfig( symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=_activations_dtype, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, activation_bits=self.activation_bits, weight_bits=self.weight_bits ) @@ -563,9 +586,15 @@ def _enable_module_qat(self, module: Module): # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -594,9 +623,15 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, ) # propagate custom quant min/max range from FakeQuantize to Observer objects From 694f646298a925b074c174895cc017ff8826db91 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 21 Mar 2022 19:16:26 -0400 Subject: [PATCH 079/218] Included check to account for when weight_qconfig_kwatgs is None. --- .../sparsification/quantization/modifier_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 27c5a4c336e..a306f4d8e73 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -379,7 +379,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if "observer" in self._weight_qconfig_kwargs: + if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver From 50b115144a53da9b11aabd576c3e7ffb7ebc465c Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 12:14:39 -0500 Subject: [PATCH 080/218] Removed output quantization from conv layers --- .../sparsification/quantization/helpers.py | 377 +++--------------- .../quantization/modifier_quantization.py | 130 ++---- 2 files changed, 87 insertions(+), 420 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 2ae713c16aa..75d11c67c31 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,21 +32,20 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ + "QUANTIZABLE_MODULE_TYPES", "QATWrapper", - "configure_module_bn_wrappers", - "configure_module_default_qconfigs", "configure_module_qat_wrappers", + "configure_module_default_qconfigs", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", - "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] -_QUANTIZABLE_MODULE_TYPES = ( +QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -70,106 +69,6 @@ ) -class BNWrapper(Module): - def __init__(self, module: Module): - super().__init__() - self.bn = module - self.freeze_bn = False - - @property - def running_mean(self): - return self.bn.running_mean - - @running_mean.setter - def running_mean(self, value): - self.bn.running_mean = value - - @property - def running_var(self): - return self.bn.running_var - - @running_var.setter - def running_var(self, value): - self.bn.running_var = value - - @property - def weight(self): - return self.bn.weight - - @weight.setter - def weight(self, value): - self.bn.weight = value - - @property - def bias(self): - return self.bn.bias - - @bias.setter - def bias(self, value): - self.bn.bias = value - - @property - def gamma(self): - return self.bn.gamma - - @gamma.setter - def gamma(self, value): - self.bn.gamma = value - - @property - def beta(self): - return self.bn.beta - - @beta.setter - def beta(self, value): - self.bn.beta = value - - @property - def num_batches_tracked(self): - return self.bn.num_batches_tracked - - @num_batches_tracked.setter - def num_batches_tracked(self, value): - self.bn.num_batches_tracked = value - - @property - def eps(self): - return self.bn.eps - - @eps.setter - def eps(self, value): - self.bn.eps = value - - @property - def momentum(self): - return self.bn.momentum - - @momentum.setter - def momentum(self, value): - self.bn.momentum = value - - def forward(self, x): - return self.bn(x) - - def freeze_bn_stats(self): - self.freeze_bn = True - self.bn.training = False - return self - - def reset_running_stats(self): - self.bn.reset_running_stats() - - def train(self, mode=True): - if not self.freeze_bn: - self.bn.train(mode) - return self - - def update_bn_stats(self): - self.freeze_bn = False - self.bn.training = True - return self - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -208,15 +107,9 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, + reduce_range: bool = None, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -238,18 +131,6 @@ def from_module( else {} ) - qat_wrapper_kwargs["symmetric_activations"] = ( - symmetric_activations - if "symmetric_activations" not in qat_wrapper_kwargs - else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] - ) - - qat_wrapper_kwargs["symmetric_weights"] = ( - symmetric_weights or False - if "symmetric_weights" not in qat_wrapper_kwargs - else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] - ) - qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -269,31 +150,6 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) - qat_wrapper_kwargs["activation_dtype"] = ( - activation_dtype - if "activation_dtype" not in qat_wrapper_kwargs - else activation_dtype or qat_wrapper_kwargs["activation_dtype"] - ) - - qat_wrapper_kwargs["weight_dtype"] = ( - weight_dtype - if "weight_dtype" not in qat_wrapper_kwargs - else weight_dtype or qat_wrapper_kwargs["weight_dtype"] - ) - - qat_wrapper_kwargs["activation_bits"] = ( - activation_bits - if "activation_bits" not in qat_wrapper_kwargs - else activation_bits or qat_wrapper_kwargs["activation_bits"] - ) - - qat_wrapper_kwargs["weight_bits"] = ( - weight_bits - if "weight_bits" not in qat_wrapper_kwargs - else weight_bits or qat_wrapper_kwargs["weight_bits"] - ) - - module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -308,15 +164,9 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ): super().__init__() @@ -336,43 +186,25 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn - self._symmetric_activations = symmetric_activations - self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._activation_dtype = activation_dtype - self._weight_dtype = weight_dtype - self._activation_bits = activation_bits - self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -456,15 +288,9 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -494,66 +320,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) - if symmetric_activations is None: - _symmetric_activations = qconfig == "symmetric" - else: - _symmetric_activations = symmetric_activations - qconfigs[idx] = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=symmetric_weights, + symmetric_activations=(qconfig == "symmetric"), reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) return qconfigs -def configure_module_bn_wrappers(module: Module): - """ - if any submodule of the given module has the attribute wrap_qat == True, - then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. - Other named kwargs to the QATWrapper constructor must be contained in a dictionary - under an attributed named `qat_wrapper_kwargs` - - :param module: module to potentially wrap the submodules of - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} - """ - # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, 'freeze_bn_stats'): - for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) - # recurse on child module - configure_module_bn_wrappers(child_module) - - def configure_module_qat_wrappers( module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -578,45 +359,20 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) -def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): - dtype = dtype if dtype else torch.quint8 - bits = bits if bits else 8 - if dtype == torch.qint8: - quant_min = -(2 ** (bits - 1)) - quant_max = (2 ** (bits - 1)) - 1 - elif dtype == torch.quint8: - quant_min = 0 - quant_max = (2 ** bits) - 1 - - return quant_min, quant_max, dtype - - def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -642,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES + type(module) in QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): @@ -677,15 +433,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -706,44 +458,41 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, - activation_qconfig_kwargs) - if symmetric_weights is None: - _symmetric_weights = True - else: - _symmetric_weights = symmetric_weights - - if weight_dtype is None: - _weight_dtype = torch.qint8 - else: - _weight_dtype = weight_dtype - - weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + activation_qscheme = ( + torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine ) - - -def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): - qscheme = ( - torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine - ) - quant_min, quant_max, dtype = compute_range(dtype, bits) - observer_kwargs = dict( + activation_observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - qscheme=qscheme, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=activation_qscheme, reduce_range=reduce_range, ) - observer_kwargs.update(qconfig_kwargs or {}) - observer = torch_quantization.FakeQuantize.with_args( - **observer_kwargs, + activation_observer_kwargs.update(activation_qconfig_kwargs or {}) + activation_observer = torch_quantization.FakeQuantize.with_args( + **activation_observer_kwargs, + ) + weight_qscheme = ( + torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + ) + weight_observer_kwargs = dict( + observer=torch_quantization.MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=weight_qscheme, + reduce_range=reduce_range, ) - return observer + weight_observer_kwargs.update(weight_qconfig_kwargs or {}) + weight_observer = torch_quantization.FakeQuantize.with_args( + **weight_observer_kwargs, + ) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, + ) def fix_observer_quant_range(module: Module): @@ -769,14 +518,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) - or ( # do not propagate default uint8 symmetric range - observer.qscheme == torch.per_tensor_symmetric - and fake_quantize.quant_min == 0 - and fake_quantize.quant_max == 255 - ) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -784,11 +528,6 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True -def freeze_bn_stats(module: Module): - if hasattr(module, 'freeze_bn_stats'): - module.freeze_bn_stats() - - def fuse_module_conv_bn_relus( module: Module, inplace: bool = True, @@ -873,15 +612,9 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -899,28 +632,18 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: - if symmetric_weights is None: - _symmetric_weights = False - else: - _symmetric_weights = symmetric_weights - qconfig = get_qat_qconfig( - symmetric_activations=symmetric_activations, - symmetric_weights=_symmetric_weights, + symmetric_weights=False, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): +def get_updated_qconfig_kwargs(qconfig_kwargs, bits): qconfig_kwargs = ( qconfig_kwargs.copy() if qconfig_kwargs @@ -937,15 +660,9 @@ def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): ) if bits: - if mode == "symmetric": - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 - dtype = torch.qint8 - else: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 - + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index a306f4d8e73..79772790566 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,14 +47,14 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( + QUANTIZABLE_MODULE_TYPES, add_quant_dequant, - configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, - freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, + get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -94,8 +94,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as - 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use + the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,14 +143,13 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, + quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, - exclude_batchnorm: bool = True, - exclude_module_types: Optional[List[str]] = None, + exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -176,9 +175,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits - self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -187,7 +186,6 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -235,11 +233,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - if self._tensorrt: - fuse_fn = 'no_fuse' - else: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' - return fuse_fn + return self._model_fuse_fn_name @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -268,7 +262,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """print + """ :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -336,7 +330,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_conv_output_activations + return self._quantize_linear_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -379,15 +373,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: - kwargs = self._weight_qconfig_kwargs.copy() - if kwargs["observer"] == "minmaxobserver": - kwargs["observer"] = torch_quantization.MinMaxObserver - return kwargs - else: - return self._weight_qconfig_kwargs - - + return self._weight_qconfig_kwargs @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -397,15 +383,6 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps - @ModifierProp() - def tensorrt(self) -> Dict[str, Any]: - """ - :return: Dictionary with correct quant_min, quant_max, and dtype values - for activations - - """ - return self._tensorrt - def initialize( self, module: Module, @@ -439,7 +416,10 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if name in self._submodules: + if ( + type(submodule) in QUANTIZABLE_MODULE_TYPES + and name in self._submodules + ): self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -524,15 +504,15 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(freeze_bn_stats) + quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) - elif self.model_fuse_fn_name != "no_fuse": + if ( + self._model_fuse_fn_name is not None + and self._model_fuse_fn_name != "no_fuse" + ): # module class fn module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -542,10 +522,16 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) + elif self._model_fuse_fn_name is None: # default auto fn + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + + activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() + weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLU"]) + to_remove_layer_name.extend(["Linear", "LinearReLu"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -554,47 +540,20 @@ def _enable_module_qat(self, module: Module): "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) - if len(to_remove_layer_name) == 0: - to_remove_layer_name = None - - configure_module_bn_wrappers(module) # prepare each module / submodule for quantization - if self.tensorrt: - _symmetric_activations = True - _activation_dtype = torch.qint8 - _symmetric_weights = True - _weight_dtype = torch.qint8 - else: - _symmetric_activations = None - _activation_dtype = None - _symmetric_weights = None - _weight_dtype = None - qconfig = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -603,18 +562,9 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - if to_remove_layer_name: - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types - to_exclude = [] - if self._exclude_module_types: - to_exclude.extend(self._exclude_module_types) - - if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) - - self._exclude_module_types = to_exclude if self._exclude_module_types: self._strip_excluded_module_qconfigs(module) @@ -623,15 +573,9 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # propagate custom quant min/max range from FakeQuantize to Observer objects @@ -688,6 +632,12 @@ def _calibrate(self, module): if module_training: module.train() + def _get_updated_activation_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + + def _get_updated_weight_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From 617f60845d420f3b35f212df9238e5a686208ccd Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:35:49 -0500 Subject: [PATCH 081/218] Added _Add_ReLU module that enables QATWrapper for quantization. --- src/sparseml/pytorch/models/classification/resnet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 21611f211d7..3112da7c2e1 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -140,14 +140,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self, num_channels): + def __init__(self): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: - self.functional = ReLU(num_channels=num_channels, inplace=True) + self.functional = ReLU(num_channels=out_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +179,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU(out_channels) + self.add_relu = _AddReLU() self.initialize() @@ -236,7 +236,7 @@ def __init__( else None ) - self.add_relu = _AddReLU(out_channels) + self.add_relu = _AddReLU() self.initialize() From 5e2127895e3835ab772ee8f2f64b32a1df99f430 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:36:37 -0500 Subject: [PATCH 082/218] Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. --- .../sparsification/quantization/helpers.py | 6 +-- .../quantization/modifier_quantization.py | 39 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 75d11c67c31..f28656f1712 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,7 +32,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ - "QUANTIZABLE_MODULE_TYPES", "QATWrapper", "configure_module_qat_wrappers", "configure_module_default_qconfigs", @@ -45,7 +44,7 @@ "prepare_embeddings_qat", ] -QUANTIZABLE_MODULE_TYPES = ( +_QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -150,6 +149,7 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -398,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in QUANTIZABLE_MODULE_TYPES + type(module) in _QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 79772790566..f914b1f2b91 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,7 +47,6 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( - QUANTIZABLE_MODULE_TYPES, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, @@ -94,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use - the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' + to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,10 +142,10 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, - quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, + exclude_batchnorm: bool = True, exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, @@ -175,9 +174,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations - self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits + self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -233,7 +232,8 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - return self._model_fuse_fn_name + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + return fuse_fn @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -416,10 +416,7 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if ( - type(submodule) in QUANTIZABLE_MODULE_TYPES - and name in self._submodules - ): + if name in self._submodules: self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -509,10 +506,10 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if ( - self._model_fuse_fn_name is not None - and self._model_fuse_fn_name != "no_fuse" - ): # module class fn + if self._model_fuse_fn_name == 'conv_bn_relus': + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -522,16 +519,13 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) - elif self._model_fuse_fn_name is None: # default auto fn - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLu"]) + to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -565,8 +559,15 @@ def _enable_module_qat(self, module: Module): remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types + to_exclude = [] if self._exclude_module_types: - self._strip_excluded_module_qconfigs(module) + to_exclude.extend(self._exclude_module_types) + + if self._exclude_batchnorm: + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + + self._exclude_module_types = to_exclude + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) From d98cb52792ee8ca862c28dc90dd5629f3b4fa8fd Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:40:31 -0500 Subject: [PATCH 083/218] Minor fixes. Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 61 +++++---- .../sparsification/quantization/helpers.py | 129 +++++++++--------- .../quantization/modifier_quantization.py | 33 +++-- 3 files changed, 115 insertions(+), 108 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3112da7c2e1..be4182891d6 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -140,14 +141,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self): + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: - self.functional = ReLU(num_channels=out_channels, inplace=True) + self.functional = ReLU(num_channels=num_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +180,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -205,12 +206,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -236,7 +237,7 @@ def __init__( else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -321,12 +322,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -437,15 +438,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -479,10 +480,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index f28656f1712..ef4445a0d5f 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_qat_wrappers", @@ -105,10 +107,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -140,7 +142,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -153,20 +155,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -285,12 +287,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -331,10 +333,10 @@ def _load_qconfigs( def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -383,7 +385,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -398,9 +400,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -424,7 +426,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -433,11 +435,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -509,7 +511,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -518,9 +520,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -529,9 +531,9 @@ def fix_observer_quant_range(module: Module): def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -566,14 +568,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -610,11 +612,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -644,17 +646,10 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index f914b1f2b91..637bf7e52dd 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -93,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' - to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as + 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -232,7 +232,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" return fuse_fn @model_fuse_fn_name.setter @@ -356,7 +356,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -506,7 +505,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': + if self._model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -529,10 +528,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) # prepare each module / submodule for quantization @@ -564,7 +573,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -634,7 +643,9 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits + ) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From aeed0b990d539adda1595565f621b50f53592332 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:02:14 -0500 Subject: [PATCH 084/218] Added support to freezing bn stats. --- .../sparsification/quantization/helpers.py | 215 +++++++++++++----- .../quantization/modifier_quantization.py | 37 ++- 2 files changed, 167 insertions(+), 85 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index ef4445a0d5f..c4f165d23ef 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,16 +31,17 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", - "configure_module_qat_wrappers", + "configure_module_bn_wrappers", "configure_module_default_qconfigs", + "configure_module_qat_wrappers", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", + "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] @@ -69,6 +69,54 @@ else None ) +_BN_MODULE_TYPES = ( + { + # Conv based layers + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.ConvBnReLU1d, + nni.ConvBnReLU2d, + nni.ConvBnReLU3d, + } + if nni # nni will always import if torch.quantization is available + else {} +) + + +class BNWrapper(Module): + def __init__(self, module: Module): + super().__init__() + self.bn = module + self.freeze_bn = False + + def forward(self, x): + return self.bn(x) + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def train(self, mode=True): + if not self.freeze_bn: + self.bn.train() + return self + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + +_BN_MODULE_TYPES.add(BNWrapper) + class QATWrapper(Module): """ @@ -107,10 +155,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -142,7 +190,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -155,20 +203,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -287,12 +335,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -332,11 +380,40 @@ def _load_qconfigs( return qconfigs +def configure_module_bn_wrappers(module: Module): + """ + if any submodule of the given module has the attribute wrap_qat == True, + then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. + Other named kwargs to the QATWrapper constructor must be contained in a dictionary + under an attributed named `qat_wrapper_kwargs` + + :param module: module to potentially wrap the submodules of + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware + Default is False + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. Default is {} + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. Default is {} + """ + # wrap any children of the given module as a QATWrapper if required + if type(module) != BNWrapper: + for child_name, child_module in module.named_children(): + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) + # recurse on child module + configure_module_bn_wrappers(child_module) + + def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -385,7 +462,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -400,9 +477,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -426,7 +503,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -435,11 +512,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -511,7 +588,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -520,9 +597,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -530,10 +607,15 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True +def freeze_bn_stats(module: Module): + if type(module) in _BN_MODULE_TYPES: + module.freeze_bn_stats() + + def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -568,14 +650,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -612,11 +694,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -646,10 +728,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 637bf7e52dd..7eed410b441 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -48,9 +48,11 @@ from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( add_quant_dequant, + configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, + freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, get_updated_qconfig_kwargs, @@ -232,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' return fuse_fn @model_fuse_fn_name.setter @@ -262,7 +264,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """ + """print :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -356,6 +358,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -500,12 +503,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) + quant_module.apply(freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == "conv_bn_relus": + if self._model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -528,22 +531,14 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) + configure_module_bn_wrappers(module) + # prepare each module / submodule for quantization qconfig = get_qat_qconfig( reduce_range=self._reduce_range, @@ -573,7 +568,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -643,9 +638,7 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits - ) + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From c640cd65630a658bfd0593102402c9fc7149b4ca Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:57:15 -0500 Subject: [PATCH 085/218] Added mode argument to wrapping of train function in BNWrapper --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index c4f165d23ef..64958570e2d 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -106,7 +106,7 @@ def reset_running_stats(self): def train(self, mode=True): if not self.freeze_bn: - self.bn.train() + self.bn.train(mode) return self def update_bn_stats(self): From 61e7a563658be0cad79f0d94a7b8a72e29432076 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:08:20 -0500 Subject: [PATCH 086/218] Set BN fusing back as default. --- .../sparsification/quantization/modifier_quantization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 7eed410b441..37307e38863 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -508,8 +508,8 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True + if self.model_fuse_fn_name == 'conv_bn_relus': + self.model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 118fbdb5982ffb77181f0acf8041770bdcb6d7d6 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:19:09 -0500 Subject: [PATCH 087/218] Set BN fusing back as default. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- .../sparsification/quantization/modifier_quantization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 64958570e2d..a43d69d947b 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -397,7 +397,7 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) != BNWrapper: + if type(module) not in _BN_MODULE_TYPES: for child_name, child_module in module.named_children(): if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: setattr( diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 37307e38863..2a35ebd2aaf 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -509,7 +509,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs if self.model_fuse_fn_name == 'conv_bn_relus': - self.model_fuse_fn_kwargs["inplace"] = True + self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 8db174c8d1d6640b50a46ba6569c8aad5a7ea163 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 11 Mar 2022 19:24:03 -0500 Subject: [PATCH 088/218] Fixed custom freeze_bn_stats. --- .../sparsification/quantization/helpers.py | 251 +++++++++++------- .../quantization/modifier_quantization.py | 46 +++- 2 files changed, 185 insertions(+), 112 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index a43d69d947b..6110a499b70 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,23 +71,6 @@ else None ) -_BN_MODULE_TYPES = ( - { - # Conv based layers - nni.ConvBn1d, - nni.ConvBn2d, - nni.ConvBn3d, - nni.ConvReLU1d, - nni.ConvReLU2d, - nni.ConvReLU3d, - nni.ConvBnReLU1d, - nni.ConvBnReLU2d, - nni.ConvBnReLU3d, - } - if nni # nni will always import if torch.quantization is available - else {} -) - class BNWrapper(Module): def __init__(self, module: Module): @@ -93,6 +78,78 @@ def __init__(self, module: Module): self.bn = module self.freeze_bn = False + @property + def running_mean(self): + return self.bn.running_mean + + @running_mean.setter + def running_mean(self, value): + self.bn.running_mean = value + + @property + def running_var(self): + return self.bn.running_var + + @running_var.setter + def running_var(self, value): + self.bn.running_var = value + + @property + def weight(self): + return self.bn.weight + + @weight.setter + def weight(self, value): + self.bn.weight = value + + @property + def bias(self): + return self.bn.bias + + @bias.setter + def bias(self, value): + self.bn.bias = value + + @property + def gamma(self): + return self.bn.gamma + + @gamma.setter + def gamma(self, value): + self.bn.gamma = value + + @property + def beta(self): + return self.bn.beta + + @beta.setter + def beta(self, value): + self.bn.beta = value + + @property + def num_batches_tracked(self): + return self.bn.num_batches_tracked + + @num_batches_tracked.setter + def num_batches_tracked(self, value): + self.bn.num_batches_tracked = value + + @property + def eps(self): + return self.bn.eps + + @eps.setter + def eps(self, value): + self.bn.eps = value + + @property + def momentum(self): + return self.bn.momentum + + @momentum.setter + def momentum(self, value): + self.bn.momentum = value + def forward(self, x): return self.bn(x) @@ -115,9 +172,6 @@ def update_bn_stats(self): return self -_BN_MODULE_TYPES.add(BNWrapper) - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -155,10 +209,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -190,7 +244,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -203,20 +257,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -335,12 +389,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -397,23 +451,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) not in _BN_MODULE_TYPES: + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -462,7 +516,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -477,9 +531,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -503,7 +557,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -512,11 +566,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -588,7 +642,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -597,9 +651,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -608,14 +662,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if type(module) in _BN_MODULE_TYPES: + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -650,14 +704,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -694,11 +748,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -727,26 +781,25 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 + if mode == "symmetric": + quant_min = -(2 ** (bits - 1)) + quant_max = 2 ** (bits - 1) - 1 + dtype = torch.qint8 + else: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 2a35ebd2aaf..acbae885d71 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,9 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -332,7 +334,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_linear_output_activations + return self._quantize_conv_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -358,7 +360,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -504,11 +505,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) + # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -531,11 +533,23 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) + if len(to_remove_layer_name) == 0: + to_remove_layer_name = None configure_module_bn_wrappers(module) @@ -560,7 +574,8 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + if to_remove_layer_name: + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types to_exclude = [] @@ -568,10 +583,11 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude - self._strip_excluded_module_qconfigs(module) + if self._exclude_module_types: + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) @@ -638,10 +654,14 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" + ) def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + return get_updated_qconfig_kwargs( + self.weight_qconfig_kwargs, self.weight_bits, "symmetric" + ) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( From cfcf1cb8235d033fe4e170d2ec1a1eb87f6927ae Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 14 Mar 2022 15:35:52 -0400 Subject: [PATCH 089/218] Temporary files for evaluating changes to graphs. --- .../pytorch/models/classification/resnet.py | 53 +++++++++---------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index be4182891d6..21611f211d7 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,7 +41,6 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: @@ -146,7 +145,7 @@ def __init__(self, num_channels): if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -206,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -322,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -438,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -480,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 1f63f3d608e55f6c3414e71fbc24d757551226dc Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 17 Mar 2022 11:51:50 -0400 Subject: [PATCH 090/218] Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. --- .../sparsification/quantization/helpers.py | 213 ++++++++++-------- .../quantization/modifier_quantization.py | 58 ++--- 2 files changed, 137 insertions(+), 134 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 6110a499b70..8ae045de9e8 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,7 +31,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -209,10 +207,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -244,7 +242,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -257,20 +255,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -389,12 +387,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -451,23 +449,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, "freeze_bn_stats"): + if not hasattr(module, 'freeze_bn_stats'): for child_name, child_module in module.named_children(): - if type(child_module) in [ - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - ]: - setattr(module, child_name, BNWrapper(child_module)) + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -506,6 +504,17 @@ def configure_module_qat_wrappers( ) +def compute_range(dtype: torch.dtype, bits: int): + if dtype == torch.qint8: + quant_min = -2 ** (bits - 1) + quant_max = 2 ** (bits - 1) - 1 + elif dtype == torch.quint8: + quant_min = 0 + quant_max = 2 ** bits - 1 + + return quant_min, quant_max + + def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -516,7 +525,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -531,9 +540,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -557,7 +566,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -566,11 +575,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = torch.quint8, + weight_dtype: Optional[torch.dtype] = torch.qint8, + activation_bits: Optional[int] = 8, + weight_bits: Optional[int] = 8, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -591,42 +604,35 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_qscheme = ( - torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine - ) - activation_observer_kwargs = dict( - observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - dtype=torch.quint8, - qscheme=activation_qscheme, - reduce_range=reduce_range, - ) - activation_observer_kwargs.update(activation_qconfig_kwargs or {}) - activation_observer = torch_quantization.FakeQuantize.with_args( - **activation_observer_kwargs, + activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, + activation_qconfig_kwargs) + weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, ) - weight_qscheme = ( - torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + + +def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): + qscheme = ( + torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - weight_observer_kwargs = dict( + quant_min, quant_max = compute_range(dtype, bits) + observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=-128, - quant_max=127, - dtype=torch.qint8, - qscheme=weight_qscheme, + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + qscheme=qscheme, reduce_range=reduce_range, ) - - weight_observer_kwargs.update(weight_qconfig_kwargs or {}) - weight_observer = torch_quantization.FakeQuantize.with_args( - **weight_observer_kwargs, - ) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + observer_kwargs.update(qconfig_kwargs or {}) + observer = torch_quantization.FakeQuantize.with_args( + **observer_kwargs, ) + return observer + def fix_observer_quant_range(module: Module): """ @@ -642,7 +648,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -651,9 +657,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -662,14 +668,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, "freeze_bn_stats"): + if hasattr(module, 'freeze_bn_stats'): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -704,14 +710,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -748,11 +754,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -782,17 +788,24 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) + quant_min = -2 ** (bits - 1) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index acbae885d71..5a5e1913b18 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -55,7 +55,6 @@ freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, - get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -151,6 +150,7 @@ def __init__( exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + tensorrt: Optional[bool] = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -187,6 +187,7 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -234,9 +235,10 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = ( - self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" - ) + if self._tensorrt: + fuse_fn = 'no_fuse' + else: + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -360,6 +362,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -505,12 +508,11 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) - # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == "conv_bn_relus": + if self.model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -524,29 +526,16 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) - activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() - weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() - to_remove_layer_name = [] if not self._quantize_linear_output_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -554,10 +543,21 @@ def _enable_module_qat(self, module: Module): configure_module_bn_wrappers(module) # prepare each module / submodule for quantization + if self.tensorrt: + _symmetric_activations = True + _activations_dtype = torch.qint8 + else: + _symmetric_activations = False + _activations_dtype = torch.quint8 + qconfig = get_qat_qconfig( + symmetric_activations=_symmetric_activations, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=_activations_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) @@ -583,7 +583,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude if self._exclude_module_types: @@ -653,16 +653,6 @@ def _calibrate(self, module): if module_training: module.train() - def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" - ) - - def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.weight_qconfig_kwargs, self.weight_bits, "symmetric" - ) - def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From 84d851e499e17b0ec9c038adf7513910a7585663 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Sun, 20 Mar 2022 11:42:14 -0400 Subject: [PATCH 091/218] Added support to TensorRT quantization --- .../sparsification/quantization/helpers.py | 166 ++++++++++++++++-- .../quantization/modifier_quantization.py | 61 +++++-- 2 files changed, 195 insertions(+), 32 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 8ae045de9e8..027c7514c32 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -208,9 +208,15 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - reduce_range: bool = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -232,6 +238,18 @@ def from_module( else {} ) + qat_wrapper_kwargs["symmetric_activations"] = ( + symmetric_activations + if "symmetric_activations" not in qat_wrapper_kwargs + else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] + ) + + qat_wrapper_kwargs["symmetric_weights"] = ( + symmetric_weights or False + if "symmetric_weights" not in qat_wrapper_kwargs + else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] + ) + qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -251,6 +269,30 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + qat_wrapper_kwargs["activation_dtype"] = ( + activation_dtype + if "activation_dtype" not in qat_wrapper_kwargs + else activation_dtype or qat_wrapper_kwargs["activation_dtype"] + ) + + qat_wrapper_kwargs["weight_dtype"] = ( + weight_dtype + if "weight_dtype" not in qat_wrapper_kwargs + else weight_dtype or qat_wrapper_kwargs["weight_dtype"] + ) + + qat_wrapper_kwargs["activation_bits"] = ( + activation_bits + if "activation_bits" not in qat_wrapper_kwargs + else activation_bits or qat_wrapper_kwargs["activation_bits"] + ) + + qat_wrapper_kwargs["weight_bits"] = ( + weight_bits + if "weight_bits" not in qat_wrapper_kwargs + else weight_bits or qat_wrapper_kwargs["weight_bits"] + ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -266,9 +308,15 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -288,25 +336,43 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn + self._symmetric_activations = symmetric_activations + self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._activation_dtype = activation_dtype + self._weight_dtype = weight_dtype + self._activation_bits = activation_bits + self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -390,9 +456,15 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -422,11 +494,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) + if symmetric_activations is None: + _symmetric_activations = qconfig == "symmetric" + else: + _symmetric_activations = symmetric_activations + qconfigs[idx] = get_qat_qconfig( - symmetric_activations=(qconfig == "symmetric"), + symmetric_activations=_symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) return qconfigs @@ -463,9 +545,15 @@ def configure_module_bn_wrappers(module: Module): def configure_module_qat_wrappers( module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -490,29 +578,43 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) -def compute_range(dtype: torch.dtype, bits: int): +def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + dtype = dtype if dtype else torch.quint8 + bits = bits if bits else 8 if dtype == torch.qint8: - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 + quant_min = -(2 ** (bits - 1)) + quant_max = (2 ** (bits - 1)) - 1 elif dtype == torch.quint8: quant_min = 0 - quant_max = 2 ** bits - 1 + quant_max = (2 ** bits) - 1 - return quant_min, quant_max + return quant_min, quant_max, dtype def configure_module_default_qconfigs(module: Module): @@ -575,15 +677,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = torch.quint8, - weight_dtype: Optional[torch.dtype] = torch.qint8, - activation_bits: Optional[int] = 8, - weight_bits: Optional[int] = 8, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -606,18 +708,28 @@ def get_qat_qconfig( """ activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, activation_qconfig_kwargs) - weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + if symmetric_weights is None: + _symmetric_weights = True + else: + _symmetric_weights = symmetric_weights + + if weight_dtype is None: + _weight_dtype = torch.qint8 + else: + _weight_dtype = weight_dtype + + weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): +def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): qscheme = ( torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - quant_min, quant_max = compute_range(dtype, bits) + quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, quant_min=quant_min, @@ -756,9 +868,15 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -776,11 +894,21 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: + if symmetric_weights is None: + _symmetric_weights = False + else: + _symmetric_weights = symmetric_weights + qconfig = get_qat_qconfig( - symmetric_weights=False, + symmetric_activations=symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 5a5e1913b18..27c5a4c336e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -147,10 +147,10 @@ def __init__( weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_batchnorm: bool = True, - exclude_module_types: Union[List[str], None] = None, + exclude_module_types: Optional[List[str]] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: Optional[bool] = False, + tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -379,7 +379,15 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - return self._weight_qconfig_kwargs + if "observer" in self._weight_qconfig_kwargs: + kwargs = self._weight_qconfig_kwargs.copy() + if kwargs["observer"] == "minmaxobserver": + kwargs["observer"] = torch_quantization.MinMaxObserver + return kwargs + else: + return self._weight_qconfig_kwargs + + @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -389,6 +397,15 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps + @ModifierProp() + def tensorrt(self) -> Dict[str, Any]: + """ + :return: Dictionary with correct quant_min, quant_max, and dtype values + for activations + + """ + return self._tensorrt + def initialize( self, module: Module, @@ -545,17 +562,23 @@ def _enable_module_qat(self, module: Module): # prepare each module / submodule for quantization if self.tensorrt: _symmetric_activations = True - _activations_dtype = torch.qint8 + _activation_dtype = torch.qint8 + _symmetric_weights = True + _weight_dtype = torch.qint8 else: - _symmetric_activations = False - _activations_dtype = torch.quint8 + _symmetric_activations = None + _activation_dtype = None + _symmetric_weights = None + _weight_dtype = None qconfig = get_qat_qconfig( symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=_activations_dtype, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, activation_bits=self.activation_bits, weight_bits=self.weight_bits ) @@ -563,9 +586,15 @@ def _enable_module_qat(self, module: Module): # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -594,9 +623,15 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, ) # propagate custom quant min/max range from FakeQuantize to Observer objects From 97329f3f18dba478a893ca57818b6ac61f0640c5 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 21 Mar 2022 19:16:26 -0400 Subject: [PATCH 092/218] Included check to account for when weight_qconfig_kwatgs is None. --- .../sparsification/quantization/modifier_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 27c5a4c336e..a306f4d8e73 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -379,7 +379,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if "observer" in self._weight_qconfig_kwargs: + if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver From 840a6bf8adcf2200f803fe0757128d63965f38f7 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 14:20:19 -0400 Subject: [PATCH 093/218] Modified argument names for backwards compatibility. --- .../quantization/modifier_quantization.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index a306f4d8e73..73a50e0f9c4 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -141,8 +141,8 @@ def __init__( model_fuse_fn_kwargs: Dict[str, Any] = None, quantize_embeddings: bool = True, reduce_range: bool = False, - quantize_linear_output_activations: bool = False, - quantize_conv_output_activations: bool = False, + quantize_linear_activations: bool = False, + quantize_conv_activations: bool = False, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, @@ -174,8 +174,8 @@ def __init__( self._freeze_bn_stats_epoch = freeze_bn_stats_epoch self._quantize_embeddings = quantize_embeddings self._reduce_range = reduce_range - self._quantize_linear_output_activations = quantize_linear_output_activations - self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_linear_activations = quantize_linear_activations + self._quantize_conv_activations = quantize_conv_activations self._activation_bits = activation_bits self._weight_bits = weight_bits self._exclude_batchnorm = exclude_batchnorm @@ -320,7 +320,7 @@ def reduce_range(self) -> bool: return self._reduce_range @ModifierProp() - def quantize_linear_output_activations(self) -> bool: + def quantize_linear_activations(self) -> bool: """ :return: if False, FakeQuantize ops will not be run for activations of fully connected layers. this is important for quantizing @@ -328,15 +328,15 @@ def quantize_linear_output_activations(self) -> bool: are kept at 32 bits of precision and fake quantizing the outputs harm training recovery """ - return self._quantize_linear_output_activations + return self._quantize_linear_activations @ModifierProp() - def quantize_conv_output_activations(self) -> bool: + def quantize_conv_activations(self) -> bool: """ :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_conv_output_activations + return self._quantize_conv_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -544,10 +544,10 @@ def _enable_module_qat(self, module: Module): module_fuse_fn(**self._model_fuse_fn_kwargs) to_remove_layer_name = [] - if not self._quantize_linear_output_activations: + if not self._quantize_linear_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) - if not self._quantize_conv_output_activations: + if not self._quantize_conv_activations: to_remove_layer_name.extend( ["Conv1d", "Conv2d", "Conv3d", "ConvBn1d", "ConvBn2d", "ConvBn3d", From 518552089c42ea16adfef362b247e0110b13e8f0 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:40:51 -0400 Subject: [PATCH 094/218] Updated documentation to reflect changes. --- .../sparsification/quantization/helpers.py | 118 ++++++++++++------ 1 file changed, 81 insertions(+), 37 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 027c7514c32..bc9aeb6d58c 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -69,8 +69,14 @@ else None ) - +# class BNWrapper(Module): + """ + Wraps BatchNormalization module to expose methods needed to enable + freezing/unfreezing of statistics + + :param module: BatchNormalization module to be wrapped + """ def __init__(self, module: Module): super().__init__() self.bn = module @@ -220,14 +226,25 @@ def from_module( ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for - :param reduce_range: if True, the quantization range will be reduced by one - bit. This may prevent overflow issues with model execution on certain - hardware. Default is None, will only override qat_wrapper_kwargs if set - to a bool value + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. :return: QATWrapper object created using the given Module as the forward function. Will attempt to find any other named parameter of the QATWrapper constructor from the attributes of the given Module @@ -293,6 +310,7 @@ def from_module( else weight_bits or qat_wrapper_kwargs["weight_bits"] ) + # Remove qconfig from wrapped layer to avoid duplicate quantization module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -516,19 +534,10 @@ def _load_qconfigs( def configure_module_bn_wrappers(module: Module): """ - if any submodule of the given module has the attribute wrap_qat == True, - then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. - Other named kwargs to the QATWrapper constructor must be contained in a dictionary - under an attributed named `qat_wrapper_kwargs` + Wrap any BatchNormalization modules that are not fused with convolutions + with BNWrapper to enable freezing/unfreezing of BN statistics :param module: module to potentially wrap the submodules of - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required if not hasattr(module, 'freeze_bn_stats'): @@ -562,14 +571,25 @@ def configure_module_qat_wrappers( under an attributed named `qat_wrapper_kwargs` :param module: module to potentially wrap the submodules of + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} - """ + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. """ # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: @@ -605,6 +625,13 @@ def configure_module_qat_wrappers( def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + """ + compute quantization limits depending on data type and number of bits + + :param dtype: data type. If None dtype is set to torch.quint8. + :param bits: number of bits. If None is set to 8. + :return: minimum limit, maximum limit, data type + """ dtype = dtype if dtype else torch.quint8 bits = bits if bits else 8 if dtype == torch.qint8: @@ -689,18 +716,24 @@ def get_qat_qconfig( ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric - UINT8 quantization range with zero point set to 128. Otherwise activations - will use asymmetric quantization with any zero point. Default is False + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. :param symmetric_weights: if True, weights will have a symmetric - INT8 quantization range with zero point set to 0. Otherwise activations - will use asymmetric quantization with any zero point. Default is True + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. :return: A QAT fake quantization config for symmetric weight quantization and asymmetric activation quantization. The difference between this and torch.quantization.default_qat_qconfig is that the activation observer @@ -885,14 +918,25 @@ def prepare_embeddings_qat( :param module: module to run QAT for the embeddings of :param qconfig: qconfig to generate the fake quantize ops from. Default uses INT8 asymmetric range - :param activation_qconfig_kwargs: additional kwargs for quantizing activations. - Default is {}. - :param weight_qconfig_kwargs: additional kwargs for quantizing the weights. - Default is {}. + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware. - Default is False - """ + Default is False. + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. """ if qconfig is None: if symmetric_weights is None: _symmetric_weights = False From 74ad8efd854087490a5d53c194e61714e3c2ef93 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:40:57 -0400 Subject: [PATCH 095/218] Updated documentation to reflect changes. --- .../quantization/modifier_quantization.py | 59 ++++++++++++------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 73a50e0f9c4..4f912b3d8bb 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -113,21 +113,26 @@ class QuantizationModifier(ScheduledModifier): :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware Default is False - :param quantize_linear_activations: if False, FakeQuantize ops will not be run - for activations of fully connected layers. this is important for quantizing - transformer based models such as BERT where the quantized MatMul outputs - are kept at 32 bits of precision and fake quantizing the outputs harm training - recovery. Default is True + :param quantize_linear_activations: if True, FakeQuantize ops will be run + for output activations of fully connected layers. Default is False. + :param quantize_conv_activations: if True, FakeQuantize ops will be run + for output activations of convolutional layers. Default is False. :param activation_bits: Number of bits to use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. + activations. Default is None, which will quantize activations to 8 bits. + :param weight_bits: Number of bits to use for setting quant min/max values for + weights. Default is None, which will quantize weights to 8 bits. :param num_calibration_steps: Number of steps to run post training calibration for. - When None, the entire calibration_dataloader is used + When None, the entire calibration_dataloader is used + :param exclude_batchnorm: If True, do not propagate quantization qconfigs to + batch-normalization modules :param exclude_module_types: optional list of module class names to not propagate quantization configs to. Default is None :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. + weights. + :param tenssorrt: if True sets quantization configuration for compatibility with + explict quantization as supported by TensorRT 8.2. """ def __init__( @@ -232,11 +237,12 @@ def submodules(self, value: Union[List[str], None]): def model_fuse_fn_name(self) -> Union[str, None]: """ :return: Name of model function to fuse the model in place prior - to performing QAT. None to uses the default function + to performing QAT. None sets to default function. + If tensorrt flag is True, default is 'no_fuse', otherwise `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ if self._tensorrt: - fuse_fn = 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' else: fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @@ -322,19 +328,16 @@ def reduce_range(self) -> bool: @ModifierProp() def quantize_linear_activations(self) -> bool: """ - :return: if False, FakeQuantize ops will not be run - for activations of fully connected layers. this is important for quantizing - transformer based models such as BERT where the quantized MatMul outputs - are kept at 32 bits of precision and fake quantizing the outputs harm - training recovery + :return: if True, FakeQuantize ops will be run for output activations + of fully connected layers """ return self._quantize_linear_activations @ModifierProp() def quantize_conv_activations(self) -> bool: """ - :return: if False, FakeQuantize ops will not be run - for activations of convolutional layers. + :return: if True, FakeQuantize ops will be run for output activations + of convolutional layers """ return self._quantize_conv_activations @@ -358,7 +361,7 @@ def activation_bits(self) -> Optional[int]: def weight_bits(self) -> Optional[int]: """ :return: Number of bits to be use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. + weights. Default is None, which will quantize weights to 8 bits. """ return self._weight_bits @@ -543,6 +546,7 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) + # build list of layer types that should not quantize output activations to_remove_layer_name = [] if not self._quantize_linear_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) @@ -557,9 +561,16 @@ def _enable_module_qat(self, module: Module): if len(to_remove_layer_name) == 0: to_remove_layer_name = None + # fix for freezing batchnorm statistics when not fusing BN with convs. + # pytorch only supports freezing batchnorm statistics for fused modules. + # this fix wraps BN modules adding with a new module class that supports + # methods related to freezing/unfreezing BN statistics. configure_module_bn_wrappers(module) - # prepare each module / submodule for quantization + # set qconfig. + # if tensorrt flag is used, set activation and weights to symmetric + # quantization. + # otherwise, use the default values set in get_qat_qconfig if self.tensorrt: _symmetric_activations = True _activation_dtype = torch.qint8 @@ -582,6 +593,8 @@ def _enable_module_qat(self, module: Module): activation_bits=self.activation_bits, weight_bits=self.weight_bits ) + + # prepare each module / submodule for quantization for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( @@ -596,13 +609,17 @@ def _enable_module_qat(self, module: Module): activation_bits=self.activation_bits, weight_bits=self.weight_bits ) + # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig + # wrap all conv / linear blocks in with quantization observers torch_quantization.propagate_qconfig_(quant_module) configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) + + # Remove output quantization from appropriate modules if to_remove_layer_name: remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) @@ -611,6 +628,8 @@ def _enable_module_qat(self, module: Module): if self._exclude_module_types: to_exclude.extend(self._exclude_module_types) + # if exclude_batchnorm flag is used, add batch norm layers to list of + # modules to exclude qconfig if self._exclude_batchnorm: to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) From 6d57cb2f717db6d5d355b1f5f8e0964753b19592 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:42:27 -0400 Subject: [PATCH 096/218] Updated documentation to reflect changes. --- src/sparseml/pytorch/models/classification/resnet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 21611f211d7..3a7a5169447 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -140,6 +140,10 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): + """ + Wrapper for the FloatFunctional class that enables QATWrapper used to + quantize the first input to the Add operation + """ def __init__(self, num_channels): super().__init__() if FloatFunctional: From 43bf76805a507d9f476766448deab478b269bf71 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:52:15 -0400 Subject: [PATCH 097/218] Fixed default weights data type. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index bc9aeb6d58c..b3e47162c5e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -751,7 +751,7 @@ def get_qat_qconfig( else: _weight_dtype = weight_dtype - weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + weight_observer = get_observer(_symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, From afd1df8b492a1eed6c084cc2d9a51b6282ca32c2 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 17:02:48 -0400 Subject: [PATCH 098/218] Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 54 ++-- .../sparsification/quantization/helpers.py | 247 +++++++++--------- .../quantization/modifier_quantization.py | 44 +++- 3 files changed, 186 insertions(+), 159 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3a7a5169447..cd8b979c3ad 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -144,12 +145,13 @@ class _AddReLU(Module): Wrapper for the FloatFunctional class that enables QATWrapper used to quantize the first input to the Add operation """ + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -209,12 +211,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -325,12 +327,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -441,15 +443,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -483,10 +485,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index b3e47162c5e..c2e21d30a16 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,7 +71,7 @@ else None ) -# + class BNWrapper(Module): """ Wraps BatchNormalization module to expose methods needed to enable @@ -77,6 +79,7 @@ class BNWrapper(Module): :param module: BatchNormalization module to be wrapped """ + def __init__(self, module: Module): super().__init__() self.bn = module @@ -213,16 +216,16 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -241,8 +244,10 @@ def from_module( activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_dtype: quantized activation data type. + Default is torch.quint8. + :param weight_dtype: quantized weights data type. + Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. :param weight_bits: number of bits for weights. Default is 8. :return: QATWrapper object created using the given Module as the forward @@ -277,7 +282,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -315,26 +320,26 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -471,18 +476,18 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -540,29 +545,29 @@ def configure_module_bn_wrappers(module: Module): :param module: module to potentially wrap the submodules of """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, 'freeze_bn_stats'): + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -589,7 +594,7 @@ def configure_module_qat_wrappers( :param activation_dtype: quantized activation data type. Default is torch.quint8. :param weight_dtype: quantized weights data type. Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. """ + :param weight_bits: number of bits for weights. Default is 8.""" # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: @@ -654,7 +659,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -669,9 +674,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -695,7 +700,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -704,15 +709,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -739,8 +744,13 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, - activation_qconfig_kwargs) + activation_observer = get_observer( + symmetric_activations, + activation_dtype, + activation_bits, + reduce_range, + activation_qconfig_kwargs, + ) if symmetric_weights is None: _symmetric_weights = True else: @@ -751,17 +761,23 @@ def get_qat_qconfig( else: _weight_dtype = weight_dtype - weight_observer = get_observer(_symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs) + weight_observer = get_observer( + _symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs + ) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): - qscheme = ( - torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine - ) +def get_observer( + symmetric: Optional[bool], + dtype: Optional[torch.dtype], + bits: Optional[int], + reduce_range: bool, + qconfig_kwargs: Dict[str, Any], +): + qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, @@ -793,7 +809,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -813,14 +829,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, 'freeze_bn_stats'): + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -855,14 +871,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -899,17 +915,17 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -936,7 +952,7 @@ def prepare_embeddings_qat( :param activation_dtype: quantized activation data type. Default is torch.quint8. :param weight_dtype: quantized weights data type. Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. """ + :param weight_bits: number of bits for weights. Default is 8.""" if qconfig is None: if symmetric_weights is None: _symmetric_weights = False @@ -960,24 +976,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -2 ** (bits - 1) + quant_min = -(2 ** (bits - 1)) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 4f912b3d8bb..30e1aefbe15 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -242,9 +242,15 @@ def model_fuse_fn_name(self) -> Union[str, None]: `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ if self._tensorrt: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + ) else: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name + if self._model_fuse_fn_name + else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -365,7 +371,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -382,7 +387,10 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: + if ( + self._weight_qconfig_kwargs is not None + and "observer" in self._weight_qconfig_kwargs + ): kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver @@ -390,8 +398,6 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: else: return self._weight_qconfig_kwargs - - @ModifierProp() def num_calibration_steps(self) -> Optional[int]: """ @@ -532,7 +538,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -553,10 +559,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -591,7 +607,7 @@ def _enable_module_qat(self, module: Module): activation_dtype=_activation_dtype, weight_dtype=_weight_dtype, activation_bits=self.activation_bits, - weight_bits=self.weight_bits + weight_bits=self.weight_bits, ) # prepare each module / submodule for quantization @@ -607,7 +623,7 @@ def _enable_module_qat(self, module: Module): activation_dtype=_activation_dtype, weight_dtype=_weight_dtype, activation_bits=self.activation_bits, - weight_bits=self.weight_bits + weight_bits=self.weight_bits, ) # set quantization config (asymmetric activations, symmetric weights) @@ -631,7 +647,7 @@ def _enable_module_qat(self, module: Module): # if exclude_batchnorm flag is used, add batch norm layers to list of # modules to exclude qconfig if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude if self._exclude_module_types: From 44303e19ad82fec8a3a4b2b2fa30d4f78327bc06 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 17:53:05 -0400 Subject: [PATCH 099/218] Removed unused method --- .../sparsification/quantization/helpers.py | 31 ------------------- 1 file changed, 31 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index c2e21d30a16..6c30789fbb7 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -41,7 +41,6 @@ "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", - "get_updated_qconfig_kwargs", "fix_observer_quant_range", "freeze_bn_stats", "fuse_module_conv_bn_relus", @@ -975,36 +974,6 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} - - # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): - raise ValueError( - "Cannot override quant_max and quant_min when number of bits is set" - ) - - if bits: - if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) - quant_max = 2 ** (bits - 1) - 1 - dtype = torch.qint8 - else: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 - - qconfig_kwargs.update( - dict( - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - ) - ) - - return qconfig_kwargs - - def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"): embedding.weight_fake_quant = qconfig.weight() From 64b42d5ba291b7098909faeb9887575d8ff3c737 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 12:14:39 -0500 Subject: [PATCH 100/218] Removed output quantization from conv layers --- .../sparsification/quantization/helpers.py | 601 +++++------------- .../quantization/modifier_quantization.py | 231 +++---- 2 files changed, 226 insertions(+), 606 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 6c30789fbb7..e10224bbce7 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,22 +31,21 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ + "QUANTIZABLE_MODULE_TYPES", "QATWrapper", - "configure_module_bn_wrappers", - "configure_module_default_qconfigs", "configure_module_qat_wrappers", + "configure_module_default_qconfigs", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", + "get_updated_qconfig_kwargs", "fix_observer_quant_range", - "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] -_QUANTIZABLE_MODULE_TYPES = ( +QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -71,113 +69,6 @@ ) -class BNWrapper(Module): - """ - Wraps BatchNormalization module to expose methods needed to enable - freezing/unfreezing of statistics - - :param module: BatchNormalization module to be wrapped - """ - - def __init__(self, module: Module): - super().__init__() - self.bn = module - self.freeze_bn = False - - @property - def running_mean(self): - return self.bn.running_mean - - @running_mean.setter - def running_mean(self, value): - self.bn.running_mean = value - - @property - def running_var(self): - return self.bn.running_var - - @running_var.setter - def running_var(self, value): - self.bn.running_var = value - - @property - def weight(self): - return self.bn.weight - - @weight.setter - def weight(self, value): - self.bn.weight = value - - @property - def bias(self): - return self.bn.bias - - @bias.setter - def bias(self, value): - self.bn.bias = value - - @property - def gamma(self): - return self.bn.gamma - - @gamma.setter - def gamma(self, value): - self.bn.gamma = value - - @property - def beta(self): - return self.bn.beta - - @beta.setter - def beta(self, value): - self.bn.beta = value - - @property - def num_batches_tracked(self): - return self.bn.num_batches_tracked - - @num_batches_tracked.setter - def num_batches_tracked(self, value): - self.bn.num_batches_tracked = value - - @property - def eps(self): - return self.bn.eps - - @eps.setter - def eps(self, value): - self.bn.eps = value - - @property - def momentum(self): - return self.bn.momentum - - @momentum.setter - def momentum(self, value): - self.bn.momentum = value - - def forward(self, x): - return self.bn(x) - - def freeze_bn_stats(self): - self.freeze_bn = True - self.bn.training = False - return self - - def reset_running_stats(self): - self.bn.reset_running_stats() - - def train(self, mode=True): - if not self.freeze_bn: - self.bn.train(mode) - return self - - def update_bn_stats(self): - self.freeze_bn = False - self.bn.training = True - return self - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -215,40 +106,21 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for - :param symmetric_activations: if True, activations will have a symmetric - quantization range with a pre-specified zero point - (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). - Default is False. - :param symmetric_weights: if True, weights will have a symmetric - quantization range with a pre-specified zero point - (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). - Default is True. - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware. - Default is False. + :param reduce_range: if True, the quantization range will be reduced by one + bit. This may prevent overflow issues with model execution on certain + hardware. Default is None, will only override qat_wrapper_kwargs if set + to a bool value :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. Default is {} :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param activation_dtype: quantized activation data type. - Default is torch.quint8. - :param weight_dtype: quantized weights data type. - Default is torch.qint8. - :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. + weights. Default is {} :return: QATWrapper object created using the given Module as the forward function. Will attempt to find any other named parameter of the QATWrapper constructor from the attributes of the given Module @@ -259,18 +131,6 @@ def from_module( else {} ) - qat_wrapper_kwargs["symmetric_activations"] = ( - symmetric_activations - if "symmetric_activations" not in qat_wrapper_kwargs - else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] - ) - - qat_wrapper_kwargs["symmetric_weights"] = ( - symmetric_weights or False - if "symmetric_weights" not in qat_wrapper_kwargs - else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] - ) - qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -281,7 +141,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -290,55 +150,23 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) - qat_wrapper_kwargs["activation_dtype"] = ( - activation_dtype - if "activation_dtype" not in qat_wrapper_kwargs - else activation_dtype or qat_wrapper_kwargs["activation_dtype"] - ) - - qat_wrapper_kwargs["weight_dtype"] = ( - weight_dtype - if "weight_dtype" not in qat_wrapper_kwargs - else weight_dtype or qat_wrapper_kwargs["weight_dtype"] - ) - - qat_wrapper_kwargs["activation_bits"] = ( - activation_bits - if "activation_bits" not in qat_wrapper_kwargs - else activation_bits or qat_wrapper_kwargs["activation_bits"] - ) - - qat_wrapper_kwargs["weight_bits"] = ( - weight_bits - if "weight_bits" not in qat_wrapper_kwargs - else weight_bits or qat_wrapper_kwargs["weight_bits"] - ) - - # Remove qconfig from wrapped layer to avoid duplicate quantization - module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -358,43 +186,25 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn - self._symmetric_activations = symmetric_activations - self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._activation_dtype = activation_dtype - self._weight_dtype = weight_dtype - self._activation_bits = activation_bits - self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -475,18 +285,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -516,57 +320,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) - if symmetric_activations is None: - _symmetric_activations = qconfig == "symmetric" - else: - _symmetric_activations = symmetric_activations - qconfigs[idx] = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=symmetric_weights, + symmetric_activations=(qconfig == "symmetric"), reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) return qconfigs -def configure_module_bn_wrappers(module: Module): - """ - Wrap any BatchNormalization modules that are not fused with convolutions - with BNWrapper to enable freezing/unfreezing of BN statistics - - :param module: module to potentially wrap the submodules of - """ - # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, "freeze_bn_stats"): - for child_name, child_module in module.named_children(): - if type(child_module) in [ - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - ]: - setattr(module, child_name, BNWrapper(child_module)) - # recurse on child module - configure_module_bn_wrappers(child_module) - - def configure_module_qat_wrappers( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -575,25 +343,14 @@ def configure_module_qat_wrappers( under an attributed named `qat_wrapper_kwargs` :param module: module to potentially wrap the submodules of - :param symmetric_activations: if True, activations will have a symmetric - quantization range with a pre-specified zero point - (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). - Default is False. - :param symmetric_weights: if True, weights will have a symmetric - quantization range with a pre-specified zero point - (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). - Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware. - Default is False. + This may prevent overflow issues with model execution on certain hardware + Default is False :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. Default is {} :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. - :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8.""" + weights. Default is {} + """ # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: @@ -602,52 +359,20 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) -def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): - """ - compute quantization limits depending on data type and number of bits - - :param dtype: data type. If None dtype is set to torch.quint8. - :param bits: number of bits. If None is set to 8. - :return: minimum limit, maximum limit, data type - """ - dtype = dtype if dtype else torch.quint8 - bits = bits if bits else 8 - if dtype == torch.qint8: - quant_min = -(2 ** (bits - 1)) - quant_max = (2 ** (bits - 1)) - 1 - elif dtype == torch.quint8: - quant_min = 0 - quant_max = (2 ** bits) - 1 - - return quant_min, quant_max, dtype - - def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -658,7 +383,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -673,9 +398,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -699,7 +424,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -708,90 +433,66 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric - quantization range with a pre-specified zero point - (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). - Default is False. + UINT8 quantization range with zero point set to 128. Otherwise activations + will use asymmetric quantization with any zero point. Default is False :param symmetric_weights: if True, weights will have a symmetric - quantization range with a pre-specified zero point - (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). - Default is True. + INT8 quantization range with zero point set to 0. Otherwise activations + will use asymmetric quantization with any zero point. Default is True :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware. - Default is False. + This may prevent overflow issues with model execution on certain hardware + Default is False :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. - :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. + weights. :return: A QAT fake quantization config for symmetric weight quantization and asymmetric activation quantization. The difference between this and torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_observer = get_observer( - symmetric_activations, - activation_dtype, - activation_bits, - reduce_range, - activation_qconfig_kwargs, + activation_qscheme = ( + torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine ) - if symmetric_weights is None: - _symmetric_weights = True - else: - _symmetric_weights = symmetric_weights - - if weight_dtype is None: - _weight_dtype = torch.qint8 - else: - _weight_dtype = weight_dtype - - weight_observer = get_observer( - _symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs + activation_observer_kwargs = dict( + observer=torch_quantization.MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=activation_qscheme, + reduce_range=reduce_range, ) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + activation_observer_kwargs.update(activation_qconfig_kwargs or {}) + activation_observer = torch_quantization.FakeQuantize.with_args( + **activation_observer_kwargs, ) - - -def get_observer( - symmetric: Optional[bool], - dtype: Optional[torch.dtype], - bits: Optional[int], - reduce_range: bool, - qconfig_kwargs: Dict[str, Any], -): - qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine - quant_min, quant_max, dtype = compute_range(dtype, bits) - observer_kwargs = dict( + weight_qscheme = ( + torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + ) + weight_observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - qscheme=qscheme, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=weight_qscheme, reduce_range=reduce_range, ) - observer_kwargs.update(qconfig_kwargs or {}) - observer = torch_quantization.FakeQuantize.with_args( - **observer_kwargs, - ) - return observer + weight_observer_kwargs.update(weight_qconfig_kwargs or {}) + weight_observer = torch_quantization.FakeQuantize.with_args( + **weight_observer_kwargs, + ) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, + ) def fix_observer_quant_range(module: Module): @@ -808,7 +509,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -817,9 +518,14 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) + or ( # do not propagate default uint8 symmetric range + observer.qscheme == torch.per_tensor_symmetric + and fake_quantize.quant_min == 0 + and fake_quantize.quant_max == 255 + ) ): continue observer.quant_min = fake_quantize.quant_min @@ -827,15 +533,10 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True -def freeze_bn_stats(module: Module): - if hasattr(module, "freeze_bn_stats"): - module.freeze_bn_stats() - - def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -870,14 +571,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -914,17 +615,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -933,47 +628,57 @@ def prepare_embeddings_qat( :param module: module to run QAT for the embeddings of :param qconfig: qconfig to generate the fake quantize ops from. Default uses INT8 asymmetric range - :param symmetric_activations: if True, activations will have a symmetric - quantization range with a pre-specified zero point - (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). - Default is False. - :param symmetric_weights: if True, weights will have a symmetric - quantization range with a pre-specified zero point - (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). - Default is True. + :param activation_qconfig_kwargs: additional kwargs for quantizing activations. + Default is {}. + :param weight_qconfig_kwargs: additional kwargs for quantizing the weights. + Default is {}. :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware. - Default is False. - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. - :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8.""" + Default is False + """ if qconfig is None: - if symmetric_weights is None: - _symmetric_weights = False - else: - _symmetric_weights = symmetric_weights - qconfig = get_qat_qconfig( - symmetric_activations=symmetric_activations, - symmetric_weights=_symmetric_weights, + symmetric_weights=False, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: _prepare_qat_embedding(submodule, qconfig) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits): + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) + + # update qconfig_kwargs for bits + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): + raise ValueError( + "Cannot override quant_max and quant_min when number of bits is set" + ) + + if bits: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( + dict( + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + ) + ) + + return qconfig_kwargs + + def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"): embedding.weight_fake_quant = qconfig.weight() diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 30e1aefbe15..79772790566 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,14 +47,14 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( + QUANTIZABLE_MODULE_TYPES, add_quant_dequant, - configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, - freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, + get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -94,8 +94,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as - 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use + the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -113,26 +113,21 @@ class QuantizationModifier(ScheduledModifier): :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware Default is False - :param quantize_linear_activations: if True, FakeQuantize ops will be run - for output activations of fully connected layers. Default is False. - :param quantize_conv_activations: if True, FakeQuantize ops will be run - for output activations of convolutional layers. Default is False. + :param quantize_linear_activations: if False, FakeQuantize ops will not be run + for activations of fully connected layers. this is important for quantizing + transformer based models such as BERT where the quantized MatMul outputs + are kept at 32 bits of precision and fake quantizing the outputs harm training + recovery. Default is True :param activation_bits: Number of bits to use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. - :param weight_bits: Number of bits to use for setting quant min/max values for - weights. Default is None, which will quantize weights to 8 bits. + activations. Default is None, which will quantize activations to 8 bits. :param num_calibration_steps: Number of steps to run post training calibration for. - When None, the entire calibration_dataloader is used - :param exclude_batchnorm: If True, do not propagate quantization qconfigs to - batch-normalization modules + When None, the entire calibration_dataloader is used :param exclude_module_types: optional list of module class names to not propagate quantization configs to. Default is None :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param tenssorrt: if True sets quantization configuration for compatibility with - explict quantization as supported by TensorRT 8.2. + weights. """ def __init__( @@ -146,16 +141,15 @@ def __init__( model_fuse_fn_kwargs: Dict[str, Any] = None, quantize_embeddings: bool = True, reduce_range: bool = False, - quantize_linear_activations: bool = False, - quantize_conv_activations: bool = False, + quantize_linear_output_activations: bool = False, + quantize_conv_output_activations: bool = False, + quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, - exclude_batchnorm: bool = True, - exclude_module_types: Optional[List[str]] = None, + exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -179,11 +173,11 @@ def __init__( self._freeze_bn_stats_epoch = freeze_bn_stats_epoch self._quantize_embeddings = quantize_embeddings self._reduce_range = reduce_range - self._quantize_linear_activations = quantize_linear_activations - self._quantize_conv_activations = quantize_conv_activations + self._quantize_linear_output_activations = quantize_linear_output_activations + self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits - self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -192,7 +186,6 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -237,21 +230,10 @@ def submodules(self, value: Union[List[str], None]): def model_fuse_fn_name(self) -> Union[str, None]: """ :return: Name of model function to fuse the model in place prior - to performing QAT. None sets to default function. - If tensorrt flag is True, default is 'no_fuse', otherwise + to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - if self._tensorrt: - fuse_fn = ( - self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" - ) - else: - fuse_fn = ( - self._model_fuse_fn_name - if self._model_fuse_fn_name - else "conv_bn_relus" - ) - return fuse_fn + return self._model_fuse_fn_name @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -280,7 +262,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """print + """ :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -332,20 +314,23 @@ def reduce_range(self) -> bool: return self._reduce_range @ModifierProp() - def quantize_linear_activations(self) -> bool: + def quantize_linear_output_activations(self) -> bool: """ - :return: if True, FakeQuantize ops will be run for output activations - of fully connected layers + :return: if False, FakeQuantize ops will not be run + for activations of fully connected layers. this is important for quantizing + transformer based models such as BERT where the quantized MatMul outputs + are kept at 32 bits of precision and fake quantizing the outputs harm + training recovery """ - return self._quantize_linear_activations + return self._quantize_linear_output_activations @ModifierProp() - def quantize_conv_activations(self) -> bool: + def quantize_conv_output_activations(self) -> bool: """ - :return: if True, FakeQuantize ops will be run for output activations - of convolutional layers + :return: if False, FakeQuantize ops will not be run + for activations of convolutional layers. """ - return self._quantize_conv_activations + return self._quantize_linear_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -367,10 +352,11 @@ def activation_bits(self) -> Optional[int]: def weight_bits(self) -> Optional[int]: """ :return: Number of bits to be use for setting quant min/max values for - weights. Default is None, which will quantize weights to 8 bits. + activations. Default is None, which will quantize activations to 8 bits. """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -387,16 +373,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if ( - self._weight_qconfig_kwargs is not None - and "observer" in self._weight_qconfig_kwargs - ): - kwargs = self._weight_qconfig_kwargs.copy() - if kwargs["observer"] == "minmaxobserver": - kwargs["observer"] = torch_quantization.MinMaxObserver - return kwargs - else: - return self._weight_qconfig_kwargs + return self._weight_qconfig_kwargs @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -406,15 +383,6 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps - @ModifierProp() - def tensorrt(self) -> Dict[str, Any]: - """ - :return: Dictionary with correct quant_min, quant_max, and dtype values - for activations - - """ - return self._tensorrt - def initialize( self, module: Module, @@ -448,7 +416,10 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if name in self._submodules: + if ( + type(submodule) in QUANTIZABLE_MODULE_TYPES + and name in self._submodules + ): self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -533,15 +504,15 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(freeze_bn_stats) + quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == "conv_bn_relus": - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) - elif self.model_fuse_fn_name != "no_fuse": + if ( + self._model_fuse_fn_name is not None + and self._model_fuse_fn_name != "no_fuse" + ): # module class fn module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -551,105 +522,49 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) + elif self._model_fuse_fn_name is None: # default auto fn + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + + activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() + weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() - # build list of layer types that should not quantize output activations to_remove_layer_name = [] - if not self._quantize_linear_activations: - to_remove_layer_name.extend(["Linear", "LinearReLU"]) + if not self._quantize_linear_output_activations: + to_remove_layer_name.extend(["Linear", "LinearReLu"]) - if not self._quantize_conv_activations: + if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) - if len(to_remove_layer_name) == 0: - to_remove_layer_name = None - - # fix for freezing batchnorm statistics when not fusing BN with convs. - # pytorch only supports freezing batchnorm statistics for fused modules. - # this fix wraps BN modules adding with a new module class that supports - # methods related to freezing/unfreezing BN statistics. - configure_module_bn_wrappers(module) - - # set qconfig. - # if tensorrt flag is used, set activation and weights to symmetric - # quantization. - # otherwise, use the default values set in get_qat_qconfig - if self.tensorrt: - _symmetric_activations = True - _activation_dtype = torch.qint8 - _symmetric_weights = True - _weight_dtype = torch.qint8 - else: - _symmetric_activations = None - _activation_dtype = None - _symmetric_weights = None - _weight_dtype = None + # prepare each module / submodule for quantization qconfig = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) - - # prepare each module / submodule for quantization for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) - # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig - # wrap all conv / linear blocks in with quantization observers torch_quantization.propagate_qconfig_(quant_module) configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - - # Remove output quantization from appropriate modules - if to_remove_layer_name: - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types - to_exclude = [] - if self._exclude_module_types: - to_exclude.extend(self._exclude_module_types) - - # if exclude_batchnorm flag is used, add batch norm layers to list of - # modules to exclude qconfig - if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) - - self._exclude_module_types = to_exclude if self._exclude_module_types: self._strip_excluded_module_qconfigs(module) @@ -658,15 +573,9 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # propagate custom quant min/max range from FakeQuantize to Observer objects @@ -723,6 +632,12 @@ def _calibrate(self, module): if module_training: module.train() + def _get_updated_activation_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + + def _get_updated_weight_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From 74d7db5355d5acdd86e28589e6f04b4551b4a499 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:35:49 -0500 Subject: [PATCH 101/218] Added _Add_ReLU module that enables QATWrapper for quantizaiton. --- .../pytorch/models/classification/resnet.py | 66 +++++++++---------- 1 file changed, 30 insertions(+), 36 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index cd8b979c3ad..3112da7c2e1 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,7 +41,6 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: @@ -141,19 +140,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - """ - Wrapper for the FloatFunctional class that enables QATWrapper used to - quantize the first input to the Add operation - """ - - def __init__(self, num_channels): + def __init__(self): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: - self.functional = ReLU(num_channels=num_channels, inplace=True) + self.functional = ReLU(num_channels=out_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -185,7 +179,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU(out_channels) + self.add_relu = _AddReLU() self.initialize() @@ -211,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -242,7 +236,7 @@ def __init__( else None ) - self.add_relu = _AddReLU(out_channels) + self.add_relu = _AddReLU() self.initialize() @@ -327,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -443,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -485,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 71f25cd281d1a067231cbc220456c18064212bf8 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:36:37 -0500 Subject: [PATCH 102/218] Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. --- .../sparsification/quantization/helpers.py | 6 +-- .../quantization/modifier_quantization.py | 39 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index e10224bbce7..ec69ded82c8 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,7 +32,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ - "QUANTIZABLE_MODULE_TYPES", "QATWrapper", "configure_module_qat_wrappers", "configure_module_default_qconfigs", @@ -45,7 +44,7 @@ "prepare_embeddings_qat", ] -QUANTIZABLE_MODULE_TYPES = ( +_QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -150,6 +149,7 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -398,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in QUANTIZABLE_MODULE_TYPES + type(module) in _QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 79772790566..f914b1f2b91 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,7 +47,6 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( - QUANTIZABLE_MODULE_TYPES, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, @@ -94,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use - the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' + to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,10 +142,10 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, - quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, + exclude_batchnorm: bool = True, exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, @@ -175,9 +174,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations - self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits + self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -233,7 +232,8 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - return self._model_fuse_fn_name + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + return fuse_fn @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -416,10 +416,7 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if ( - type(submodule) in QUANTIZABLE_MODULE_TYPES - and name in self._submodules - ): + if name in self._submodules: self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -509,10 +506,10 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if ( - self._model_fuse_fn_name is not None - and self._model_fuse_fn_name != "no_fuse" - ): # module class fn + if self._model_fuse_fn_name == 'conv_bn_relus': + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -522,16 +519,13 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) - elif self._model_fuse_fn_name is None: # default auto fn - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLu"]) + to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -565,8 +559,15 @@ def _enable_module_qat(self, module: Module): remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types + to_exclude = [] if self._exclude_module_types: - self._strip_excluded_module_qconfigs(module) + to_exclude.extend(self._exclude_module_types) + + if self._exclude_batchnorm: + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + + self._exclude_module_types = to_exclude + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) From 04c436fd2aa62d833fda3cf4206c63dc4ec6ce0a Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:40:31 -0500 Subject: [PATCH 103/218] Minor fixes. Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 61 ++++----- .../sparsification/quantization/helpers.py | 123 +++++++++--------- .../quantization/modifier_quantization.py | 33 +++-- 3 files changed, 112 insertions(+), 105 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3112da7c2e1..be4182891d6 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -140,14 +141,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self): + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: - self.functional = ReLU(num_channels=out_channels, inplace=True) + self.functional = ReLU(num_channels=num_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +180,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -205,12 +206,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -236,7 +237,7 @@ def __init__( else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -321,12 +322,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -437,15 +438,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -479,10 +480,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index ec69ded82c8..2c1ac640d6e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_qat_wrappers", @@ -105,10 +107,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -140,7 +142,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -153,20 +155,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -285,12 +287,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -331,10 +333,10 @@ def _load_qconfigs( def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -383,7 +385,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -398,9 +400,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -424,7 +426,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -433,11 +435,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -509,7 +511,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -534,9 +536,9 @@ def fix_observer_quant_range(module: Module): def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -571,14 +573,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -615,11 +617,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -649,17 +651,10 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index f914b1f2b91..637bf7e52dd 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -93,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' - to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as + 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -232,7 +232,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" return fuse_fn @model_fuse_fn_name.setter @@ -356,7 +356,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -506,7 +505,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': + if self._model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -529,10 +528,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) # prepare each module / submodule for quantization @@ -564,7 +573,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -634,7 +643,9 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits + ) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From b76d59e6b190a5849a63d8bb19fc9536bbf595a7 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:02:14 -0500 Subject: [PATCH 104/218] Added support to freezing bn stats. --- .../sparsification/quantization/helpers.py | 209 +++++++++++++----- .../quantization/modifier_quantization.py | 37 ++-- 2 files changed, 164 insertions(+), 82 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 2c1ac640d6e..a44369550b1 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,16 +31,17 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", - "configure_module_qat_wrappers", + "configure_module_bn_wrappers", "configure_module_default_qconfigs", + "configure_module_qat_wrappers", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", + "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] @@ -69,6 +69,54 @@ else None ) +_BN_MODULE_TYPES = ( + { + # Conv based layers + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.ConvBnReLU1d, + nni.ConvBnReLU2d, + nni.ConvBnReLU3d, + } + if nni # nni will always import if torch.quantization is available + else {} +) + + +class BNWrapper(Module): + def __init__(self, module: Module): + super().__init__() + self.bn = module + self.freeze_bn = False + + def forward(self, x): + return self.bn(x) + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def train(self, mode=True): + if not self.freeze_bn: + self.bn.train() + return self + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + +_BN_MODULE_TYPES.add(BNWrapper) + class QATWrapper(Module): """ @@ -107,10 +155,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -142,7 +190,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -155,20 +203,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -287,12 +335,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -332,11 +380,40 @@ def _load_qconfigs( return qconfigs +def configure_module_bn_wrappers(module: Module): + """ + if any submodule of the given module has the attribute wrap_qat == True, + then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. + Other named kwargs to the QATWrapper constructor must be contained in a dictionary + under an attributed named `qat_wrapper_kwargs` + + :param module: module to potentially wrap the submodules of + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware + Default is False + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. Default is {} + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. Default is {} + """ + # wrap any children of the given module as a QATWrapper if required + if type(module) != BNWrapper: + for child_name, child_module in module.named_children(): + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) + # recurse on child module + configure_module_bn_wrappers(child_module) + + def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -385,7 +462,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -400,9 +477,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -426,7 +503,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -435,11 +512,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -511,7 +588,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -535,10 +612,15 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True +def freeze_bn_stats(module: Module): + if type(module) in _BN_MODULE_TYPES: + module.freeze_bn_stats() + + def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -573,14 +655,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -617,11 +699,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -651,10 +733,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 637bf7e52dd..7eed410b441 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -48,9 +48,11 @@ from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( add_quant_dequant, + configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, + freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, get_updated_qconfig_kwargs, @@ -232,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' return fuse_fn @model_fuse_fn_name.setter @@ -262,7 +264,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """ + """print :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -356,6 +358,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -500,12 +503,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) + quant_module.apply(freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == "conv_bn_relus": + if self._model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -528,22 +531,14 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) + configure_module_bn_wrappers(module) + # prepare each module / submodule for quantization qconfig = get_qat_qconfig( reduce_range=self._reduce_range, @@ -573,7 +568,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -643,9 +638,7 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits - ) + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From 81bc2d039ab6524689f976c83cdf5235d9f54dd4 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:57:15 -0500 Subject: [PATCH 105/218] Added mode argument to wrapping of train function in BNWrapper --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index a44369550b1..48ed0708eae 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -106,7 +106,7 @@ def reset_running_stats(self): def train(self, mode=True): if not self.freeze_bn: - self.bn.train() + self.bn.train(mode) return self def update_bn_stats(self): From a009d20e5ef665ee291e2692768623b70e8beded Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:08:20 -0500 Subject: [PATCH 106/218] Set BN fusing back as default. --- .../sparsification/quantization/modifier_quantization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 7eed410b441..37307e38863 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -508,8 +508,8 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True + if self.model_fuse_fn_name == 'conv_bn_relus': + self.model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From b736ce745f2137429cf8ac1dfa05bc24a935c33f Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:19:09 -0500 Subject: [PATCH 107/218] Set BN fusing back as default. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- .../sparsification/quantization/modifier_quantization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 48ed0708eae..71f6553fc44 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -397,7 +397,7 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) != BNWrapper: + if type(module) not in _BN_MODULE_TYPES: for child_name, child_module in module.named_children(): if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: setattr( diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 37307e38863..2a35ebd2aaf 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -509,7 +509,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs if self.model_fuse_fn_name == 'conv_bn_relus': - self.model_fuse_fn_kwargs["inplace"] = True + self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From e56066e959cdb5312f142a4f6e21cc5292f7eaa4 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 11 Mar 2022 19:24:03 -0500 Subject: [PATCH 108/218] Fixed custom freeze_bn_stats. --- .../sparsification/quantization/helpers.py | 245 +++++++++++------- .../quantization/modifier_quantization.py | 46 +++- 2 files changed, 182 insertions(+), 109 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 71f6553fc44..e09c0e29690 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,23 +71,6 @@ else None ) -_BN_MODULE_TYPES = ( - { - # Conv based layers - nni.ConvBn1d, - nni.ConvBn2d, - nni.ConvBn3d, - nni.ConvReLU1d, - nni.ConvReLU2d, - nni.ConvReLU3d, - nni.ConvBnReLU1d, - nni.ConvBnReLU2d, - nni.ConvBnReLU3d, - } - if nni # nni will always import if torch.quantization is available - else {} -) - class BNWrapper(Module): def __init__(self, module: Module): @@ -93,6 +78,78 @@ def __init__(self, module: Module): self.bn = module self.freeze_bn = False + @property + def running_mean(self): + return self.bn.running_mean + + @running_mean.setter + def running_mean(self, value): + self.bn.running_mean = value + + @property + def running_var(self): + return self.bn.running_var + + @running_var.setter + def running_var(self, value): + self.bn.running_var = value + + @property + def weight(self): + return self.bn.weight + + @weight.setter + def weight(self, value): + self.bn.weight = value + + @property + def bias(self): + return self.bn.bias + + @bias.setter + def bias(self, value): + self.bn.bias = value + + @property + def gamma(self): + return self.bn.gamma + + @gamma.setter + def gamma(self, value): + self.bn.gamma = value + + @property + def beta(self): + return self.bn.beta + + @beta.setter + def beta(self, value): + self.bn.beta = value + + @property + def num_batches_tracked(self): + return self.bn.num_batches_tracked + + @num_batches_tracked.setter + def num_batches_tracked(self, value): + self.bn.num_batches_tracked = value + + @property + def eps(self): + return self.bn.eps + + @eps.setter + def eps(self, value): + self.bn.eps = value + + @property + def momentum(self): + return self.bn.momentum + + @momentum.setter + def momentum(self, value): + self.bn.momentum = value + def forward(self, x): return self.bn(x) @@ -115,9 +172,6 @@ def update_bn_stats(self): return self -_BN_MODULE_TYPES.add(BNWrapper) - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -155,10 +209,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -190,7 +244,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -203,20 +257,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -335,12 +389,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -397,23 +451,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) not in _BN_MODULE_TYPES: + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -462,7 +516,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -477,9 +531,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -503,7 +557,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -512,11 +566,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -588,7 +642,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -613,14 +667,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if type(module) in _BN_MODULE_TYPES: + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -655,14 +709,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -699,11 +753,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -732,26 +786,25 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 + if mode == "symmetric": + quant_min = -(2 ** (bits - 1)) + quant_max = 2 ** (bits - 1) - 1 + dtype = torch.qint8 + else: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 2a35ebd2aaf..acbae885d71 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,9 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -332,7 +334,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_linear_output_activations + return self._quantize_conv_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -358,7 +360,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -504,11 +505,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) + # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -531,11 +533,23 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) + if len(to_remove_layer_name) == 0: + to_remove_layer_name = None configure_module_bn_wrappers(module) @@ -560,7 +574,8 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + if to_remove_layer_name: + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types to_exclude = [] @@ -568,10 +583,11 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude - self._strip_excluded_module_qconfigs(module) + if self._exclude_module_types: + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) @@ -638,10 +654,14 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" + ) def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + return get_updated_qconfig_kwargs( + self.weight_qconfig_kwargs, self.weight_bits, "symmetric" + ) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( From 75dd790df27472fcae059b60117b61e84fa1b9f0 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 14 Mar 2022 15:35:52 -0400 Subject: [PATCH 109/218] Temporary files for evaluating changes to graphs. --- .../pytorch/models/classification/resnet.py | 53 +++++++++---------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index be4182891d6..21611f211d7 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,7 +41,6 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: @@ -146,7 +145,7 @@ def __init__(self, num_channels): if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -206,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -322,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -438,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -480,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 4fd2ced7ea7787799feddc4be3b61a508dca9dc3 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 17 Mar 2022 11:51:50 -0400 Subject: [PATCH 110/218] Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. --- .../sparsification/quantization/helpers.py | 207 ++++++++++-------- .../quantization/modifier_quantization.py | 58 ++--- 2 files changed, 134 insertions(+), 131 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index e09c0e29690..57b919470e4 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,7 +31,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -209,10 +207,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -244,7 +242,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -257,20 +255,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -389,12 +387,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -451,23 +449,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, "freeze_bn_stats"): + if not hasattr(module, 'freeze_bn_stats'): for child_name, child_module in module.named_children(): - if type(child_module) in [ - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - ]: - setattr(module, child_name, BNWrapper(child_module)) + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -506,6 +504,17 @@ def configure_module_qat_wrappers( ) +def compute_range(dtype: torch.dtype, bits: int): + if dtype == torch.qint8: + quant_min = -2 ** (bits - 1) + quant_max = 2 ** (bits - 1) - 1 + elif dtype == torch.quint8: + quant_min = 0 + quant_max = 2 ** bits - 1 + + return quant_min, quant_max + + def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -516,7 +525,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -531,9 +540,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -557,7 +566,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -566,11 +575,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = torch.quint8, + weight_dtype: Optional[torch.dtype] = torch.qint8, + activation_bits: Optional[int] = 8, + weight_bits: Optional[int] = 8, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -591,42 +604,35 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_qscheme = ( - torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine - ) - activation_observer_kwargs = dict( - observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - dtype=torch.quint8, - qscheme=activation_qscheme, - reduce_range=reduce_range, - ) - activation_observer_kwargs.update(activation_qconfig_kwargs or {}) - activation_observer = torch_quantization.FakeQuantize.with_args( - **activation_observer_kwargs, + activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, + activation_qconfig_kwargs) + weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, ) - weight_qscheme = ( - torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + + +def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): + qscheme = ( + torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - weight_observer_kwargs = dict( + quant_min, quant_max = compute_range(dtype, bits) + observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=-128, - quant_max=127, - dtype=torch.qint8, - qscheme=weight_qscheme, + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + qscheme=qscheme, reduce_range=reduce_range, ) - - weight_observer_kwargs.update(weight_qconfig_kwargs or {}) - weight_observer = torch_quantization.FakeQuantize.with_args( - **weight_observer_kwargs, - ) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + observer_kwargs.update(qconfig_kwargs or {}) + observer = torch_quantization.FakeQuantize.with_args( + **observer_kwargs, ) + return observer + def fix_observer_quant_range(module: Module): """ @@ -642,7 +648,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -667,14 +673,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, "freeze_bn_stats"): + if hasattr(module, 'freeze_bn_stats'): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -709,14 +715,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -753,11 +759,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -787,17 +793,24 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) + quant_min = -2 ** (bits - 1) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index acbae885d71..5a5e1913b18 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -55,7 +55,6 @@ freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, - get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -151,6 +150,7 @@ def __init__( exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + tensorrt: Optional[bool] = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -187,6 +187,7 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -234,9 +235,10 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = ( - self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" - ) + if self._tensorrt: + fuse_fn = 'no_fuse' + else: + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -360,6 +362,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -505,12 +508,11 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) - # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == "conv_bn_relus": + if self.model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -524,29 +526,16 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) - activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() - weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() - to_remove_layer_name = [] if not self._quantize_linear_output_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -554,10 +543,21 @@ def _enable_module_qat(self, module: Module): configure_module_bn_wrappers(module) # prepare each module / submodule for quantization + if self.tensorrt: + _symmetric_activations = True + _activations_dtype = torch.qint8 + else: + _symmetric_activations = False + _activations_dtype = torch.quint8 + qconfig = get_qat_qconfig( + symmetric_activations=_symmetric_activations, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=_activations_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) @@ -583,7 +583,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude if self._exclude_module_types: @@ -653,16 +653,6 @@ def _calibrate(self, module): if module_training: module.train() - def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" - ) - - def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.weight_qconfig_kwargs, self.weight_bits, "symmetric" - ) - def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From 29e32ede9fcdefa372e10612077355ebfaf21fec Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Sun, 20 Mar 2022 11:42:14 -0400 Subject: [PATCH 111/218] Added support to TensorRT quantization --- .../sparsification/quantization/helpers.py | 166 ++++++++++++++++-- .../quantization/modifier_quantization.py | 61 +++++-- 2 files changed, 195 insertions(+), 32 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 57b919470e4..2ae713c16aa 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -208,9 +208,15 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - reduce_range: bool = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -232,6 +238,18 @@ def from_module( else {} ) + qat_wrapper_kwargs["symmetric_activations"] = ( + symmetric_activations + if "symmetric_activations" not in qat_wrapper_kwargs + else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] + ) + + qat_wrapper_kwargs["symmetric_weights"] = ( + symmetric_weights or False + if "symmetric_weights" not in qat_wrapper_kwargs + else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] + ) + qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -251,6 +269,30 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + qat_wrapper_kwargs["activation_dtype"] = ( + activation_dtype + if "activation_dtype" not in qat_wrapper_kwargs + else activation_dtype or qat_wrapper_kwargs["activation_dtype"] + ) + + qat_wrapper_kwargs["weight_dtype"] = ( + weight_dtype + if "weight_dtype" not in qat_wrapper_kwargs + else weight_dtype or qat_wrapper_kwargs["weight_dtype"] + ) + + qat_wrapper_kwargs["activation_bits"] = ( + activation_bits + if "activation_bits" not in qat_wrapper_kwargs + else activation_bits or qat_wrapper_kwargs["activation_bits"] + ) + + qat_wrapper_kwargs["weight_bits"] = ( + weight_bits + if "weight_bits" not in qat_wrapper_kwargs + else weight_bits or qat_wrapper_kwargs["weight_bits"] + ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -266,9 +308,15 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -288,25 +336,43 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn + self._symmetric_activations = symmetric_activations + self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._activation_dtype = activation_dtype + self._weight_dtype = weight_dtype + self._activation_bits = activation_bits + self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -390,9 +456,15 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -422,11 +494,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) + if symmetric_activations is None: + _symmetric_activations = qconfig == "symmetric" + else: + _symmetric_activations = symmetric_activations + qconfigs[idx] = get_qat_qconfig( - symmetric_activations=(qconfig == "symmetric"), + symmetric_activations=_symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) return qconfigs @@ -463,9 +545,15 @@ def configure_module_bn_wrappers(module: Module): def configure_module_qat_wrappers( module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -490,29 +578,43 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) -def compute_range(dtype: torch.dtype, bits: int): +def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + dtype = dtype if dtype else torch.quint8 + bits = bits if bits else 8 if dtype == torch.qint8: - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 + quant_min = -(2 ** (bits - 1)) + quant_max = (2 ** (bits - 1)) - 1 elif dtype == torch.quint8: quant_min = 0 - quant_max = 2 ** bits - 1 + quant_max = (2 ** bits) - 1 - return quant_min, quant_max + return quant_min, quant_max, dtype def configure_module_default_qconfigs(module: Module): @@ -575,15 +677,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = torch.quint8, - weight_dtype: Optional[torch.dtype] = torch.qint8, - activation_bits: Optional[int] = 8, - weight_bits: Optional[int] = 8, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -606,18 +708,28 @@ def get_qat_qconfig( """ activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, activation_qconfig_kwargs) - weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + if symmetric_weights is None: + _symmetric_weights = True + else: + _symmetric_weights = symmetric_weights + + if weight_dtype is None: + _weight_dtype = torch.qint8 + else: + _weight_dtype = weight_dtype + + weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): +def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): qscheme = ( torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - quant_min, quant_max = compute_range(dtype, bits) + quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, quant_min=quant_min, @@ -761,9 +873,15 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -781,11 +899,21 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: + if symmetric_weights is None: + _symmetric_weights = False + else: + _symmetric_weights = symmetric_weights + qconfig = get_qat_qconfig( - symmetric_weights=False, + symmetric_activations=symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 5a5e1913b18..27c5a4c336e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -147,10 +147,10 @@ def __init__( weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_batchnorm: bool = True, - exclude_module_types: Union[List[str], None] = None, + exclude_module_types: Optional[List[str]] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: Optional[bool] = False, + tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -379,7 +379,15 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - return self._weight_qconfig_kwargs + if "observer" in self._weight_qconfig_kwargs: + kwargs = self._weight_qconfig_kwargs.copy() + if kwargs["observer"] == "minmaxobserver": + kwargs["observer"] = torch_quantization.MinMaxObserver + return kwargs + else: + return self._weight_qconfig_kwargs + + @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -389,6 +397,15 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps + @ModifierProp() + def tensorrt(self) -> Dict[str, Any]: + """ + :return: Dictionary with correct quant_min, quant_max, and dtype values + for activations + + """ + return self._tensorrt + def initialize( self, module: Module, @@ -545,17 +562,23 @@ def _enable_module_qat(self, module: Module): # prepare each module / submodule for quantization if self.tensorrt: _symmetric_activations = True - _activations_dtype = torch.qint8 + _activation_dtype = torch.qint8 + _symmetric_weights = True + _weight_dtype = torch.qint8 else: - _symmetric_activations = False - _activations_dtype = torch.quint8 + _symmetric_activations = None + _activation_dtype = None + _symmetric_weights = None + _weight_dtype = None qconfig = get_qat_qconfig( symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=_activations_dtype, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, activation_bits=self.activation_bits, weight_bits=self.weight_bits ) @@ -563,9 +586,15 @@ def _enable_module_qat(self, module: Module): # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -594,9 +623,15 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, ) # propagate custom quant min/max range from FakeQuantize to Observer objects From 7af19e350e5604bb927d70b1f3a26a3e14fbeb46 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 21 Mar 2022 19:16:26 -0400 Subject: [PATCH 112/218] Included check to account for when weight_qconfig_kwatgs is None. --- .../sparsification/quantization/modifier_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 27c5a4c336e..a306f4d8e73 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -379,7 +379,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if "observer" in self._weight_qconfig_kwargs: + if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver From a6b2b688e27a82ff27d37657788a2117206177f8 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 12:14:39 -0500 Subject: [PATCH 113/218] Removed output quantization from conv layers --- .../sparsification/quantization/helpers.py | 377 +++--------------- .../quantization/modifier_quantization.py | 130 ++---- 2 files changed, 87 insertions(+), 420 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 2ae713c16aa..75d11c67c31 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,21 +32,20 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ + "QUANTIZABLE_MODULE_TYPES", "QATWrapper", - "configure_module_bn_wrappers", - "configure_module_default_qconfigs", "configure_module_qat_wrappers", + "configure_module_default_qconfigs", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", - "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] -_QUANTIZABLE_MODULE_TYPES = ( +QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -70,106 +69,6 @@ ) -class BNWrapper(Module): - def __init__(self, module: Module): - super().__init__() - self.bn = module - self.freeze_bn = False - - @property - def running_mean(self): - return self.bn.running_mean - - @running_mean.setter - def running_mean(self, value): - self.bn.running_mean = value - - @property - def running_var(self): - return self.bn.running_var - - @running_var.setter - def running_var(self, value): - self.bn.running_var = value - - @property - def weight(self): - return self.bn.weight - - @weight.setter - def weight(self, value): - self.bn.weight = value - - @property - def bias(self): - return self.bn.bias - - @bias.setter - def bias(self, value): - self.bn.bias = value - - @property - def gamma(self): - return self.bn.gamma - - @gamma.setter - def gamma(self, value): - self.bn.gamma = value - - @property - def beta(self): - return self.bn.beta - - @beta.setter - def beta(self, value): - self.bn.beta = value - - @property - def num_batches_tracked(self): - return self.bn.num_batches_tracked - - @num_batches_tracked.setter - def num_batches_tracked(self, value): - self.bn.num_batches_tracked = value - - @property - def eps(self): - return self.bn.eps - - @eps.setter - def eps(self, value): - self.bn.eps = value - - @property - def momentum(self): - return self.bn.momentum - - @momentum.setter - def momentum(self, value): - self.bn.momentum = value - - def forward(self, x): - return self.bn(x) - - def freeze_bn_stats(self): - self.freeze_bn = True - self.bn.training = False - return self - - def reset_running_stats(self): - self.bn.reset_running_stats() - - def train(self, mode=True): - if not self.freeze_bn: - self.bn.train(mode) - return self - - def update_bn_stats(self): - self.freeze_bn = False - self.bn.training = True - return self - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -208,15 +107,9 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, + reduce_range: bool = None, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -238,18 +131,6 @@ def from_module( else {} ) - qat_wrapper_kwargs["symmetric_activations"] = ( - symmetric_activations - if "symmetric_activations" not in qat_wrapper_kwargs - else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] - ) - - qat_wrapper_kwargs["symmetric_weights"] = ( - symmetric_weights or False - if "symmetric_weights" not in qat_wrapper_kwargs - else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] - ) - qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -269,31 +150,6 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) - qat_wrapper_kwargs["activation_dtype"] = ( - activation_dtype - if "activation_dtype" not in qat_wrapper_kwargs - else activation_dtype or qat_wrapper_kwargs["activation_dtype"] - ) - - qat_wrapper_kwargs["weight_dtype"] = ( - weight_dtype - if "weight_dtype" not in qat_wrapper_kwargs - else weight_dtype or qat_wrapper_kwargs["weight_dtype"] - ) - - qat_wrapper_kwargs["activation_bits"] = ( - activation_bits - if "activation_bits" not in qat_wrapper_kwargs - else activation_bits or qat_wrapper_kwargs["activation_bits"] - ) - - qat_wrapper_kwargs["weight_bits"] = ( - weight_bits - if "weight_bits" not in qat_wrapper_kwargs - else weight_bits or qat_wrapper_kwargs["weight_bits"] - ) - - module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -308,15 +164,9 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ): super().__init__() @@ -336,43 +186,25 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn - self._symmetric_activations = symmetric_activations - self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._activation_dtype = activation_dtype - self._weight_dtype = weight_dtype - self._activation_bits = activation_bits - self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -456,15 +288,9 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -494,66 +320,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) - if symmetric_activations is None: - _symmetric_activations = qconfig == "symmetric" - else: - _symmetric_activations = symmetric_activations - qconfigs[idx] = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=symmetric_weights, + symmetric_activations=(qconfig == "symmetric"), reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) return qconfigs -def configure_module_bn_wrappers(module: Module): - """ - if any submodule of the given module has the attribute wrap_qat == True, - then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. - Other named kwargs to the QATWrapper constructor must be contained in a dictionary - under an attributed named `qat_wrapper_kwargs` - - :param module: module to potentially wrap the submodules of - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} - """ - # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, 'freeze_bn_stats'): - for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) - # recurse on child module - configure_module_bn_wrappers(child_module) - - def configure_module_qat_wrappers( module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -578,45 +359,20 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) -def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): - dtype = dtype if dtype else torch.quint8 - bits = bits if bits else 8 - if dtype == torch.qint8: - quant_min = -(2 ** (bits - 1)) - quant_max = (2 ** (bits - 1)) - 1 - elif dtype == torch.quint8: - quant_min = 0 - quant_max = (2 ** bits) - 1 - - return quant_min, quant_max, dtype - - def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -642,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES + type(module) in QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): @@ -677,15 +433,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -706,44 +458,41 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, - activation_qconfig_kwargs) - if symmetric_weights is None: - _symmetric_weights = True - else: - _symmetric_weights = symmetric_weights - - if weight_dtype is None: - _weight_dtype = torch.qint8 - else: - _weight_dtype = weight_dtype - - weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + activation_qscheme = ( + torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine ) - - -def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): - qscheme = ( - torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine - ) - quant_min, quant_max, dtype = compute_range(dtype, bits) - observer_kwargs = dict( + activation_observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - qscheme=qscheme, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=activation_qscheme, reduce_range=reduce_range, ) - observer_kwargs.update(qconfig_kwargs or {}) - observer = torch_quantization.FakeQuantize.with_args( - **observer_kwargs, + activation_observer_kwargs.update(activation_qconfig_kwargs or {}) + activation_observer = torch_quantization.FakeQuantize.with_args( + **activation_observer_kwargs, + ) + weight_qscheme = ( + torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + ) + weight_observer_kwargs = dict( + observer=torch_quantization.MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=weight_qscheme, + reduce_range=reduce_range, ) - return observer + weight_observer_kwargs.update(weight_qconfig_kwargs or {}) + weight_observer = torch_quantization.FakeQuantize.with_args( + **weight_observer_kwargs, + ) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, + ) def fix_observer_quant_range(module: Module): @@ -769,14 +518,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) - or ( # do not propagate default uint8 symmetric range - observer.qscheme == torch.per_tensor_symmetric - and fake_quantize.quant_min == 0 - and fake_quantize.quant_max == 255 - ) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -784,11 +528,6 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True -def freeze_bn_stats(module: Module): - if hasattr(module, 'freeze_bn_stats'): - module.freeze_bn_stats() - - def fuse_module_conv_bn_relus( module: Module, inplace: bool = True, @@ -873,15 +612,9 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -899,28 +632,18 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: - if symmetric_weights is None: - _symmetric_weights = False - else: - _symmetric_weights = symmetric_weights - qconfig = get_qat_qconfig( - symmetric_activations=symmetric_activations, - symmetric_weights=_symmetric_weights, + symmetric_weights=False, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): +def get_updated_qconfig_kwargs(qconfig_kwargs, bits): qconfig_kwargs = ( qconfig_kwargs.copy() if qconfig_kwargs @@ -937,15 +660,9 @@ def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): ) if bits: - if mode == "symmetric": - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 - dtype = torch.qint8 - else: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 - + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index a306f4d8e73..79772790566 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,14 +47,14 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( + QUANTIZABLE_MODULE_TYPES, add_quant_dequant, - configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, - freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, + get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -94,8 +94,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as - 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use + the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,14 +143,13 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, + quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, - exclude_batchnorm: bool = True, - exclude_module_types: Optional[List[str]] = None, + exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -176,9 +175,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits - self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -187,7 +186,6 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -235,11 +233,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - if self._tensorrt: - fuse_fn = 'no_fuse' - else: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' - return fuse_fn + return self._model_fuse_fn_name @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -268,7 +262,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """print + """ :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -336,7 +330,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_conv_output_activations + return self._quantize_linear_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -379,15 +373,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: - kwargs = self._weight_qconfig_kwargs.copy() - if kwargs["observer"] == "minmaxobserver": - kwargs["observer"] = torch_quantization.MinMaxObserver - return kwargs - else: - return self._weight_qconfig_kwargs - - + return self._weight_qconfig_kwargs @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -397,15 +383,6 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps - @ModifierProp() - def tensorrt(self) -> Dict[str, Any]: - """ - :return: Dictionary with correct quant_min, quant_max, and dtype values - for activations - - """ - return self._tensorrt - def initialize( self, module: Module, @@ -439,7 +416,10 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if name in self._submodules: + if ( + type(submodule) in QUANTIZABLE_MODULE_TYPES + and name in self._submodules + ): self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -524,15 +504,15 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(freeze_bn_stats) + quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) - elif self.model_fuse_fn_name != "no_fuse": + if ( + self._model_fuse_fn_name is not None + and self._model_fuse_fn_name != "no_fuse" + ): # module class fn module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -542,10 +522,16 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) + elif self._model_fuse_fn_name is None: # default auto fn + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + + activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() + weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLU"]) + to_remove_layer_name.extend(["Linear", "LinearReLu"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -554,47 +540,20 @@ def _enable_module_qat(self, module: Module): "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) - if len(to_remove_layer_name) == 0: - to_remove_layer_name = None - - configure_module_bn_wrappers(module) # prepare each module / submodule for quantization - if self.tensorrt: - _symmetric_activations = True - _activation_dtype = torch.qint8 - _symmetric_weights = True - _weight_dtype = torch.qint8 - else: - _symmetric_activations = None - _activation_dtype = None - _symmetric_weights = None - _weight_dtype = None - qconfig = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -603,18 +562,9 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - if to_remove_layer_name: - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types - to_exclude = [] - if self._exclude_module_types: - to_exclude.extend(self._exclude_module_types) - - if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) - - self._exclude_module_types = to_exclude if self._exclude_module_types: self._strip_excluded_module_qconfigs(module) @@ -623,15 +573,9 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # propagate custom quant min/max range from FakeQuantize to Observer objects @@ -688,6 +632,12 @@ def _calibrate(self, module): if module_training: module.train() + def _get_updated_activation_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + + def _get_updated_weight_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From cbc8551238543fd065a48098fd424ec2430725aa Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:35:49 -0500 Subject: [PATCH 114/218] Added _Add_ReLU module that enables QATWrapper for quantizaiton. --- src/sparseml/pytorch/models/classification/resnet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 21611f211d7..3112da7c2e1 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -140,14 +140,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self, num_channels): + def __init__(self): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: - self.functional = ReLU(num_channels=num_channels, inplace=True) + self.functional = ReLU(num_channels=out_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +179,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU(out_channels) + self.add_relu = _AddReLU() self.initialize() @@ -236,7 +236,7 @@ def __init__( else None ) - self.add_relu = _AddReLU(out_channels) + self.add_relu = _AddReLU() self.initialize() From 475fd7e88c3aaeab18227cf97a12ecec77d16ce8 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:36:37 -0500 Subject: [PATCH 115/218] Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. --- .../sparsification/quantization/helpers.py | 6 +-- .../quantization/modifier_quantization.py | 39 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 75d11c67c31..f28656f1712 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,7 +32,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ - "QUANTIZABLE_MODULE_TYPES", "QATWrapper", "configure_module_qat_wrappers", "configure_module_default_qconfigs", @@ -45,7 +44,7 @@ "prepare_embeddings_qat", ] -QUANTIZABLE_MODULE_TYPES = ( +_QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -150,6 +149,7 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -398,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in QUANTIZABLE_MODULE_TYPES + type(module) in _QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 79772790566..f914b1f2b91 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,7 +47,6 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( - QUANTIZABLE_MODULE_TYPES, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, @@ -94,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use - the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' + to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,10 +142,10 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, - quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, + exclude_batchnorm: bool = True, exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, @@ -175,9 +174,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations - self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits + self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -233,7 +232,8 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - return self._model_fuse_fn_name + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + return fuse_fn @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -416,10 +416,7 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if ( - type(submodule) in QUANTIZABLE_MODULE_TYPES - and name in self._submodules - ): + if name in self._submodules: self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -509,10 +506,10 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if ( - self._model_fuse_fn_name is not None - and self._model_fuse_fn_name != "no_fuse" - ): # module class fn + if self._model_fuse_fn_name == 'conv_bn_relus': + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -522,16 +519,13 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) - elif self._model_fuse_fn_name is None: # default auto fn - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLu"]) + to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -565,8 +559,15 @@ def _enable_module_qat(self, module: Module): remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types + to_exclude = [] if self._exclude_module_types: - self._strip_excluded_module_qconfigs(module) + to_exclude.extend(self._exclude_module_types) + + if self._exclude_batchnorm: + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + + self._exclude_module_types = to_exclude + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) From 0d0f757d4d6e11295fb00641670c5bc08c667fbd Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:40:31 -0500 Subject: [PATCH 116/218] Minor fixes. Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 61 +++++---- .../sparsification/quantization/helpers.py | 129 +++++++++--------- .../quantization/modifier_quantization.py | 33 +++-- 3 files changed, 115 insertions(+), 108 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3112da7c2e1..be4182891d6 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -140,14 +141,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self): + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: - self.functional = ReLU(num_channels=out_channels, inplace=True) + self.functional = ReLU(num_channels=num_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +180,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -205,12 +206,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -236,7 +237,7 @@ def __init__( else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -321,12 +322,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -437,15 +438,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -479,10 +480,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index f28656f1712..ef4445a0d5f 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_qat_wrappers", @@ -105,10 +107,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -140,7 +142,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -153,20 +155,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -285,12 +287,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -331,10 +333,10 @@ def _load_qconfigs( def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -383,7 +385,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -398,9 +400,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -424,7 +426,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -433,11 +435,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -509,7 +511,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -518,9 +520,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -529,9 +531,9 @@ def fix_observer_quant_range(module: Module): def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -566,14 +568,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -610,11 +612,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -644,17 +646,10 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index f914b1f2b91..637bf7e52dd 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -93,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' - to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as + 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -232,7 +232,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" return fuse_fn @model_fuse_fn_name.setter @@ -356,7 +356,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -506,7 +505,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': + if self._model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -529,10 +528,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) # prepare each module / submodule for quantization @@ -564,7 +573,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -634,7 +643,9 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits + ) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From 0e05c94f791e0125e2b5071b0701dd7e86e66a2e Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:02:14 -0500 Subject: [PATCH 117/218] Added support to freezing bn stats. --- .../sparsification/quantization/helpers.py | 215 +++++++++++++----- .../quantization/modifier_quantization.py | 37 ++- 2 files changed, 167 insertions(+), 85 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index ef4445a0d5f..c4f165d23ef 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,16 +31,17 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", - "configure_module_qat_wrappers", + "configure_module_bn_wrappers", "configure_module_default_qconfigs", + "configure_module_qat_wrappers", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", + "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] @@ -69,6 +69,54 @@ else None ) +_BN_MODULE_TYPES = ( + { + # Conv based layers + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.ConvBnReLU1d, + nni.ConvBnReLU2d, + nni.ConvBnReLU3d, + } + if nni # nni will always import if torch.quantization is available + else {} +) + + +class BNWrapper(Module): + def __init__(self, module: Module): + super().__init__() + self.bn = module + self.freeze_bn = False + + def forward(self, x): + return self.bn(x) + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def train(self, mode=True): + if not self.freeze_bn: + self.bn.train() + return self + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + +_BN_MODULE_TYPES.add(BNWrapper) + class QATWrapper(Module): """ @@ -107,10 +155,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -142,7 +190,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -155,20 +203,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -287,12 +335,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -332,11 +380,40 @@ def _load_qconfigs( return qconfigs +def configure_module_bn_wrappers(module: Module): + """ + if any submodule of the given module has the attribute wrap_qat == True, + then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. + Other named kwargs to the QATWrapper constructor must be contained in a dictionary + under an attributed named `qat_wrapper_kwargs` + + :param module: module to potentially wrap the submodules of + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware + Default is False + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. Default is {} + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. Default is {} + """ + # wrap any children of the given module as a QATWrapper if required + if type(module) != BNWrapper: + for child_name, child_module in module.named_children(): + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) + # recurse on child module + configure_module_bn_wrappers(child_module) + + def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -385,7 +462,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -400,9 +477,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -426,7 +503,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -435,11 +512,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -511,7 +588,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -520,9 +597,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -530,10 +607,15 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True +def freeze_bn_stats(module: Module): + if type(module) in _BN_MODULE_TYPES: + module.freeze_bn_stats() + + def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -568,14 +650,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -612,11 +694,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -646,10 +728,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 637bf7e52dd..7eed410b441 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -48,9 +48,11 @@ from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( add_quant_dequant, + configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, + freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, get_updated_qconfig_kwargs, @@ -232,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' return fuse_fn @model_fuse_fn_name.setter @@ -262,7 +264,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """ + """print :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -356,6 +358,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -500,12 +503,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) + quant_module.apply(freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == "conv_bn_relus": + if self._model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -528,22 +531,14 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) + configure_module_bn_wrappers(module) + # prepare each module / submodule for quantization qconfig = get_qat_qconfig( reduce_range=self._reduce_range, @@ -573,7 +568,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -643,9 +638,7 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits - ) + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From 9dce86ef6439cf9e713d337df107b7b20b1e5446 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:57:15 -0500 Subject: [PATCH 118/218] Added mode argument to wrapping of train function in BNWrapper --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index c4f165d23ef..64958570e2d 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -106,7 +106,7 @@ def reset_running_stats(self): def train(self, mode=True): if not self.freeze_bn: - self.bn.train() + self.bn.train(mode) return self def update_bn_stats(self): From dd7e24b526cefc59d2d106a852aa5a5c7fff7dde Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:08:20 -0500 Subject: [PATCH 119/218] Set BN fusing back as default. --- .../sparsification/quantization/modifier_quantization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 7eed410b441..37307e38863 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -508,8 +508,8 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True + if self.model_fuse_fn_name == 'conv_bn_relus': + self.model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 0c9bef90018bec1e0e4361a0791f29789a26f5cd Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:19:09 -0500 Subject: [PATCH 120/218] Set BN fusing back as default. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- .../sparsification/quantization/modifier_quantization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 64958570e2d..a43d69d947b 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -397,7 +397,7 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) != BNWrapper: + if type(module) not in _BN_MODULE_TYPES: for child_name, child_module in module.named_children(): if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: setattr( diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 37307e38863..2a35ebd2aaf 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -509,7 +509,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs if self.model_fuse_fn_name == 'conv_bn_relus': - self.model_fuse_fn_kwargs["inplace"] = True + self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 1f6b56c8a3579822625ba28e5cb64c6aa4c90853 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 11 Mar 2022 19:24:03 -0500 Subject: [PATCH 121/218] Fixed custom freeze_bn_stats. --- .../sparsification/quantization/helpers.py | 251 +++++++++++------- .../quantization/modifier_quantization.py | 46 +++- 2 files changed, 185 insertions(+), 112 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index a43d69d947b..6110a499b70 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,23 +71,6 @@ else None ) -_BN_MODULE_TYPES = ( - { - # Conv based layers - nni.ConvBn1d, - nni.ConvBn2d, - nni.ConvBn3d, - nni.ConvReLU1d, - nni.ConvReLU2d, - nni.ConvReLU3d, - nni.ConvBnReLU1d, - nni.ConvBnReLU2d, - nni.ConvBnReLU3d, - } - if nni # nni will always import if torch.quantization is available - else {} -) - class BNWrapper(Module): def __init__(self, module: Module): @@ -93,6 +78,78 @@ def __init__(self, module: Module): self.bn = module self.freeze_bn = False + @property + def running_mean(self): + return self.bn.running_mean + + @running_mean.setter + def running_mean(self, value): + self.bn.running_mean = value + + @property + def running_var(self): + return self.bn.running_var + + @running_var.setter + def running_var(self, value): + self.bn.running_var = value + + @property + def weight(self): + return self.bn.weight + + @weight.setter + def weight(self, value): + self.bn.weight = value + + @property + def bias(self): + return self.bn.bias + + @bias.setter + def bias(self, value): + self.bn.bias = value + + @property + def gamma(self): + return self.bn.gamma + + @gamma.setter + def gamma(self, value): + self.bn.gamma = value + + @property + def beta(self): + return self.bn.beta + + @beta.setter + def beta(self, value): + self.bn.beta = value + + @property + def num_batches_tracked(self): + return self.bn.num_batches_tracked + + @num_batches_tracked.setter + def num_batches_tracked(self, value): + self.bn.num_batches_tracked = value + + @property + def eps(self): + return self.bn.eps + + @eps.setter + def eps(self, value): + self.bn.eps = value + + @property + def momentum(self): + return self.bn.momentum + + @momentum.setter + def momentum(self, value): + self.bn.momentum = value + def forward(self, x): return self.bn(x) @@ -115,9 +172,6 @@ def update_bn_stats(self): return self -_BN_MODULE_TYPES.add(BNWrapper) - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -155,10 +209,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -190,7 +244,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -203,20 +257,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -335,12 +389,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -397,23 +451,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) not in _BN_MODULE_TYPES: + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -462,7 +516,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -477,9 +531,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -503,7 +557,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -512,11 +566,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -588,7 +642,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -597,9 +651,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -608,14 +662,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if type(module) in _BN_MODULE_TYPES: + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -650,14 +704,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -694,11 +748,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -727,26 +781,25 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 + if mode == "symmetric": + quant_min = -(2 ** (bits - 1)) + quant_max = 2 ** (bits - 1) - 1 + dtype = torch.qint8 + else: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 2a35ebd2aaf..acbae885d71 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,9 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -332,7 +334,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_linear_output_activations + return self._quantize_conv_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -358,7 +360,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -504,11 +505,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) + # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -531,11 +533,23 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) + if len(to_remove_layer_name) == 0: + to_remove_layer_name = None configure_module_bn_wrappers(module) @@ -560,7 +574,8 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + if to_remove_layer_name: + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types to_exclude = [] @@ -568,10 +583,11 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude - self._strip_excluded_module_qconfigs(module) + if self._exclude_module_types: + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) @@ -638,10 +654,14 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" + ) def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + return get_updated_qconfig_kwargs( + self.weight_qconfig_kwargs, self.weight_bits, "symmetric" + ) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( From aee8074a885e130af0faae48a17fd6cfbe18b722 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 14 Mar 2022 15:35:52 -0400 Subject: [PATCH 122/218] Temporary files for evaluating changes to graphs. --- .../pytorch/models/classification/resnet.py | 53 +++++++++---------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index be4182891d6..21611f211d7 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,7 +41,6 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: @@ -146,7 +145,7 @@ def __init__(self, num_channels): if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -206,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -322,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -438,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -480,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 810158aa4e754732669f2889c2accd91b90083a0 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 17 Mar 2022 11:51:50 -0400 Subject: [PATCH 123/218] Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. --- .../sparsification/quantization/helpers.py | 213 ++++++++++-------- .../quantization/modifier_quantization.py | 58 ++--- 2 files changed, 137 insertions(+), 134 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 6110a499b70..8ae045de9e8 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,7 +31,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -209,10 +207,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -244,7 +242,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -257,20 +255,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -389,12 +387,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -451,23 +449,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, "freeze_bn_stats"): + if not hasattr(module, 'freeze_bn_stats'): for child_name, child_module in module.named_children(): - if type(child_module) in [ - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - ]: - setattr(module, child_name, BNWrapper(child_module)) + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -506,6 +504,17 @@ def configure_module_qat_wrappers( ) +def compute_range(dtype: torch.dtype, bits: int): + if dtype == torch.qint8: + quant_min = -2 ** (bits - 1) + quant_max = 2 ** (bits - 1) - 1 + elif dtype == torch.quint8: + quant_min = 0 + quant_max = 2 ** bits - 1 + + return quant_min, quant_max + + def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -516,7 +525,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -531,9 +540,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -557,7 +566,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -566,11 +575,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = torch.quint8, + weight_dtype: Optional[torch.dtype] = torch.qint8, + activation_bits: Optional[int] = 8, + weight_bits: Optional[int] = 8, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -591,42 +604,35 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_qscheme = ( - torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine - ) - activation_observer_kwargs = dict( - observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - dtype=torch.quint8, - qscheme=activation_qscheme, - reduce_range=reduce_range, - ) - activation_observer_kwargs.update(activation_qconfig_kwargs or {}) - activation_observer = torch_quantization.FakeQuantize.with_args( - **activation_observer_kwargs, + activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, + activation_qconfig_kwargs) + weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, ) - weight_qscheme = ( - torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + + +def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): + qscheme = ( + torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - weight_observer_kwargs = dict( + quant_min, quant_max = compute_range(dtype, bits) + observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=-128, - quant_max=127, - dtype=torch.qint8, - qscheme=weight_qscheme, + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + qscheme=qscheme, reduce_range=reduce_range, ) - - weight_observer_kwargs.update(weight_qconfig_kwargs or {}) - weight_observer = torch_quantization.FakeQuantize.with_args( - **weight_observer_kwargs, - ) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + observer_kwargs.update(qconfig_kwargs or {}) + observer = torch_quantization.FakeQuantize.with_args( + **observer_kwargs, ) + return observer + def fix_observer_quant_range(module: Module): """ @@ -642,7 +648,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -651,9 +657,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -662,14 +668,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, "freeze_bn_stats"): + if hasattr(module, 'freeze_bn_stats'): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -704,14 +710,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -748,11 +754,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -782,17 +788,24 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) + quant_min = -2 ** (bits - 1) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index acbae885d71..5a5e1913b18 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -55,7 +55,6 @@ freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, - get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -151,6 +150,7 @@ def __init__( exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + tensorrt: Optional[bool] = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -187,6 +187,7 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -234,9 +235,10 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = ( - self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" - ) + if self._tensorrt: + fuse_fn = 'no_fuse' + else: + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -360,6 +362,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -505,12 +508,11 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) - # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == "conv_bn_relus": + if self.model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -524,29 +526,16 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) - activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() - weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() - to_remove_layer_name = [] if not self._quantize_linear_output_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -554,10 +543,21 @@ def _enable_module_qat(self, module: Module): configure_module_bn_wrappers(module) # prepare each module / submodule for quantization + if self.tensorrt: + _symmetric_activations = True + _activations_dtype = torch.qint8 + else: + _symmetric_activations = False + _activations_dtype = torch.quint8 + qconfig = get_qat_qconfig( + symmetric_activations=_symmetric_activations, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=_activations_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) @@ -583,7 +583,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude if self._exclude_module_types: @@ -653,16 +653,6 @@ def _calibrate(self, module): if module_training: module.train() - def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" - ) - - def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.weight_qconfig_kwargs, self.weight_bits, "symmetric" - ) - def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From 2782d7f344beb2e94fa6d17ef7c4415a6838e3b4 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Sun, 20 Mar 2022 11:42:14 -0400 Subject: [PATCH 124/218] Added support to TensorRT quantization --- .../sparsification/quantization/helpers.py | 166 ++++++++++++++++-- .../quantization/modifier_quantization.py | 61 +++++-- 2 files changed, 195 insertions(+), 32 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 8ae045de9e8..027c7514c32 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -208,9 +208,15 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - reduce_range: bool = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -232,6 +238,18 @@ def from_module( else {} ) + qat_wrapper_kwargs["symmetric_activations"] = ( + symmetric_activations + if "symmetric_activations" not in qat_wrapper_kwargs + else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] + ) + + qat_wrapper_kwargs["symmetric_weights"] = ( + symmetric_weights or False + if "symmetric_weights" not in qat_wrapper_kwargs + else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] + ) + qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -251,6 +269,30 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + qat_wrapper_kwargs["activation_dtype"] = ( + activation_dtype + if "activation_dtype" not in qat_wrapper_kwargs + else activation_dtype or qat_wrapper_kwargs["activation_dtype"] + ) + + qat_wrapper_kwargs["weight_dtype"] = ( + weight_dtype + if "weight_dtype" not in qat_wrapper_kwargs + else weight_dtype or qat_wrapper_kwargs["weight_dtype"] + ) + + qat_wrapper_kwargs["activation_bits"] = ( + activation_bits + if "activation_bits" not in qat_wrapper_kwargs + else activation_bits or qat_wrapper_kwargs["activation_bits"] + ) + + qat_wrapper_kwargs["weight_bits"] = ( + weight_bits + if "weight_bits" not in qat_wrapper_kwargs + else weight_bits or qat_wrapper_kwargs["weight_bits"] + ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -266,9 +308,15 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -288,25 +336,43 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn + self._symmetric_activations = symmetric_activations + self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._activation_dtype = activation_dtype + self._weight_dtype = weight_dtype + self._activation_bits = activation_bits + self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -390,9 +456,15 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -422,11 +494,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) + if symmetric_activations is None: + _symmetric_activations = qconfig == "symmetric" + else: + _symmetric_activations = symmetric_activations + qconfigs[idx] = get_qat_qconfig( - symmetric_activations=(qconfig == "symmetric"), + symmetric_activations=_symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) return qconfigs @@ -463,9 +545,15 @@ def configure_module_bn_wrappers(module: Module): def configure_module_qat_wrappers( module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -490,29 +578,43 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) -def compute_range(dtype: torch.dtype, bits: int): +def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + dtype = dtype if dtype else torch.quint8 + bits = bits if bits else 8 if dtype == torch.qint8: - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 + quant_min = -(2 ** (bits - 1)) + quant_max = (2 ** (bits - 1)) - 1 elif dtype == torch.quint8: quant_min = 0 - quant_max = 2 ** bits - 1 + quant_max = (2 ** bits) - 1 - return quant_min, quant_max + return quant_min, quant_max, dtype def configure_module_default_qconfigs(module: Module): @@ -575,15 +677,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = torch.quint8, - weight_dtype: Optional[torch.dtype] = torch.qint8, - activation_bits: Optional[int] = 8, - weight_bits: Optional[int] = 8, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -606,18 +708,28 @@ def get_qat_qconfig( """ activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, activation_qconfig_kwargs) - weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + if symmetric_weights is None: + _symmetric_weights = True + else: + _symmetric_weights = symmetric_weights + + if weight_dtype is None: + _weight_dtype = torch.qint8 + else: + _weight_dtype = weight_dtype + + weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): +def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): qscheme = ( torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - quant_min, quant_max = compute_range(dtype, bits) + quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, quant_min=quant_min, @@ -756,9 +868,15 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -776,11 +894,21 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: + if symmetric_weights is None: + _symmetric_weights = False + else: + _symmetric_weights = symmetric_weights + qconfig = get_qat_qconfig( - symmetric_weights=False, + symmetric_activations=symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 5a5e1913b18..27c5a4c336e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -147,10 +147,10 @@ def __init__( weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_batchnorm: bool = True, - exclude_module_types: Union[List[str], None] = None, + exclude_module_types: Optional[List[str]] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: Optional[bool] = False, + tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -379,7 +379,15 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - return self._weight_qconfig_kwargs + if "observer" in self._weight_qconfig_kwargs: + kwargs = self._weight_qconfig_kwargs.copy() + if kwargs["observer"] == "minmaxobserver": + kwargs["observer"] = torch_quantization.MinMaxObserver + return kwargs + else: + return self._weight_qconfig_kwargs + + @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -389,6 +397,15 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps + @ModifierProp() + def tensorrt(self) -> Dict[str, Any]: + """ + :return: Dictionary with correct quant_min, quant_max, and dtype values + for activations + + """ + return self._tensorrt + def initialize( self, module: Module, @@ -545,17 +562,23 @@ def _enable_module_qat(self, module: Module): # prepare each module / submodule for quantization if self.tensorrt: _symmetric_activations = True - _activations_dtype = torch.qint8 + _activation_dtype = torch.qint8 + _symmetric_weights = True + _weight_dtype = torch.qint8 else: - _symmetric_activations = False - _activations_dtype = torch.quint8 + _symmetric_activations = None + _activation_dtype = None + _symmetric_weights = None + _weight_dtype = None qconfig = get_qat_qconfig( symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=_activations_dtype, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, activation_bits=self.activation_bits, weight_bits=self.weight_bits ) @@ -563,9 +586,15 @@ def _enable_module_qat(self, module: Module): # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -594,9 +623,15 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, ) # propagate custom quant min/max range from FakeQuantize to Observer objects From dafb4645a363b3811848eafcb428994752fe9966 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 21 Mar 2022 19:16:26 -0400 Subject: [PATCH 125/218] Included check to account for when weight_qconfig_kwatgs is None. --- .../sparsification/quantization/modifier_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 27c5a4c336e..a306f4d8e73 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -379,7 +379,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if "observer" in self._weight_qconfig_kwargs: + if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver From 5fe18e5a7cf7d645bc33828d0fe12e0dd513918a Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 14:20:19 -0400 Subject: [PATCH 126/218] Modified argument names for backwards compatibility. --- .../quantization/modifier_quantization.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index a306f4d8e73..73a50e0f9c4 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -141,8 +141,8 @@ def __init__( model_fuse_fn_kwargs: Dict[str, Any] = None, quantize_embeddings: bool = True, reduce_range: bool = False, - quantize_linear_output_activations: bool = False, - quantize_conv_output_activations: bool = False, + quantize_linear_activations: bool = False, + quantize_conv_activations: bool = False, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, @@ -174,8 +174,8 @@ def __init__( self._freeze_bn_stats_epoch = freeze_bn_stats_epoch self._quantize_embeddings = quantize_embeddings self._reduce_range = reduce_range - self._quantize_linear_output_activations = quantize_linear_output_activations - self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_linear_activations = quantize_linear_activations + self._quantize_conv_activations = quantize_conv_activations self._activation_bits = activation_bits self._weight_bits = weight_bits self._exclude_batchnorm = exclude_batchnorm @@ -320,7 +320,7 @@ def reduce_range(self) -> bool: return self._reduce_range @ModifierProp() - def quantize_linear_output_activations(self) -> bool: + def quantize_linear_activations(self) -> bool: """ :return: if False, FakeQuantize ops will not be run for activations of fully connected layers. this is important for quantizing @@ -328,15 +328,15 @@ def quantize_linear_output_activations(self) -> bool: are kept at 32 bits of precision and fake quantizing the outputs harm training recovery """ - return self._quantize_linear_output_activations + return self._quantize_linear_activations @ModifierProp() - def quantize_conv_output_activations(self) -> bool: + def quantize_conv_activations(self) -> bool: """ :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_conv_output_activations + return self._quantize_conv_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -544,10 +544,10 @@ def _enable_module_qat(self, module: Module): module_fuse_fn(**self._model_fuse_fn_kwargs) to_remove_layer_name = [] - if not self._quantize_linear_output_activations: + if not self._quantize_linear_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) - if not self._quantize_conv_output_activations: + if not self._quantize_conv_activations: to_remove_layer_name.extend( ["Conv1d", "Conv2d", "Conv3d", "ConvBn1d", "ConvBn2d", "ConvBn3d", From 50e0773e491fa9830ebb4498caa0b65ea2c5111a Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:40:51 -0400 Subject: [PATCH 127/218] Updated documentation to reflect changes. --- .../sparsification/quantization/helpers.py | 118 ++++++++++++------ 1 file changed, 81 insertions(+), 37 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 027c7514c32..bc9aeb6d58c 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -69,8 +69,14 @@ else None ) - +# class BNWrapper(Module): + """ + Wraps BatchNormalization module to expose methods needed to enable + freezing/unfreezing of statistics + + :param module: BatchNormalization module to be wrapped + """ def __init__(self, module: Module): super().__init__() self.bn = module @@ -220,14 +226,25 @@ def from_module( ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for - :param reduce_range: if True, the quantization range will be reduced by one - bit. This may prevent overflow issues with model execution on certain - hardware. Default is None, will only override qat_wrapper_kwargs if set - to a bool value + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. :return: QATWrapper object created using the given Module as the forward function. Will attempt to find any other named parameter of the QATWrapper constructor from the attributes of the given Module @@ -293,6 +310,7 @@ def from_module( else weight_bits or qat_wrapper_kwargs["weight_bits"] ) + # Remove qconfig from wrapped layer to avoid duplicate quantization module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -516,19 +534,10 @@ def _load_qconfigs( def configure_module_bn_wrappers(module: Module): """ - if any submodule of the given module has the attribute wrap_qat == True, - then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. - Other named kwargs to the QATWrapper constructor must be contained in a dictionary - under an attributed named `qat_wrapper_kwargs` + Wrap any BatchNormalization modules that are not fused with convolutions + with BNWrapper to enable freezing/unfreezing of BN statistics :param module: module to potentially wrap the submodules of - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required if not hasattr(module, 'freeze_bn_stats'): @@ -562,14 +571,25 @@ def configure_module_qat_wrappers( under an attributed named `qat_wrapper_kwargs` :param module: module to potentially wrap the submodules of + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} - """ + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. """ # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: @@ -605,6 +625,13 @@ def configure_module_qat_wrappers( def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + """ + compute quantization limits depending on data type and number of bits + + :param dtype: data type. If None dtype is set to torch.quint8. + :param bits: number of bits. If None is set to 8. + :return: minimum limit, maximum limit, data type + """ dtype = dtype if dtype else torch.quint8 bits = bits if bits else 8 if dtype == torch.qint8: @@ -689,18 +716,24 @@ def get_qat_qconfig( ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric - UINT8 quantization range with zero point set to 128. Otherwise activations - will use asymmetric quantization with any zero point. Default is False + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. :param symmetric_weights: if True, weights will have a symmetric - INT8 quantization range with zero point set to 0. Otherwise activations - will use asymmetric quantization with any zero point. Default is True + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. :return: A QAT fake quantization config for symmetric weight quantization and asymmetric activation quantization. The difference between this and torch.quantization.default_qat_qconfig is that the activation observer @@ -885,14 +918,25 @@ def prepare_embeddings_qat( :param module: module to run QAT for the embeddings of :param qconfig: qconfig to generate the fake quantize ops from. Default uses INT8 asymmetric range - :param activation_qconfig_kwargs: additional kwargs for quantizing activations. - Default is {}. - :param weight_qconfig_kwargs: additional kwargs for quantizing the weights. - Default is {}. + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware. - Default is False - """ + Default is False. + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. """ if qconfig is None: if symmetric_weights is None: _symmetric_weights = False From 77444f077ef98ffd9139c9f1038ac3f324235a5a Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:40:57 -0400 Subject: [PATCH 128/218] Updated documentation to reflect changes. --- .../quantization/modifier_quantization.py | 59 ++++++++++++------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 73a50e0f9c4..4f912b3d8bb 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -113,21 +113,26 @@ class QuantizationModifier(ScheduledModifier): :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware Default is False - :param quantize_linear_activations: if False, FakeQuantize ops will not be run - for activations of fully connected layers. this is important for quantizing - transformer based models such as BERT where the quantized MatMul outputs - are kept at 32 bits of precision and fake quantizing the outputs harm training - recovery. Default is True + :param quantize_linear_activations: if True, FakeQuantize ops will be run + for output activations of fully connected layers. Default is False. + :param quantize_conv_activations: if True, FakeQuantize ops will be run + for output activations of convolutional layers. Default is False. :param activation_bits: Number of bits to use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. + activations. Default is None, which will quantize activations to 8 bits. + :param weight_bits: Number of bits to use for setting quant min/max values for + weights. Default is None, which will quantize weights to 8 bits. :param num_calibration_steps: Number of steps to run post training calibration for. - When None, the entire calibration_dataloader is used + When None, the entire calibration_dataloader is used + :param exclude_batchnorm: If True, do not propagate quantization qconfigs to + batch-normalization modules :param exclude_module_types: optional list of module class names to not propagate quantization configs to. Default is None :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. + weights. + :param tenssorrt: if True sets quantization configuration for compatibility with + explict quantization as supported by TensorRT 8.2. """ def __init__( @@ -232,11 +237,12 @@ def submodules(self, value: Union[List[str], None]): def model_fuse_fn_name(self) -> Union[str, None]: """ :return: Name of model function to fuse the model in place prior - to performing QAT. None to uses the default function + to performing QAT. None sets to default function. + If tensorrt flag is True, default is 'no_fuse', otherwise `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ if self._tensorrt: - fuse_fn = 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' else: fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @@ -322,19 +328,16 @@ def reduce_range(self) -> bool: @ModifierProp() def quantize_linear_activations(self) -> bool: """ - :return: if False, FakeQuantize ops will not be run - for activations of fully connected layers. this is important for quantizing - transformer based models such as BERT where the quantized MatMul outputs - are kept at 32 bits of precision and fake quantizing the outputs harm - training recovery + :return: if True, FakeQuantize ops will be run for output activations + of fully connected layers """ return self._quantize_linear_activations @ModifierProp() def quantize_conv_activations(self) -> bool: """ - :return: if False, FakeQuantize ops will not be run - for activations of convolutional layers. + :return: if True, FakeQuantize ops will be run for output activations + of convolutional layers """ return self._quantize_conv_activations @@ -358,7 +361,7 @@ def activation_bits(self) -> Optional[int]: def weight_bits(self) -> Optional[int]: """ :return: Number of bits to be use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. + weights. Default is None, which will quantize weights to 8 bits. """ return self._weight_bits @@ -543,6 +546,7 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) + # build list of layer types that should not quantize output activations to_remove_layer_name = [] if not self._quantize_linear_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) @@ -557,9 +561,16 @@ def _enable_module_qat(self, module: Module): if len(to_remove_layer_name) == 0: to_remove_layer_name = None + # fix for freezing batchnorm statistics when not fusing BN with convs. + # pytorch only supports freezing batchnorm statistics for fused modules. + # this fix wraps BN modules adding with a new module class that supports + # methods related to freezing/unfreezing BN statistics. configure_module_bn_wrappers(module) - # prepare each module / submodule for quantization + # set qconfig. + # if tensorrt flag is used, set activation and weights to symmetric + # quantization. + # otherwise, use the default values set in get_qat_qconfig if self.tensorrt: _symmetric_activations = True _activation_dtype = torch.qint8 @@ -582,6 +593,8 @@ def _enable_module_qat(self, module: Module): activation_bits=self.activation_bits, weight_bits=self.weight_bits ) + + # prepare each module / submodule for quantization for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( @@ -596,13 +609,17 @@ def _enable_module_qat(self, module: Module): activation_bits=self.activation_bits, weight_bits=self.weight_bits ) + # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig + # wrap all conv / linear blocks in with quantization observers torch_quantization.propagate_qconfig_(quant_module) configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) + + # Remove output quantization from appropriate modules if to_remove_layer_name: remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) @@ -611,6 +628,8 @@ def _enable_module_qat(self, module: Module): if self._exclude_module_types: to_exclude.extend(self._exclude_module_types) + # if exclude_batchnorm flag is used, add batch norm layers to list of + # modules to exclude qconfig if self._exclude_batchnorm: to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) From 9bebbf0284ce9248899df0c095eca70bf30e13c1 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:42:27 -0400 Subject: [PATCH 129/218] Updated documentation to reflect changes. --- src/sparseml/pytorch/models/classification/resnet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 21611f211d7..3a7a5169447 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -140,6 +140,10 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): + """ + Wrapper for the FloatFunctional class that enables QATWrapper used to + quantize the first input to the Add operation + """ def __init__(self, num_channels): super().__init__() if FloatFunctional: From bdf80e84e44832854dbbcea28356053065510f11 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:52:15 -0400 Subject: [PATCH 130/218] Fixed default weights data type. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index bc9aeb6d58c..b3e47162c5e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -751,7 +751,7 @@ def get_qat_qconfig( else: _weight_dtype = weight_dtype - weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + weight_observer = get_observer(_symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, From 095f139a7e66e1bbd107ea634898845f9c73c974 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 17:02:48 -0400 Subject: [PATCH 131/218] Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 54 ++-- .../sparsification/quantization/helpers.py | 247 +++++++++--------- .../quantization/modifier_quantization.py | 44 +++- 3 files changed, 186 insertions(+), 159 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3a7a5169447..cd8b979c3ad 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -144,12 +145,13 @@ class _AddReLU(Module): Wrapper for the FloatFunctional class that enables QATWrapper used to quantize the first input to the Add operation """ + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -209,12 +211,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -325,12 +327,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -441,15 +443,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -483,10 +485,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index b3e47162c5e..c2e21d30a16 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,7 +71,7 @@ else None ) -# + class BNWrapper(Module): """ Wraps BatchNormalization module to expose methods needed to enable @@ -77,6 +79,7 @@ class BNWrapper(Module): :param module: BatchNormalization module to be wrapped """ + def __init__(self, module: Module): super().__init__() self.bn = module @@ -213,16 +216,16 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -241,8 +244,10 @@ def from_module( activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_dtype: quantized activation data type. + Default is torch.quint8. + :param weight_dtype: quantized weights data type. + Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. :param weight_bits: number of bits for weights. Default is 8. :return: QATWrapper object created using the given Module as the forward @@ -277,7 +282,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -315,26 +320,26 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -471,18 +476,18 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -540,29 +545,29 @@ def configure_module_bn_wrappers(module: Module): :param module: module to potentially wrap the submodules of """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, 'freeze_bn_stats'): + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -589,7 +594,7 @@ def configure_module_qat_wrappers( :param activation_dtype: quantized activation data type. Default is torch.quint8. :param weight_dtype: quantized weights data type. Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. """ + :param weight_bits: number of bits for weights. Default is 8.""" # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: @@ -654,7 +659,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -669,9 +674,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -695,7 +700,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -704,15 +709,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -739,8 +744,13 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, - activation_qconfig_kwargs) + activation_observer = get_observer( + symmetric_activations, + activation_dtype, + activation_bits, + reduce_range, + activation_qconfig_kwargs, + ) if symmetric_weights is None: _symmetric_weights = True else: @@ -751,17 +761,23 @@ def get_qat_qconfig( else: _weight_dtype = weight_dtype - weight_observer = get_observer(_symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs) + weight_observer = get_observer( + _symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs + ) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): - qscheme = ( - torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine - ) +def get_observer( + symmetric: Optional[bool], + dtype: Optional[torch.dtype], + bits: Optional[int], + reduce_range: bool, + qconfig_kwargs: Dict[str, Any], +): + qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, @@ -793,7 +809,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -813,14 +829,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, 'freeze_bn_stats'): + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -855,14 +871,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -899,17 +915,17 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -936,7 +952,7 @@ def prepare_embeddings_qat( :param activation_dtype: quantized activation data type. Default is torch.quint8. :param weight_dtype: quantized weights data type. Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. """ + :param weight_bits: number of bits for weights. Default is 8.""" if qconfig is None: if symmetric_weights is None: _symmetric_weights = False @@ -960,24 +976,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -2 ** (bits - 1) + quant_min = -(2 ** (bits - 1)) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 4f912b3d8bb..30e1aefbe15 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -242,9 +242,15 @@ def model_fuse_fn_name(self) -> Union[str, None]: `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ if self._tensorrt: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + ) else: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name + if self._model_fuse_fn_name + else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -365,7 +371,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -382,7 +387,10 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: + if ( + self._weight_qconfig_kwargs is not None + and "observer" in self._weight_qconfig_kwargs + ): kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver @@ -390,8 +398,6 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: else: return self._weight_qconfig_kwargs - - @ModifierProp() def num_calibration_steps(self) -> Optional[int]: """ @@ -532,7 +538,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -553,10 +559,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -591,7 +607,7 @@ def _enable_module_qat(self, module: Module): activation_dtype=_activation_dtype, weight_dtype=_weight_dtype, activation_bits=self.activation_bits, - weight_bits=self.weight_bits + weight_bits=self.weight_bits, ) # prepare each module / submodule for quantization @@ -607,7 +623,7 @@ def _enable_module_qat(self, module: Module): activation_dtype=_activation_dtype, weight_dtype=_weight_dtype, activation_bits=self.activation_bits, - weight_bits=self.weight_bits + weight_bits=self.weight_bits, ) # set quantization config (asymmetric activations, symmetric weights) @@ -631,7 +647,7 @@ def _enable_module_qat(self, module: Module): # if exclude_batchnorm flag is used, add batch norm layers to list of # modules to exclude qconfig if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude if self._exclude_module_types: From afd7430f0400be7f25468b95217c5f33592853a3 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 17:53:05 -0400 Subject: [PATCH 132/218] Removed unused method --- .../sparsification/quantization/helpers.py | 31 ------------------- 1 file changed, 31 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index c2e21d30a16..6c30789fbb7 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -41,7 +41,6 @@ "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", - "get_updated_qconfig_kwargs", "fix_observer_quant_range", "freeze_bn_stats", "fuse_module_conv_bn_relus", @@ -975,36 +974,6 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} - - # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): - raise ValueError( - "Cannot override quant_max and quant_min when number of bits is set" - ) - - if bits: - if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) - quant_max = 2 ** (bits - 1) - 1 - dtype = torch.qint8 - else: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 - - qconfig_kwargs.update( - dict( - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - ) - ) - - return qconfig_kwargs - - def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"): embedding.weight_fake_quant = qconfig.weight() From fb685272b363d7da4b2b7891e9c9d067c137294e Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 31 Mar 2022 10:05:55 -0400 Subject: [PATCH 133/218] Removed testing files --- sandbox/quantization_recipe.yaml | 7 ------- sandbox/quantization_test.py | 23 ----------------------- 2 files changed, 30 deletions(-) delete mode 100644 sandbox/quantization_recipe.yaml delete mode 100644 sandbox/quantization_test.py diff --git a/sandbox/quantization_recipe.yaml b/sandbox/quantization_recipe.yaml deleted file mode 100644 index 411dd6f025a..00000000000 --- a/sandbox/quantization_recipe.yaml +++ /dev/null @@ -1,7 +0,0 @@ -quantization_modifiers: - - !QuantizationModifier - start_epoch: -1.0 - model_fuse_fn_name: no_fuse - submodules: - - input - - sections diff --git a/sandbox/quantization_test.py b/sandbox/quantization_test.py deleted file mode 100644 index ea6fba5acd5..00000000000 --- a/sandbox/quantization_test.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch -from sparseml.pytorch.utils import ModuleExporter -from sparseml.pytorch.models import ModelRegistry -from sparseml.pytorch.optim import ScheduledModifierManager - -model = ModelRegistry.create( - key='resnet50', - pretrained=False, - pretrained_dataset="imagenet", - num_classes=1000 -) - - -ScheduledModifierManager.from_yaml("quantization_recipe.yaml").apply(model, epoch=float("inf")) - -print(model) - -exporter = ModuleExporter(model, ".") -exporter.export_onnx( - torch.randn(1, 3, 224, 224), - "quantized_test.onnx", - convert_qat=False, -) \ No newline at end of file From 7309134ddb80f5b387386911f24ca5afe1f22045 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 31 Mar 2022 10:12:29 -0400 Subject: [PATCH 134/218] Style and quality fixes. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 6c30789fbb7..fa92e5fab46 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -817,9 +817,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min From a675813f77a6f10f49e18be4d3fdd6a40b6b5a9c Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 4 Apr 2022 21:01:33 -0400 Subject: [PATCH 135/218] Changed call to get_qat_qconfig to not specify symmetry and data type arguments for default case. --- .../quantization/modifier_quantization.py | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 30e1aefbe15..210e57df76b 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -588,27 +588,25 @@ def _enable_module_qat(self, module: Module): # quantization. # otherwise, use the default values set in get_qat_qconfig if self.tensorrt: - _symmetric_activations = True - _activation_dtype = torch.qint8 - _symmetric_weights = True - _weight_dtype = torch.qint8 + qconfig = get_qat_qconfig( + symmetric_activations=True, + symmetric_weights=True, + reduce_range=self._reduce_range, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=torch.qint8, + weight_dtype=torch.qint8, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, + ) else: - _symmetric_activations = None - _activation_dtype = None - _symmetric_weights = None - _weight_dtype = None - - qconfig = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, - reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, - ) + qconfig = get_qat_qconfig( + reduce_range=self._reduce_range, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, + ) # prepare each module / submodule for quantization for name, quant_module in self._modules_to_quantize: From 2bbd5280ee7e7b2f30a3b14f0c04fe721cbe4fa4 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 12:14:39 -0500 Subject: [PATCH 136/218] Removed output quantization from conv layers --- .../sparsification/quantization/helpers.py | 601 +++++------------- .../quantization/modifier_quantization.py | 231 +++---- 2 files changed, 226 insertions(+), 606 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 6c30789fbb7..e10224bbce7 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,22 +31,21 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ + "QUANTIZABLE_MODULE_TYPES", "QATWrapper", - "configure_module_bn_wrappers", - "configure_module_default_qconfigs", "configure_module_qat_wrappers", + "configure_module_default_qconfigs", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", + "get_updated_qconfig_kwargs", "fix_observer_quant_range", - "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] -_QUANTIZABLE_MODULE_TYPES = ( +QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -71,113 +69,6 @@ ) -class BNWrapper(Module): - """ - Wraps BatchNormalization module to expose methods needed to enable - freezing/unfreezing of statistics - - :param module: BatchNormalization module to be wrapped - """ - - def __init__(self, module: Module): - super().__init__() - self.bn = module - self.freeze_bn = False - - @property - def running_mean(self): - return self.bn.running_mean - - @running_mean.setter - def running_mean(self, value): - self.bn.running_mean = value - - @property - def running_var(self): - return self.bn.running_var - - @running_var.setter - def running_var(self, value): - self.bn.running_var = value - - @property - def weight(self): - return self.bn.weight - - @weight.setter - def weight(self, value): - self.bn.weight = value - - @property - def bias(self): - return self.bn.bias - - @bias.setter - def bias(self, value): - self.bn.bias = value - - @property - def gamma(self): - return self.bn.gamma - - @gamma.setter - def gamma(self, value): - self.bn.gamma = value - - @property - def beta(self): - return self.bn.beta - - @beta.setter - def beta(self, value): - self.bn.beta = value - - @property - def num_batches_tracked(self): - return self.bn.num_batches_tracked - - @num_batches_tracked.setter - def num_batches_tracked(self, value): - self.bn.num_batches_tracked = value - - @property - def eps(self): - return self.bn.eps - - @eps.setter - def eps(self, value): - self.bn.eps = value - - @property - def momentum(self): - return self.bn.momentum - - @momentum.setter - def momentum(self, value): - self.bn.momentum = value - - def forward(self, x): - return self.bn(x) - - def freeze_bn_stats(self): - self.freeze_bn = True - self.bn.training = False - return self - - def reset_running_stats(self): - self.bn.reset_running_stats() - - def train(self, mode=True): - if not self.freeze_bn: - self.bn.train(mode) - return self - - def update_bn_stats(self): - self.freeze_bn = False - self.bn.training = True - return self - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -215,40 +106,21 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for - :param symmetric_activations: if True, activations will have a symmetric - quantization range with a pre-specified zero point - (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). - Default is False. - :param symmetric_weights: if True, weights will have a symmetric - quantization range with a pre-specified zero point - (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). - Default is True. - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware. - Default is False. + :param reduce_range: if True, the quantization range will be reduced by one + bit. This may prevent overflow issues with model execution on certain + hardware. Default is None, will only override qat_wrapper_kwargs if set + to a bool value :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. Default is {} :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param activation_dtype: quantized activation data type. - Default is torch.quint8. - :param weight_dtype: quantized weights data type. - Default is torch.qint8. - :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. + weights. Default is {} :return: QATWrapper object created using the given Module as the forward function. Will attempt to find any other named parameter of the QATWrapper constructor from the attributes of the given Module @@ -259,18 +131,6 @@ def from_module( else {} ) - qat_wrapper_kwargs["symmetric_activations"] = ( - symmetric_activations - if "symmetric_activations" not in qat_wrapper_kwargs - else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] - ) - - qat_wrapper_kwargs["symmetric_weights"] = ( - symmetric_weights or False - if "symmetric_weights" not in qat_wrapper_kwargs - else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] - ) - qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -281,7 +141,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -290,55 +150,23 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) - qat_wrapper_kwargs["activation_dtype"] = ( - activation_dtype - if "activation_dtype" not in qat_wrapper_kwargs - else activation_dtype or qat_wrapper_kwargs["activation_dtype"] - ) - - qat_wrapper_kwargs["weight_dtype"] = ( - weight_dtype - if "weight_dtype" not in qat_wrapper_kwargs - else weight_dtype or qat_wrapper_kwargs["weight_dtype"] - ) - - qat_wrapper_kwargs["activation_bits"] = ( - activation_bits - if "activation_bits" not in qat_wrapper_kwargs - else activation_bits or qat_wrapper_kwargs["activation_bits"] - ) - - qat_wrapper_kwargs["weight_bits"] = ( - weight_bits - if "weight_bits" not in qat_wrapper_kwargs - else weight_bits or qat_wrapper_kwargs["weight_bits"] - ) - - # Remove qconfig from wrapped layer to avoid duplicate quantization - module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -358,43 +186,25 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn - self._symmetric_activations = symmetric_activations - self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._activation_dtype = activation_dtype - self._weight_dtype = weight_dtype - self._activation_bits = activation_bits - self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -475,18 +285,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -516,57 +320,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) - if symmetric_activations is None: - _symmetric_activations = qconfig == "symmetric" - else: - _symmetric_activations = symmetric_activations - qconfigs[idx] = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=symmetric_weights, + symmetric_activations=(qconfig == "symmetric"), reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) return qconfigs -def configure_module_bn_wrappers(module: Module): - """ - Wrap any BatchNormalization modules that are not fused with convolutions - with BNWrapper to enable freezing/unfreezing of BN statistics - - :param module: module to potentially wrap the submodules of - """ - # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, "freeze_bn_stats"): - for child_name, child_module in module.named_children(): - if type(child_module) in [ - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - ]: - setattr(module, child_name, BNWrapper(child_module)) - # recurse on child module - configure_module_bn_wrappers(child_module) - - def configure_module_qat_wrappers( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -575,25 +343,14 @@ def configure_module_qat_wrappers( under an attributed named `qat_wrapper_kwargs` :param module: module to potentially wrap the submodules of - :param symmetric_activations: if True, activations will have a symmetric - quantization range with a pre-specified zero point - (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). - Default is False. - :param symmetric_weights: if True, weights will have a symmetric - quantization range with a pre-specified zero point - (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). - Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware. - Default is False. + This may prevent overflow issues with model execution on certain hardware + Default is False :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. Default is {} :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. - :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8.""" + weights. Default is {} + """ # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: @@ -602,52 +359,20 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) -def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): - """ - compute quantization limits depending on data type and number of bits - - :param dtype: data type. If None dtype is set to torch.quint8. - :param bits: number of bits. If None is set to 8. - :return: minimum limit, maximum limit, data type - """ - dtype = dtype if dtype else torch.quint8 - bits = bits if bits else 8 - if dtype == torch.qint8: - quant_min = -(2 ** (bits - 1)) - quant_max = (2 ** (bits - 1)) - 1 - elif dtype == torch.quint8: - quant_min = 0 - quant_max = (2 ** bits) - 1 - - return quant_min, quant_max, dtype - - def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -658,7 +383,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -673,9 +398,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -699,7 +424,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -708,90 +433,66 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric - quantization range with a pre-specified zero point - (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). - Default is False. + UINT8 quantization range with zero point set to 128. Otherwise activations + will use asymmetric quantization with any zero point. Default is False :param symmetric_weights: if True, weights will have a symmetric - quantization range with a pre-specified zero point - (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). - Default is True. + INT8 quantization range with zero point set to 0. Otherwise activations + will use asymmetric quantization with any zero point. Default is True :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware. - Default is False. + This may prevent overflow issues with model execution on certain hardware + Default is False :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. - :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. + weights. :return: A QAT fake quantization config for symmetric weight quantization and asymmetric activation quantization. The difference between this and torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_observer = get_observer( - symmetric_activations, - activation_dtype, - activation_bits, - reduce_range, - activation_qconfig_kwargs, + activation_qscheme = ( + torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine ) - if symmetric_weights is None: - _symmetric_weights = True - else: - _symmetric_weights = symmetric_weights - - if weight_dtype is None: - _weight_dtype = torch.qint8 - else: - _weight_dtype = weight_dtype - - weight_observer = get_observer( - _symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs + activation_observer_kwargs = dict( + observer=torch_quantization.MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=activation_qscheme, + reduce_range=reduce_range, ) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + activation_observer_kwargs.update(activation_qconfig_kwargs or {}) + activation_observer = torch_quantization.FakeQuantize.with_args( + **activation_observer_kwargs, ) - - -def get_observer( - symmetric: Optional[bool], - dtype: Optional[torch.dtype], - bits: Optional[int], - reduce_range: bool, - qconfig_kwargs: Dict[str, Any], -): - qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine - quant_min, quant_max, dtype = compute_range(dtype, bits) - observer_kwargs = dict( + weight_qscheme = ( + torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + ) + weight_observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - qscheme=qscheme, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=weight_qscheme, reduce_range=reduce_range, ) - observer_kwargs.update(qconfig_kwargs or {}) - observer = torch_quantization.FakeQuantize.with_args( - **observer_kwargs, - ) - return observer + weight_observer_kwargs.update(weight_qconfig_kwargs or {}) + weight_observer = torch_quantization.FakeQuantize.with_args( + **weight_observer_kwargs, + ) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, + ) def fix_observer_quant_range(module: Module): @@ -808,7 +509,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -817,9 +518,14 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) + or ( # do not propagate default uint8 symmetric range + observer.qscheme == torch.per_tensor_symmetric + and fake_quantize.quant_min == 0 + and fake_quantize.quant_max == 255 + ) ): continue observer.quant_min = fake_quantize.quant_min @@ -827,15 +533,10 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True -def freeze_bn_stats(module: Module): - if hasattr(module, "freeze_bn_stats"): - module.freeze_bn_stats() - - def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -870,14 +571,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -914,17 +615,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -933,47 +628,57 @@ def prepare_embeddings_qat( :param module: module to run QAT for the embeddings of :param qconfig: qconfig to generate the fake quantize ops from. Default uses INT8 asymmetric range - :param symmetric_activations: if True, activations will have a symmetric - quantization range with a pre-specified zero point - (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). - Default is False. - :param symmetric_weights: if True, weights will have a symmetric - quantization range with a pre-specified zero point - (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). - Default is True. + :param activation_qconfig_kwargs: additional kwargs for quantizing activations. + Default is {}. + :param weight_qconfig_kwargs: additional kwargs for quantizing the weights. + Default is {}. :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware. - Default is False. - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. - :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8.""" + Default is False + """ if qconfig is None: - if symmetric_weights is None: - _symmetric_weights = False - else: - _symmetric_weights = symmetric_weights - qconfig = get_qat_qconfig( - symmetric_activations=symmetric_activations, - symmetric_weights=_symmetric_weights, + symmetric_weights=False, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: _prepare_qat_embedding(submodule, qconfig) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits): + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) + + # update qconfig_kwargs for bits + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): + raise ValueError( + "Cannot override quant_max and quant_min when number of bits is set" + ) + + if bits: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( + dict( + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + ) + ) + + return qconfig_kwargs + + def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"): embedding.weight_fake_quant = qconfig.weight() diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 30e1aefbe15..79772790566 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,14 +47,14 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( + QUANTIZABLE_MODULE_TYPES, add_quant_dequant, - configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, - freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, + get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -94,8 +94,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as - 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use + the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -113,26 +113,21 @@ class QuantizationModifier(ScheduledModifier): :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware Default is False - :param quantize_linear_activations: if True, FakeQuantize ops will be run - for output activations of fully connected layers. Default is False. - :param quantize_conv_activations: if True, FakeQuantize ops will be run - for output activations of convolutional layers. Default is False. + :param quantize_linear_activations: if False, FakeQuantize ops will not be run + for activations of fully connected layers. this is important for quantizing + transformer based models such as BERT where the quantized MatMul outputs + are kept at 32 bits of precision and fake quantizing the outputs harm training + recovery. Default is True :param activation_bits: Number of bits to use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. - :param weight_bits: Number of bits to use for setting quant min/max values for - weights. Default is None, which will quantize weights to 8 bits. + activations. Default is None, which will quantize activations to 8 bits. :param num_calibration_steps: Number of steps to run post training calibration for. - When None, the entire calibration_dataloader is used - :param exclude_batchnorm: If True, do not propagate quantization qconfigs to - batch-normalization modules + When None, the entire calibration_dataloader is used :param exclude_module_types: optional list of module class names to not propagate quantization configs to. Default is None :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param tenssorrt: if True sets quantization configuration for compatibility with - explict quantization as supported by TensorRT 8.2. + weights. """ def __init__( @@ -146,16 +141,15 @@ def __init__( model_fuse_fn_kwargs: Dict[str, Any] = None, quantize_embeddings: bool = True, reduce_range: bool = False, - quantize_linear_activations: bool = False, - quantize_conv_activations: bool = False, + quantize_linear_output_activations: bool = False, + quantize_conv_output_activations: bool = False, + quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, - exclude_batchnorm: bool = True, - exclude_module_types: Optional[List[str]] = None, + exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -179,11 +173,11 @@ def __init__( self._freeze_bn_stats_epoch = freeze_bn_stats_epoch self._quantize_embeddings = quantize_embeddings self._reduce_range = reduce_range - self._quantize_linear_activations = quantize_linear_activations - self._quantize_conv_activations = quantize_conv_activations + self._quantize_linear_output_activations = quantize_linear_output_activations + self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits - self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -192,7 +186,6 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -237,21 +230,10 @@ def submodules(self, value: Union[List[str], None]): def model_fuse_fn_name(self) -> Union[str, None]: """ :return: Name of model function to fuse the model in place prior - to performing QAT. None sets to default function. - If tensorrt flag is True, default is 'no_fuse', otherwise + to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - if self._tensorrt: - fuse_fn = ( - self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" - ) - else: - fuse_fn = ( - self._model_fuse_fn_name - if self._model_fuse_fn_name - else "conv_bn_relus" - ) - return fuse_fn + return self._model_fuse_fn_name @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -280,7 +262,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """print + """ :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -332,20 +314,23 @@ def reduce_range(self) -> bool: return self._reduce_range @ModifierProp() - def quantize_linear_activations(self) -> bool: + def quantize_linear_output_activations(self) -> bool: """ - :return: if True, FakeQuantize ops will be run for output activations - of fully connected layers + :return: if False, FakeQuantize ops will not be run + for activations of fully connected layers. this is important for quantizing + transformer based models such as BERT where the quantized MatMul outputs + are kept at 32 bits of precision and fake quantizing the outputs harm + training recovery """ - return self._quantize_linear_activations + return self._quantize_linear_output_activations @ModifierProp() - def quantize_conv_activations(self) -> bool: + def quantize_conv_output_activations(self) -> bool: """ - :return: if True, FakeQuantize ops will be run for output activations - of convolutional layers + :return: if False, FakeQuantize ops will not be run + for activations of convolutional layers. """ - return self._quantize_conv_activations + return self._quantize_linear_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -367,10 +352,11 @@ def activation_bits(self) -> Optional[int]: def weight_bits(self) -> Optional[int]: """ :return: Number of bits to be use for setting quant min/max values for - weights. Default is None, which will quantize weights to 8 bits. + activations. Default is None, which will quantize activations to 8 bits. """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -387,16 +373,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if ( - self._weight_qconfig_kwargs is not None - and "observer" in self._weight_qconfig_kwargs - ): - kwargs = self._weight_qconfig_kwargs.copy() - if kwargs["observer"] == "minmaxobserver": - kwargs["observer"] = torch_quantization.MinMaxObserver - return kwargs - else: - return self._weight_qconfig_kwargs + return self._weight_qconfig_kwargs @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -406,15 +383,6 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps - @ModifierProp() - def tensorrt(self) -> Dict[str, Any]: - """ - :return: Dictionary with correct quant_min, quant_max, and dtype values - for activations - - """ - return self._tensorrt - def initialize( self, module: Module, @@ -448,7 +416,10 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if name in self._submodules: + if ( + type(submodule) in QUANTIZABLE_MODULE_TYPES + and name in self._submodules + ): self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -533,15 +504,15 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(freeze_bn_stats) + quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == "conv_bn_relus": - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) - elif self.model_fuse_fn_name != "no_fuse": + if ( + self._model_fuse_fn_name is not None + and self._model_fuse_fn_name != "no_fuse" + ): # module class fn module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -551,105 +522,49 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) + elif self._model_fuse_fn_name is None: # default auto fn + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + + activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() + weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() - # build list of layer types that should not quantize output activations to_remove_layer_name = [] - if not self._quantize_linear_activations: - to_remove_layer_name.extend(["Linear", "LinearReLU"]) + if not self._quantize_linear_output_activations: + to_remove_layer_name.extend(["Linear", "LinearReLu"]) - if not self._quantize_conv_activations: + if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) - if len(to_remove_layer_name) == 0: - to_remove_layer_name = None - - # fix for freezing batchnorm statistics when not fusing BN with convs. - # pytorch only supports freezing batchnorm statistics for fused modules. - # this fix wraps BN modules adding with a new module class that supports - # methods related to freezing/unfreezing BN statistics. - configure_module_bn_wrappers(module) - - # set qconfig. - # if tensorrt flag is used, set activation and weights to symmetric - # quantization. - # otherwise, use the default values set in get_qat_qconfig - if self.tensorrt: - _symmetric_activations = True - _activation_dtype = torch.qint8 - _symmetric_weights = True - _weight_dtype = torch.qint8 - else: - _symmetric_activations = None - _activation_dtype = None - _symmetric_weights = None - _weight_dtype = None + # prepare each module / submodule for quantization qconfig = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) - - # prepare each module / submodule for quantization for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) - # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig - # wrap all conv / linear blocks in with quantization observers torch_quantization.propagate_qconfig_(quant_module) configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - - # Remove output quantization from appropriate modules - if to_remove_layer_name: - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types - to_exclude = [] - if self._exclude_module_types: - to_exclude.extend(self._exclude_module_types) - - # if exclude_batchnorm flag is used, add batch norm layers to list of - # modules to exclude qconfig - if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) - - self._exclude_module_types = to_exclude if self._exclude_module_types: self._strip_excluded_module_qconfigs(module) @@ -658,15 +573,9 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # propagate custom quant min/max range from FakeQuantize to Observer objects @@ -723,6 +632,12 @@ def _calibrate(self, module): if module_training: module.train() + def _get_updated_activation_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + + def _get_updated_weight_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From f7e7374ba4fa86028d6a265e073c6fa4aac9b804 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:35:49 -0500 Subject: [PATCH 137/218] Added _Add_ReLU module that enables QATWrapper for quantizaiton. --- .../pytorch/models/classification/resnet.py | 66 +++++++++---------- 1 file changed, 30 insertions(+), 36 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index cd8b979c3ad..3112da7c2e1 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,7 +41,6 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: @@ -141,19 +140,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - """ - Wrapper for the FloatFunctional class that enables QATWrapper used to - quantize the first input to the Add operation - """ - - def __init__(self, num_channels): + def __init__(self): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: - self.functional = ReLU(num_channels=num_channels, inplace=True) + self.functional = ReLU(num_channels=out_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -185,7 +179,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU(out_channels) + self.add_relu = _AddReLU() self.initialize() @@ -211,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -242,7 +236,7 @@ def __init__( else None ) - self.add_relu = _AddReLU(out_channels) + self.add_relu = _AddReLU() self.initialize() @@ -327,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -443,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -485,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 8a71887a7e29349eb0dec43ef58bac990225690b Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:36:37 -0500 Subject: [PATCH 138/218] Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. --- .../sparsification/quantization/helpers.py | 6 +-- .../quantization/modifier_quantization.py | 39 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index e10224bbce7..ec69ded82c8 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,7 +32,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ - "QUANTIZABLE_MODULE_TYPES", "QATWrapper", "configure_module_qat_wrappers", "configure_module_default_qconfigs", @@ -45,7 +44,7 @@ "prepare_embeddings_qat", ] -QUANTIZABLE_MODULE_TYPES = ( +_QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -150,6 +149,7 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -398,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in QUANTIZABLE_MODULE_TYPES + type(module) in _QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 79772790566..f914b1f2b91 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,7 +47,6 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( - QUANTIZABLE_MODULE_TYPES, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, @@ -94,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use - the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' + to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,10 +142,10 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, - quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, + exclude_batchnorm: bool = True, exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, @@ -175,9 +174,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations - self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits + self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -233,7 +232,8 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - return self._model_fuse_fn_name + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + return fuse_fn @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -416,10 +416,7 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if ( - type(submodule) in QUANTIZABLE_MODULE_TYPES - and name in self._submodules - ): + if name in self._submodules: self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -509,10 +506,10 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if ( - self._model_fuse_fn_name is not None - and self._model_fuse_fn_name != "no_fuse" - ): # module class fn + if self._model_fuse_fn_name == 'conv_bn_relus': + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -522,16 +519,13 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) - elif self._model_fuse_fn_name is None: # default auto fn - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLu"]) + to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -565,8 +559,15 @@ def _enable_module_qat(self, module: Module): remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types + to_exclude = [] if self._exclude_module_types: - self._strip_excluded_module_qconfigs(module) + to_exclude.extend(self._exclude_module_types) + + if self._exclude_batchnorm: + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + + self._exclude_module_types = to_exclude + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) From 2e65700bf094d79d7ec23b9eb9204a561b74b3b2 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:40:31 -0500 Subject: [PATCH 139/218] Minor fixes. Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 61 ++++----- .../sparsification/quantization/helpers.py | 123 +++++++++--------- .../quantization/modifier_quantization.py | 33 +++-- 3 files changed, 112 insertions(+), 105 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3112da7c2e1..be4182891d6 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -140,14 +141,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self): + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: - self.functional = ReLU(num_channels=out_channels, inplace=True) + self.functional = ReLU(num_channels=num_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +180,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -205,12 +206,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -236,7 +237,7 @@ def __init__( else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -321,12 +322,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -437,15 +438,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -479,10 +480,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index ec69ded82c8..2c1ac640d6e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_qat_wrappers", @@ -105,10 +107,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -140,7 +142,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -153,20 +155,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -285,12 +287,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -331,10 +333,10 @@ def _load_qconfigs( def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -383,7 +385,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -398,9 +400,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -424,7 +426,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -433,11 +435,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -509,7 +511,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -534,9 +536,9 @@ def fix_observer_quant_range(module: Module): def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -571,14 +573,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -615,11 +617,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -649,17 +651,10 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index f914b1f2b91..637bf7e52dd 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -93,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' - to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as + 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -232,7 +232,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" return fuse_fn @model_fuse_fn_name.setter @@ -356,7 +356,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -506,7 +505,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': + if self._model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -529,10 +528,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) # prepare each module / submodule for quantization @@ -564,7 +573,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -634,7 +643,9 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits + ) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From b100535fffb53749a12d93e684bcba58a4a6e331 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:02:14 -0500 Subject: [PATCH 140/218] Added support to freezing bn stats. --- .../sparsification/quantization/helpers.py | 209 +++++++++++++----- .../quantization/modifier_quantization.py | 37 ++-- 2 files changed, 164 insertions(+), 82 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 2c1ac640d6e..a44369550b1 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,16 +31,17 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", - "configure_module_qat_wrappers", + "configure_module_bn_wrappers", "configure_module_default_qconfigs", + "configure_module_qat_wrappers", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", + "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] @@ -69,6 +69,54 @@ else None ) +_BN_MODULE_TYPES = ( + { + # Conv based layers + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.ConvBnReLU1d, + nni.ConvBnReLU2d, + nni.ConvBnReLU3d, + } + if nni # nni will always import if torch.quantization is available + else {} +) + + +class BNWrapper(Module): + def __init__(self, module: Module): + super().__init__() + self.bn = module + self.freeze_bn = False + + def forward(self, x): + return self.bn(x) + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def train(self, mode=True): + if not self.freeze_bn: + self.bn.train() + return self + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + +_BN_MODULE_TYPES.add(BNWrapper) + class QATWrapper(Module): """ @@ -107,10 +155,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -142,7 +190,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -155,20 +203,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -287,12 +335,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -332,11 +380,40 @@ def _load_qconfigs( return qconfigs +def configure_module_bn_wrappers(module: Module): + """ + if any submodule of the given module has the attribute wrap_qat == True, + then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. + Other named kwargs to the QATWrapper constructor must be contained in a dictionary + under an attributed named `qat_wrapper_kwargs` + + :param module: module to potentially wrap the submodules of + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware + Default is False + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. Default is {} + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. Default is {} + """ + # wrap any children of the given module as a QATWrapper if required + if type(module) != BNWrapper: + for child_name, child_module in module.named_children(): + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) + # recurse on child module + configure_module_bn_wrappers(child_module) + + def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -385,7 +462,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -400,9 +477,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -426,7 +503,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -435,11 +512,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -511,7 +588,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -535,10 +612,15 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True +def freeze_bn_stats(module: Module): + if type(module) in _BN_MODULE_TYPES: + module.freeze_bn_stats() + + def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -573,14 +655,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -617,11 +699,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -651,10 +733,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 637bf7e52dd..7eed410b441 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -48,9 +48,11 @@ from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( add_quant_dequant, + configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, + freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, get_updated_qconfig_kwargs, @@ -232,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' return fuse_fn @model_fuse_fn_name.setter @@ -262,7 +264,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """ + """print :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -356,6 +358,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -500,12 +503,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) + quant_module.apply(freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == "conv_bn_relus": + if self._model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -528,22 +531,14 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) + configure_module_bn_wrappers(module) + # prepare each module / submodule for quantization qconfig = get_qat_qconfig( reduce_range=self._reduce_range, @@ -573,7 +568,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -643,9 +638,7 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits - ) + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From 7c9707ba40c5382f9f7bbcc424a6b10afe314386 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:57:15 -0500 Subject: [PATCH 141/218] Added mode argument to wrapping of train function in BNWrapper --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index a44369550b1..48ed0708eae 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -106,7 +106,7 @@ def reset_running_stats(self): def train(self, mode=True): if not self.freeze_bn: - self.bn.train() + self.bn.train(mode) return self def update_bn_stats(self): From 4d8e3ea5199b1dfb42a1236b643bb5a4dc14d019 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:08:20 -0500 Subject: [PATCH 142/218] Set BN fusing back as default. --- .../sparsification/quantization/modifier_quantization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 7eed410b441..37307e38863 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -508,8 +508,8 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True + if self.model_fuse_fn_name == 'conv_bn_relus': + self.model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 56d065dac537d94f269ffeb9fa56a966d5b947ac Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:19:09 -0500 Subject: [PATCH 143/218] Set BN fusing back as default. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- .../sparsification/quantization/modifier_quantization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 48ed0708eae..71f6553fc44 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -397,7 +397,7 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) != BNWrapper: + if type(module) not in _BN_MODULE_TYPES: for child_name, child_module in module.named_children(): if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: setattr( diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 37307e38863..2a35ebd2aaf 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -509,7 +509,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs if self.model_fuse_fn_name == 'conv_bn_relus': - self.model_fuse_fn_kwargs["inplace"] = True + self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 508c753b5d004b92be8959ee7e65a81503075dbe Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 11 Mar 2022 19:24:03 -0500 Subject: [PATCH 144/218] Fixed custom freeze_bn_stats. --- .../sparsification/quantization/helpers.py | 245 +++++++++++------- .../quantization/modifier_quantization.py | 46 +++- 2 files changed, 182 insertions(+), 109 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 71f6553fc44..e09c0e29690 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,23 +71,6 @@ else None ) -_BN_MODULE_TYPES = ( - { - # Conv based layers - nni.ConvBn1d, - nni.ConvBn2d, - nni.ConvBn3d, - nni.ConvReLU1d, - nni.ConvReLU2d, - nni.ConvReLU3d, - nni.ConvBnReLU1d, - nni.ConvBnReLU2d, - nni.ConvBnReLU3d, - } - if nni # nni will always import if torch.quantization is available - else {} -) - class BNWrapper(Module): def __init__(self, module: Module): @@ -93,6 +78,78 @@ def __init__(self, module: Module): self.bn = module self.freeze_bn = False + @property + def running_mean(self): + return self.bn.running_mean + + @running_mean.setter + def running_mean(self, value): + self.bn.running_mean = value + + @property + def running_var(self): + return self.bn.running_var + + @running_var.setter + def running_var(self, value): + self.bn.running_var = value + + @property + def weight(self): + return self.bn.weight + + @weight.setter + def weight(self, value): + self.bn.weight = value + + @property + def bias(self): + return self.bn.bias + + @bias.setter + def bias(self, value): + self.bn.bias = value + + @property + def gamma(self): + return self.bn.gamma + + @gamma.setter + def gamma(self, value): + self.bn.gamma = value + + @property + def beta(self): + return self.bn.beta + + @beta.setter + def beta(self, value): + self.bn.beta = value + + @property + def num_batches_tracked(self): + return self.bn.num_batches_tracked + + @num_batches_tracked.setter + def num_batches_tracked(self, value): + self.bn.num_batches_tracked = value + + @property + def eps(self): + return self.bn.eps + + @eps.setter + def eps(self, value): + self.bn.eps = value + + @property + def momentum(self): + return self.bn.momentum + + @momentum.setter + def momentum(self, value): + self.bn.momentum = value + def forward(self, x): return self.bn(x) @@ -115,9 +172,6 @@ def update_bn_stats(self): return self -_BN_MODULE_TYPES.add(BNWrapper) - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -155,10 +209,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -190,7 +244,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -203,20 +257,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -335,12 +389,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -397,23 +451,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) not in _BN_MODULE_TYPES: + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -462,7 +516,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -477,9 +531,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -503,7 +557,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -512,11 +566,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -588,7 +642,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -613,14 +667,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if type(module) in _BN_MODULE_TYPES: + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -655,14 +709,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -699,11 +753,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -732,26 +786,25 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 + if mode == "symmetric": + quant_min = -(2 ** (bits - 1)) + quant_max = 2 ** (bits - 1) - 1 + dtype = torch.qint8 + else: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 2a35ebd2aaf..acbae885d71 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,9 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -332,7 +334,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_linear_output_activations + return self._quantize_conv_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -358,7 +360,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -504,11 +505,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) + # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -531,11 +533,23 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) + if len(to_remove_layer_name) == 0: + to_remove_layer_name = None configure_module_bn_wrappers(module) @@ -560,7 +574,8 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + if to_remove_layer_name: + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types to_exclude = [] @@ -568,10 +583,11 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude - self._strip_excluded_module_qconfigs(module) + if self._exclude_module_types: + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) @@ -638,10 +654,14 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" + ) def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + return get_updated_qconfig_kwargs( + self.weight_qconfig_kwargs, self.weight_bits, "symmetric" + ) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( From 810c22a3005cc07276785335df006dca027eb2ad Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 14 Mar 2022 15:35:52 -0400 Subject: [PATCH 145/218] Temporary files for evaluating changes to graphs. --- .../pytorch/models/classification/resnet.py | 53 +++++++++---------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index be4182891d6..21611f211d7 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,7 +41,6 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: @@ -146,7 +145,7 @@ def __init__(self, num_channels): if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -206,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -322,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -438,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -480,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 9422c873a127dc48afd2cde4fb6886133e02c800 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 17 Mar 2022 11:51:50 -0400 Subject: [PATCH 146/218] Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. --- .../sparsification/quantization/helpers.py | 207 ++++++++++-------- .../quantization/modifier_quantization.py | 58 ++--- 2 files changed, 134 insertions(+), 131 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index e09c0e29690..57b919470e4 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,7 +31,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -209,10 +207,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -244,7 +242,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -257,20 +255,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -389,12 +387,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -451,23 +449,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, "freeze_bn_stats"): + if not hasattr(module, 'freeze_bn_stats'): for child_name, child_module in module.named_children(): - if type(child_module) in [ - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - ]: - setattr(module, child_name, BNWrapper(child_module)) + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -506,6 +504,17 @@ def configure_module_qat_wrappers( ) +def compute_range(dtype: torch.dtype, bits: int): + if dtype == torch.qint8: + quant_min = -2 ** (bits - 1) + quant_max = 2 ** (bits - 1) - 1 + elif dtype == torch.quint8: + quant_min = 0 + quant_max = 2 ** bits - 1 + + return quant_min, quant_max + + def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -516,7 +525,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -531,9 +540,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -557,7 +566,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -566,11 +575,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = torch.quint8, + weight_dtype: Optional[torch.dtype] = torch.qint8, + activation_bits: Optional[int] = 8, + weight_bits: Optional[int] = 8, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -591,42 +604,35 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_qscheme = ( - torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine - ) - activation_observer_kwargs = dict( - observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - dtype=torch.quint8, - qscheme=activation_qscheme, - reduce_range=reduce_range, - ) - activation_observer_kwargs.update(activation_qconfig_kwargs or {}) - activation_observer = torch_quantization.FakeQuantize.with_args( - **activation_observer_kwargs, + activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, + activation_qconfig_kwargs) + weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, ) - weight_qscheme = ( - torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + + +def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): + qscheme = ( + torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - weight_observer_kwargs = dict( + quant_min, quant_max = compute_range(dtype, bits) + observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=-128, - quant_max=127, - dtype=torch.qint8, - qscheme=weight_qscheme, + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + qscheme=qscheme, reduce_range=reduce_range, ) - - weight_observer_kwargs.update(weight_qconfig_kwargs or {}) - weight_observer = torch_quantization.FakeQuantize.with_args( - **weight_observer_kwargs, - ) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + observer_kwargs.update(qconfig_kwargs or {}) + observer = torch_quantization.FakeQuantize.with_args( + **observer_kwargs, ) + return observer + def fix_observer_quant_range(module: Module): """ @@ -642,7 +648,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -667,14 +673,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, "freeze_bn_stats"): + if hasattr(module, 'freeze_bn_stats'): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -709,14 +715,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -753,11 +759,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -787,17 +793,24 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) + quant_min = -2 ** (bits - 1) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index acbae885d71..5a5e1913b18 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -55,7 +55,6 @@ freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, - get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -151,6 +150,7 @@ def __init__( exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + tensorrt: Optional[bool] = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -187,6 +187,7 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -234,9 +235,10 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = ( - self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" - ) + if self._tensorrt: + fuse_fn = 'no_fuse' + else: + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -360,6 +362,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -505,12 +508,11 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) - # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == "conv_bn_relus": + if self.model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -524,29 +526,16 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) - activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() - weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() - to_remove_layer_name = [] if not self._quantize_linear_output_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -554,10 +543,21 @@ def _enable_module_qat(self, module: Module): configure_module_bn_wrappers(module) # prepare each module / submodule for quantization + if self.tensorrt: + _symmetric_activations = True + _activations_dtype = torch.qint8 + else: + _symmetric_activations = False + _activations_dtype = torch.quint8 + qconfig = get_qat_qconfig( + symmetric_activations=_symmetric_activations, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=_activations_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) @@ -583,7 +583,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude if self._exclude_module_types: @@ -653,16 +653,6 @@ def _calibrate(self, module): if module_training: module.train() - def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" - ) - - def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.weight_qconfig_kwargs, self.weight_bits, "symmetric" - ) - def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From 7d662752fe176d52dba7f18b7d1b1bb4d7e7ec9a Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Sun, 20 Mar 2022 11:42:14 -0400 Subject: [PATCH 147/218] Added support to TensorRT quantization --- .../sparsification/quantization/helpers.py | 166 ++++++++++++++++-- .../quantization/modifier_quantization.py | 61 +++++-- 2 files changed, 195 insertions(+), 32 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 57b919470e4..2ae713c16aa 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -208,9 +208,15 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - reduce_range: bool = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -232,6 +238,18 @@ def from_module( else {} ) + qat_wrapper_kwargs["symmetric_activations"] = ( + symmetric_activations + if "symmetric_activations" not in qat_wrapper_kwargs + else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] + ) + + qat_wrapper_kwargs["symmetric_weights"] = ( + symmetric_weights or False + if "symmetric_weights" not in qat_wrapper_kwargs + else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] + ) + qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -251,6 +269,30 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + qat_wrapper_kwargs["activation_dtype"] = ( + activation_dtype + if "activation_dtype" not in qat_wrapper_kwargs + else activation_dtype or qat_wrapper_kwargs["activation_dtype"] + ) + + qat_wrapper_kwargs["weight_dtype"] = ( + weight_dtype + if "weight_dtype" not in qat_wrapper_kwargs + else weight_dtype or qat_wrapper_kwargs["weight_dtype"] + ) + + qat_wrapper_kwargs["activation_bits"] = ( + activation_bits + if "activation_bits" not in qat_wrapper_kwargs + else activation_bits or qat_wrapper_kwargs["activation_bits"] + ) + + qat_wrapper_kwargs["weight_bits"] = ( + weight_bits + if "weight_bits" not in qat_wrapper_kwargs + else weight_bits or qat_wrapper_kwargs["weight_bits"] + ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -266,9 +308,15 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -288,25 +336,43 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn + self._symmetric_activations = symmetric_activations + self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._activation_dtype = activation_dtype + self._weight_dtype = weight_dtype + self._activation_bits = activation_bits + self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -390,9 +456,15 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -422,11 +494,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) + if symmetric_activations is None: + _symmetric_activations = qconfig == "symmetric" + else: + _symmetric_activations = symmetric_activations + qconfigs[idx] = get_qat_qconfig( - symmetric_activations=(qconfig == "symmetric"), + symmetric_activations=_symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) return qconfigs @@ -463,9 +545,15 @@ def configure_module_bn_wrappers(module: Module): def configure_module_qat_wrappers( module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -490,29 +578,43 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) -def compute_range(dtype: torch.dtype, bits: int): +def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + dtype = dtype if dtype else torch.quint8 + bits = bits if bits else 8 if dtype == torch.qint8: - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 + quant_min = -(2 ** (bits - 1)) + quant_max = (2 ** (bits - 1)) - 1 elif dtype == torch.quint8: quant_min = 0 - quant_max = 2 ** bits - 1 + quant_max = (2 ** bits) - 1 - return quant_min, quant_max + return quant_min, quant_max, dtype def configure_module_default_qconfigs(module: Module): @@ -575,15 +677,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = torch.quint8, - weight_dtype: Optional[torch.dtype] = torch.qint8, - activation_bits: Optional[int] = 8, - weight_bits: Optional[int] = 8, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -606,18 +708,28 @@ def get_qat_qconfig( """ activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, activation_qconfig_kwargs) - weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + if symmetric_weights is None: + _symmetric_weights = True + else: + _symmetric_weights = symmetric_weights + + if weight_dtype is None: + _weight_dtype = torch.qint8 + else: + _weight_dtype = weight_dtype + + weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): +def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): qscheme = ( torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - quant_min, quant_max = compute_range(dtype, bits) + quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, quant_min=quant_min, @@ -761,9 +873,15 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -781,11 +899,21 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: + if symmetric_weights is None: + _symmetric_weights = False + else: + _symmetric_weights = symmetric_weights + qconfig = get_qat_qconfig( - symmetric_weights=False, + symmetric_activations=symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 5a5e1913b18..27c5a4c336e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -147,10 +147,10 @@ def __init__( weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_batchnorm: bool = True, - exclude_module_types: Union[List[str], None] = None, + exclude_module_types: Optional[List[str]] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: Optional[bool] = False, + tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -379,7 +379,15 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - return self._weight_qconfig_kwargs + if "observer" in self._weight_qconfig_kwargs: + kwargs = self._weight_qconfig_kwargs.copy() + if kwargs["observer"] == "minmaxobserver": + kwargs["observer"] = torch_quantization.MinMaxObserver + return kwargs + else: + return self._weight_qconfig_kwargs + + @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -389,6 +397,15 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps + @ModifierProp() + def tensorrt(self) -> Dict[str, Any]: + """ + :return: Dictionary with correct quant_min, quant_max, and dtype values + for activations + + """ + return self._tensorrt + def initialize( self, module: Module, @@ -545,17 +562,23 @@ def _enable_module_qat(self, module: Module): # prepare each module / submodule for quantization if self.tensorrt: _symmetric_activations = True - _activations_dtype = torch.qint8 + _activation_dtype = torch.qint8 + _symmetric_weights = True + _weight_dtype = torch.qint8 else: - _symmetric_activations = False - _activations_dtype = torch.quint8 + _symmetric_activations = None + _activation_dtype = None + _symmetric_weights = None + _weight_dtype = None qconfig = get_qat_qconfig( symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=_activations_dtype, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, activation_bits=self.activation_bits, weight_bits=self.weight_bits ) @@ -563,9 +586,15 @@ def _enable_module_qat(self, module: Module): # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -594,9 +623,15 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, ) # propagate custom quant min/max range from FakeQuantize to Observer objects From 15c552bc6cdc64359e9dffd8706842a2febc9f49 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 21 Mar 2022 19:16:26 -0400 Subject: [PATCH 148/218] Included check to account for when weight_qconfig_kwatgs is None. --- .../sparsification/quantization/modifier_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 27c5a4c336e..a306f4d8e73 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -379,7 +379,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if "observer" in self._weight_qconfig_kwargs: + if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver From 34e7a8fad08a644419798b3e2fd6b7649ff1447d Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 12:14:39 -0500 Subject: [PATCH 149/218] Removed output quantization from conv layers --- .../sparsification/quantization/helpers.py | 377 +++--------------- .../quantization/modifier_quantization.py | 130 ++---- 2 files changed, 87 insertions(+), 420 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 2ae713c16aa..75d11c67c31 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,21 +32,20 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ + "QUANTIZABLE_MODULE_TYPES", "QATWrapper", - "configure_module_bn_wrappers", - "configure_module_default_qconfigs", "configure_module_qat_wrappers", + "configure_module_default_qconfigs", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", - "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] -_QUANTIZABLE_MODULE_TYPES = ( +QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -70,106 +69,6 @@ ) -class BNWrapper(Module): - def __init__(self, module: Module): - super().__init__() - self.bn = module - self.freeze_bn = False - - @property - def running_mean(self): - return self.bn.running_mean - - @running_mean.setter - def running_mean(self, value): - self.bn.running_mean = value - - @property - def running_var(self): - return self.bn.running_var - - @running_var.setter - def running_var(self, value): - self.bn.running_var = value - - @property - def weight(self): - return self.bn.weight - - @weight.setter - def weight(self, value): - self.bn.weight = value - - @property - def bias(self): - return self.bn.bias - - @bias.setter - def bias(self, value): - self.bn.bias = value - - @property - def gamma(self): - return self.bn.gamma - - @gamma.setter - def gamma(self, value): - self.bn.gamma = value - - @property - def beta(self): - return self.bn.beta - - @beta.setter - def beta(self, value): - self.bn.beta = value - - @property - def num_batches_tracked(self): - return self.bn.num_batches_tracked - - @num_batches_tracked.setter - def num_batches_tracked(self, value): - self.bn.num_batches_tracked = value - - @property - def eps(self): - return self.bn.eps - - @eps.setter - def eps(self, value): - self.bn.eps = value - - @property - def momentum(self): - return self.bn.momentum - - @momentum.setter - def momentum(self, value): - self.bn.momentum = value - - def forward(self, x): - return self.bn(x) - - def freeze_bn_stats(self): - self.freeze_bn = True - self.bn.training = False - return self - - def reset_running_stats(self): - self.bn.reset_running_stats() - - def train(self, mode=True): - if not self.freeze_bn: - self.bn.train(mode) - return self - - def update_bn_stats(self): - self.freeze_bn = False - self.bn.training = True - return self - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -208,15 +107,9 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, + reduce_range: bool = None, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -238,18 +131,6 @@ def from_module( else {} ) - qat_wrapper_kwargs["symmetric_activations"] = ( - symmetric_activations - if "symmetric_activations" not in qat_wrapper_kwargs - else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] - ) - - qat_wrapper_kwargs["symmetric_weights"] = ( - symmetric_weights or False - if "symmetric_weights" not in qat_wrapper_kwargs - else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] - ) - qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -269,31 +150,6 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) - qat_wrapper_kwargs["activation_dtype"] = ( - activation_dtype - if "activation_dtype" not in qat_wrapper_kwargs - else activation_dtype or qat_wrapper_kwargs["activation_dtype"] - ) - - qat_wrapper_kwargs["weight_dtype"] = ( - weight_dtype - if "weight_dtype" not in qat_wrapper_kwargs - else weight_dtype or qat_wrapper_kwargs["weight_dtype"] - ) - - qat_wrapper_kwargs["activation_bits"] = ( - activation_bits - if "activation_bits" not in qat_wrapper_kwargs - else activation_bits or qat_wrapper_kwargs["activation_bits"] - ) - - qat_wrapper_kwargs["weight_bits"] = ( - weight_bits - if "weight_bits" not in qat_wrapper_kwargs - else weight_bits or qat_wrapper_kwargs["weight_bits"] - ) - - module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -308,15 +164,9 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ): super().__init__() @@ -336,43 +186,25 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn - self._symmetric_activations = symmetric_activations - self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._activation_dtype = activation_dtype - self._weight_dtype = weight_dtype - self._activation_bits = activation_bits - self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -456,15 +288,9 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -494,66 +320,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) - if symmetric_activations is None: - _symmetric_activations = qconfig == "symmetric" - else: - _symmetric_activations = symmetric_activations - qconfigs[idx] = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=symmetric_weights, + symmetric_activations=(qconfig == "symmetric"), reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) return qconfigs -def configure_module_bn_wrappers(module: Module): - """ - if any submodule of the given module has the attribute wrap_qat == True, - then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. - Other named kwargs to the QATWrapper constructor must be contained in a dictionary - under an attributed named `qat_wrapper_kwargs` - - :param module: module to potentially wrap the submodules of - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} - """ - # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, 'freeze_bn_stats'): - for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) - # recurse on child module - configure_module_bn_wrappers(child_module) - - def configure_module_qat_wrappers( module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -578,45 +359,20 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) -def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): - dtype = dtype if dtype else torch.quint8 - bits = bits if bits else 8 - if dtype == torch.qint8: - quant_min = -(2 ** (bits - 1)) - quant_max = (2 ** (bits - 1)) - 1 - elif dtype == torch.quint8: - quant_min = 0 - quant_max = (2 ** bits) - 1 - - return quant_min, quant_max, dtype - - def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -642,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES + type(module) in QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): @@ -677,15 +433,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -706,44 +458,41 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, - activation_qconfig_kwargs) - if symmetric_weights is None: - _symmetric_weights = True - else: - _symmetric_weights = symmetric_weights - - if weight_dtype is None: - _weight_dtype = torch.qint8 - else: - _weight_dtype = weight_dtype - - weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + activation_qscheme = ( + torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine ) - - -def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): - qscheme = ( - torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine - ) - quant_min, quant_max, dtype = compute_range(dtype, bits) - observer_kwargs = dict( + activation_observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - qscheme=qscheme, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=activation_qscheme, reduce_range=reduce_range, ) - observer_kwargs.update(qconfig_kwargs or {}) - observer = torch_quantization.FakeQuantize.with_args( - **observer_kwargs, + activation_observer_kwargs.update(activation_qconfig_kwargs or {}) + activation_observer = torch_quantization.FakeQuantize.with_args( + **activation_observer_kwargs, + ) + weight_qscheme = ( + torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + ) + weight_observer_kwargs = dict( + observer=torch_quantization.MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=weight_qscheme, + reduce_range=reduce_range, ) - return observer + weight_observer_kwargs.update(weight_qconfig_kwargs or {}) + weight_observer = torch_quantization.FakeQuantize.with_args( + **weight_observer_kwargs, + ) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, + ) def fix_observer_quant_range(module: Module): @@ -769,14 +518,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) - or ( # do not propagate default uint8 symmetric range - observer.qscheme == torch.per_tensor_symmetric - and fake_quantize.quant_min == 0 - and fake_quantize.quant_max == 255 - ) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -784,11 +528,6 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True -def freeze_bn_stats(module: Module): - if hasattr(module, 'freeze_bn_stats'): - module.freeze_bn_stats() - - def fuse_module_conv_bn_relus( module: Module, inplace: bool = True, @@ -873,15 +612,9 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -899,28 +632,18 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: - if symmetric_weights is None: - _symmetric_weights = False - else: - _symmetric_weights = symmetric_weights - qconfig = get_qat_qconfig( - symmetric_activations=symmetric_activations, - symmetric_weights=_symmetric_weights, + symmetric_weights=False, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): +def get_updated_qconfig_kwargs(qconfig_kwargs, bits): qconfig_kwargs = ( qconfig_kwargs.copy() if qconfig_kwargs @@ -937,15 +660,9 @@ def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): ) if bits: - if mode == "symmetric": - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 - dtype = torch.qint8 - else: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 - + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index a306f4d8e73..79772790566 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,14 +47,14 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( + QUANTIZABLE_MODULE_TYPES, add_quant_dequant, - configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, - freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, + get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -94,8 +94,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as - 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use + the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,14 +143,13 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, + quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, - exclude_batchnorm: bool = True, - exclude_module_types: Optional[List[str]] = None, + exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -176,9 +175,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits - self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -187,7 +186,6 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -235,11 +233,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - if self._tensorrt: - fuse_fn = 'no_fuse' - else: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' - return fuse_fn + return self._model_fuse_fn_name @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -268,7 +262,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """print + """ :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -336,7 +330,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_conv_output_activations + return self._quantize_linear_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -379,15 +373,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: - kwargs = self._weight_qconfig_kwargs.copy() - if kwargs["observer"] == "minmaxobserver": - kwargs["observer"] = torch_quantization.MinMaxObserver - return kwargs - else: - return self._weight_qconfig_kwargs - - + return self._weight_qconfig_kwargs @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -397,15 +383,6 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps - @ModifierProp() - def tensorrt(self) -> Dict[str, Any]: - """ - :return: Dictionary with correct quant_min, quant_max, and dtype values - for activations - - """ - return self._tensorrt - def initialize( self, module: Module, @@ -439,7 +416,10 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if name in self._submodules: + if ( + type(submodule) in QUANTIZABLE_MODULE_TYPES + and name in self._submodules + ): self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -524,15 +504,15 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(freeze_bn_stats) + quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) - elif self.model_fuse_fn_name != "no_fuse": + if ( + self._model_fuse_fn_name is not None + and self._model_fuse_fn_name != "no_fuse" + ): # module class fn module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -542,10 +522,16 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) + elif self._model_fuse_fn_name is None: # default auto fn + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + + activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() + weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLU"]) + to_remove_layer_name.extend(["Linear", "LinearReLu"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -554,47 +540,20 @@ def _enable_module_qat(self, module: Module): "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) - if len(to_remove_layer_name) == 0: - to_remove_layer_name = None - - configure_module_bn_wrappers(module) # prepare each module / submodule for quantization - if self.tensorrt: - _symmetric_activations = True - _activation_dtype = torch.qint8 - _symmetric_weights = True - _weight_dtype = torch.qint8 - else: - _symmetric_activations = None - _activation_dtype = None - _symmetric_weights = None - _weight_dtype = None - qconfig = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -603,18 +562,9 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - if to_remove_layer_name: - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types - to_exclude = [] - if self._exclude_module_types: - to_exclude.extend(self._exclude_module_types) - - if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) - - self._exclude_module_types = to_exclude if self._exclude_module_types: self._strip_excluded_module_qconfigs(module) @@ -623,15 +573,9 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # propagate custom quant min/max range from FakeQuantize to Observer objects @@ -688,6 +632,12 @@ def _calibrate(self, module): if module_training: module.train() + def _get_updated_activation_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + + def _get_updated_weight_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From 29af80d60336ffabefd0f5908bf89c9b364e6c9a Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:35:49 -0500 Subject: [PATCH 150/218] Added _Add_ReLU module that enables QATWrapper for quantization. --- src/sparseml/pytorch/models/classification/resnet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 21611f211d7..3112da7c2e1 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -140,14 +140,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self, num_channels): + def __init__(self): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: - self.functional = ReLU(num_channels=num_channels, inplace=True) + self.functional = ReLU(num_channels=out_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +179,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU(out_channels) + self.add_relu = _AddReLU() self.initialize() @@ -236,7 +236,7 @@ def __init__( else None ) - self.add_relu = _AddReLU(out_channels) + self.add_relu = _AddReLU() self.initialize() From 5d4bebacf2cdc625702de260623511414fbcfaac Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:36:37 -0500 Subject: [PATCH 151/218] Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. --- .../sparsification/quantization/helpers.py | 6 +-- .../quantization/modifier_quantization.py | 39 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 75d11c67c31..f28656f1712 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,7 +32,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ - "QUANTIZABLE_MODULE_TYPES", "QATWrapper", "configure_module_qat_wrappers", "configure_module_default_qconfigs", @@ -45,7 +44,7 @@ "prepare_embeddings_qat", ] -QUANTIZABLE_MODULE_TYPES = ( +_QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -150,6 +149,7 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -398,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in QUANTIZABLE_MODULE_TYPES + type(module) in _QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 79772790566..f914b1f2b91 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,7 +47,6 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( - QUANTIZABLE_MODULE_TYPES, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, @@ -94,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use - the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' + to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,10 +142,10 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, - quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, + exclude_batchnorm: bool = True, exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, @@ -175,9 +174,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations - self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits + self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -233,7 +232,8 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - return self._model_fuse_fn_name + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + return fuse_fn @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -416,10 +416,7 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if ( - type(submodule) in QUANTIZABLE_MODULE_TYPES - and name in self._submodules - ): + if name in self._submodules: self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -509,10 +506,10 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if ( - self._model_fuse_fn_name is not None - and self._model_fuse_fn_name != "no_fuse" - ): # module class fn + if self._model_fuse_fn_name == 'conv_bn_relus': + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -522,16 +519,13 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) - elif self._model_fuse_fn_name is None: # default auto fn - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLu"]) + to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -565,8 +559,15 @@ def _enable_module_qat(self, module: Module): remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types + to_exclude = [] if self._exclude_module_types: - self._strip_excluded_module_qconfigs(module) + to_exclude.extend(self._exclude_module_types) + + if self._exclude_batchnorm: + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + + self._exclude_module_types = to_exclude + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) From 318fe315e4fb49d966425f5f701627a31fd7d58e Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:40:31 -0500 Subject: [PATCH 152/218] Minor fixes. Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 61 +++++---- .../sparsification/quantization/helpers.py | 129 +++++++++--------- .../quantization/modifier_quantization.py | 33 +++-- 3 files changed, 115 insertions(+), 108 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3112da7c2e1..be4182891d6 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -140,14 +141,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self): + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: - self.functional = ReLU(num_channels=out_channels, inplace=True) + self.functional = ReLU(num_channels=num_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +180,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -205,12 +206,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -236,7 +237,7 @@ def __init__( else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -321,12 +322,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -437,15 +438,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -479,10 +480,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index f28656f1712..ef4445a0d5f 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_qat_wrappers", @@ -105,10 +107,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -140,7 +142,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -153,20 +155,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -285,12 +287,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -331,10 +333,10 @@ def _load_qconfigs( def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -383,7 +385,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -398,9 +400,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -424,7 +426,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -433,11 +435,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -509,7 +511,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -518,9 +520,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -529,9 +531,9 @@ def fix_observer_quant_range(module: Module): def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -566,14 +568,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -610,11 +612,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -644,17 +646,10 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index f914b1f2b91..637bf7e52dd 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -93,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' - to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as + 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -232,7 +232,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" return fuse_fn @model_fuse_fn_name.setter @@ -356,7 +356,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -506,7 +505,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': + if self._model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -529,10 +528,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) # prepare each module / submodule for quantization @@ -564,7 +573,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -634,7 +643,9 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits + ) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From b8e563dc425590529962b67fd395f2bab30aefc6 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:02:14 -0500 Subject: [PATCH 153/218] Added support to freezing bn stats. --- .../sparsification/quantization/helpers.py | 215 +++++++++++++----- .../quantization/modifier_quantization.py | 37 ++- 2 files changed, 167 insertions(+), 85 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index ef4445a0d5f..c4f165d23ef 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,16 +31,17 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", - "configure_module_qat_wrappers", + "configure_module_bn_wrappers", "configure_module_default_qconfigs", + "configure_module_qat_wrappers", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", + "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] @@ -69,6 +69,54 @@ else None ) +_BN_MODULE_TYPES = ( + { + # Conv based layers + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.ConvBnReLU1d, + nni.ConvBnReLU2d, + nni.ConvBnReLU3d, + } + if nni # nni will always import if torch.quantization is available + else {} +) + + +class BNWrapper(Module): + def __init__(self, module: Module): + super().__init__() + self.bn = module + self.freeze_bn = False + + def forward(self, x): + return self.bn(x) + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def train(self, mode=True): + if not self.freeze_bn: + self.bn.train() + return self + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + +_BN_MODULE_TYPES.add(BNWrapper) + class QATWrapper(Module): """ @@ -107,10 +155,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -142,7 +190,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -155,20 +203,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -287,12 +335,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -332,11 +380,40 @@ def _load_qconfigs( return qconfigs +def configure_module_bn_wrappers(module: Module): + """ + if any submodule of the given module has the attribute wrap_qat == True, + then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. + Other named kwargs to the QATWrapper constructor must be contained in a dictionary + under an attributed named `qat_wrapper_kwargs` + + :param module: module to potentially wrap the submodules of + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware + Default is False + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. Default is {} + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. Default is {} + """ + # wrap any children of the given module as a QATWrapper if required + if type(module) != BNWrapper: + for child_name, child_module in module.named_children(): + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) + # recurse on child module + configure_module_bn_wrappers(child_module) + + def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -385,7 +462,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -400,9 +477,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -426,7 +503,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -435,11 +512,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -511,7 +588,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -520,9 +597,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -530,10 +607,15 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True +def freeze_bn_stats(module: Module): + if type(module) in _BN_MODULE_TYPES: + module.freeze_bn_stats() + + def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -568,14 +650,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -612,11 +694,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -646,10 +728,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 637bf7e52dd..7eed410b441 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -48,9 +48,11 @@ from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( add_quant_dequant, + configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, + freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, get_updated_qconfig_kwargs, @@ -232,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' return fuse_fn @model_fuse_fn_name.setter @@ -262,7 +264,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """ + """print :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -356,6 +358,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -500,12 +503,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) + quant_module.apply(freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == "conv_bn_relus": + if self._model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -528,22 +531,14 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) + configure_module_bn_wrappers(module) + # prepare each module / submodule for quantization qconfig = get_qat_qconfig( reduce_range=self._reduce_range, @@ -573,7 +568,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -643,9 +638,7 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits - ) + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From b911a75680d2b6327e4a3c8a7f250b4fe1fbfea1 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:57:15 -0500 Subject: [PATCH 154/218] Added mode argument to wrapping of train function in BNWrapper --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index c4f165d23ef..64958570e2d 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -106,7 +106,7 @@ def reset_running_stats(self): def train(self, mode=True): if not self.freeze_bn: - self.bn.train() + self.bn.train(mode) return self def update_bn_stats(self): From e39e90cc3b921d49e4253a801834e39a13c1503d Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:08:20 -0500 Subject: [PATCH 155/218] Set BN fusing back as default. --- .../sparsification/quantization/modifier_quantization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 7eed410b441..37307e38863 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -508,8 +508,8 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True + if self.model_fuse_fn_name == 'conv_bn_relus': + self.model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 24363c61c62bc087d77ce0c0b0b5a8898fee63c3 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:19:09 -0500 Subject: [PATCH 156/218] Set BN fusing back as default. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- .../sparsification/quantization/modifier_quantization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 64958570e2d..a43d69d947b 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -397,7 +397,7 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) != BNWrapper: + if type(module) not in _BN_MODULE_TYPES: for child_name, child_module in module.named_children(): if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: setattr( diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 37307e38863..2a35ebd2aaf 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -509,7 +509,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs if self.model_fuse_fn_name == 'conv_bn_relus': - self.model_fuse_fn_kwargs["inplace"] = True + self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 242e29b450299b3b834e9c5a8f8e8ea4e34902c9 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 11 Mar 2022 19:24:03 -0500 Subject: [PATCH 157/218] Fixed custom freeze_bn_stats. --- .../sparsification/quantization/helpers.py | 251 +++++++++++------- .../quantization/modifier_quantization.py | 46 +++- 2 files changed, 185 insertions(+), 112 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index a43d69d947b..6110a499b70 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,23 +71,6 @@ else None ) -_BN_MODULE_TYPES = ( - { - # Conv based layers - nni.ConvBn1d, - nni.ConvBn2d, - nni.ConvBn3d, - nni.ConvReLU1d, - nni.ConvReLU2d, - nni.ConvReLU3d, - nni.ConvBnReLU1d, - nni.ConvBnReLU2d, - nni.ConvBnReLU3d, - } - if nni # nni will always import if torch.quantization is available - else {} -) - class BNWrapper(Module): def __init__(self, module: Module): @@ -93,6 +78,78 @@ def __init__(self, module: Module): self.bn = module self.freeze_bn = False + @property + def running_mean(self): + return self.bn.running_mean + + @running_mean.setter + def running_mean(self, value): + self.bn.running_mean = value + + @property + def running_var(self): + return self.bn.running_var + + @running_var.setter + def running_var(self, value): + self.bn.running_var = value + + @property + def weight(self): + return self.bn.weight + + @weight.setter + def weight(self, value): + self.bn.weight = value + + @property + def bias(self): + return self.bn.bias + + @bias.setter + def bias(self, value): + self.bn.bias = value + + @property + def gamma(self): + return self.bn.gamma + + @gamma.setter + def gamma(self, value): + self.bn.gamma = value + + @property + def beta(self): + return self.bn.beta + + @beta.setter + def beta(self, value): + self.bn.beta = value + + @property + def num_batches_tracked(self): + return self.bn.num_batches_tracked + + @num_batches_tracked.setter + def num_batches_tracked(self, value): + self.bn.num_batches_tracked = value + + @property + def eps(self): + return self.bn.eps + + @eps.setter + def eps(self, value): + self.bn.eps = value + + @property + def momentum(self): + return self.bn.momentum + + @momentum.setter + def momentum(self, value): + self.bn.momentum = value + def forward(self, x): return self.bn(x) @@ -115,9 +172,6 @@ def update_bn_stats(self): return self -_BN_MODULE_TYPES.add(BNWrapper) - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -155,10 +209,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -190,7 +244,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -203,20 +257,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -335,12 +389,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -397,23 +451,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) not in _BN_MODULE_TYPES: + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -462,7 +516,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -477,9 +531,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -503,7 +557,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -512,11 +566,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -588,7 +642,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -597,9 +651,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -608,14 +662,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if type(module) in _BN_MODULE_TYPES: + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -650,14 +704,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -694,11 +748,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -727,26 +781,25 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 + if mode == "symmetric": + quant_min = -(2 ** (bits - 1)) + quant_max = 2 ** (bits - 1) - 1 + dtype = torch.qint8 + else: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 2a35ebd2aaf..acbae885d71 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,9 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -332,7 +334,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_linear_output_activations + return self._quantize_conv_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -358,7 +360,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -504,11 +505,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) + # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -531,11 +533,23 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) + if len(to_remove_layer_name) == 0: + to_remove_layer_name = None configure_module_bn_wrappers(module) @@ -560,7 +574,8 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + if to_remove_layer_name: + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types to_exclude = [] @@ -568,10 +583,11 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude - self._strip_excluded_module_qconfigs(module) + if self._exclude_module_types: + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) @@ -638,10 +654,14 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" + ) def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + return get_updated_qconfig_kwargs( + self.weight_qconfig_kwargs, self.weight_bits, "symmetric" + ) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( From 3759f59c46a99e29425d480e82f58420e797e150 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 14 Mar 2022 15:35:52 -0400 Subject: [PATCH 158/218] Temporary files for evaluating changes to graphs. --- .../pytorch/models/classification/resnet.py | 53 +++++++++---------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index be4182891d6..21611f211d7 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,7 +41,6 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: @@ -146,7 +145,7 @@ def __init__(self, num_channels): if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -206,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -322,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -438,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -480,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 0b677f4e208c96a912f3140c112b6aff16fb7d9f Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 17 Mar 2022 11:51:50 -0400 Subject: [PATCH 159/218] Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. --- .../sparsification/quantization/helpers.py | 213 ++++++++++-------- .../quantization/modifier_quantization.py | 58 ++--- 2 files changed, 137 insertions(+), 134 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 6110a499b70..8ae045de9e8 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,7 +31,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -209,10 +207,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -244,7 +242,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -257,20 +255,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -389,12 +387,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -451,23 +449,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, "freeze_bn_stats"): + if not hasattr(module, 'freeze_bn_stats'): for child_name, child_module in module.named_children(): - if type(child_module) in [ - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - ]: - setattr(module, child_name, BNWrapper(child_module)) + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -506,6 +504,17 @@ def configure_module_qat_wrappers( ) +def compute_range(dtype: torch.dtype, bits: int): + if dtype == torch.qint8: + quant_min = -2 ** (bits - 1) + quant_max = 2 ** (bits - 1) - 1 + elif dtype == torch.quint8: + quant_min = 0 + quant_max = 2 ** bits - 1 + + return quant_min, quant_max + + def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -516,7 +525,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -531,9 +540,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -557,7 +566,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -566,11 +575,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = torch.quint8, + weight_dtype: Optional[torch.dtype] = torch.qint8, + activation_bits: Optional[int] = 8, + weight_bits: Optional[int] = 8, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -591,42 +604,35 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_qscheme = ( - torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine - ) - activation_observer_kwargs = dict( - observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - dtype=torch.quint8, - qscheme=activation_qscheme, - reduce_range=reduce_range, - ) - activation_observer_kwargs.update(activation_qconfig_kwargs or {}) - activation_observer = torch_quantization.FakeQuantize.with_args( - **activation_observer_kwargs, + activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, + activation_qconfig_kwargs) + weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, ) - weight_qscheme = ( - torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + + +def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): + qscheme = ( + torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - weight_observer_kwargs = dict( + quant_min, quant_max = compute_range(dtype, bits) + observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=-128, - quant_max=127, - dtype=torch.qint8, - qscheme=weight_qscheme, + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + qscheme=qscheme, reduce_range=reduce_range, ) - - weight_observer_kwargs.update(weight_qconfig_kwargs or {}) - weight_observer = torch_quantization.FakeQuantize.with_args( - **weight_observer_kwargs, - ) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + observer_kwargs.update(qconfig_kwargs or {}) + observer = torch_quantization.FakeQuantize.with_args( + **observer_kwargs, ) + return observer + def fix_observer_quant_range(module: Module): """ @@ -642,7 +648,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -651,9 +657,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -662,14 +668,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, "freeze_bn_stats"): + if hasattr(module, 'freeze_bn_stats'): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -704,14 +710,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -748,11 +754,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -782,17 +788,24 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) + quant_min = -2 ** (bits - 1) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index acbae885d71..5a5e1913b18 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -55,7 +55,6 @@ freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, - get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -151,6 +150,7 @@ def __init__( exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + tensorrt: Optional[bool] = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -187,6 +187,7 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -234,9 +235,10 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = ( - self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" - ) + if self._tensorrt: + fuse_fn = 'no_fuse' + else: + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -360,6 +362,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -505,12 +508,11 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) - # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == "conv_bn_relus": + if self.model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -524,29 +526,16 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) - activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() - weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() - to_remove_layer_name = [] if not self._quantize_linear_output_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -554,10 +543,21 @@ def _enable_module_qat(self, module: Module): configure_module_bn_wrappers(module) # prepare each module / submodule for quantization + if self.tensorrt: + _symmetric_activations = True + _activations_dtype = torch.qint8 + else: + _symmetric_activations = False + _activations_dtype = torch.quint8 + qconfig = get_qat_qconfig( + symmetric_activations=_symmetric_activations, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=_activations_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) @@ -583,7 +583,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude if self._exclude_module_types: @@ -653,16 +653,6 @@ def _calibrate(self, module): if module_training: module.train() - def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" - ) - - def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.weight_qconfig_kwargs, self.weight_bits, "symmetric" - ) - def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From df5844db2dddc76f748403d7f6dec039ac528658 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Sun, 20 Mar 2022 11:42:14 -0400 Subject: [PATCH 160/218] Added support to TensorRT quantization --- .../sparsification/quantization/helpers.py | 166 ++++++++++++++++-- .../quantization/modifier_quantization.py | 61 +++++-- 2 files changed, 195 insertions(+), 32 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 8ae045de9e8..027c7514c32 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -208,9 +208,15 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - reduce_range: bool = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -232,6 +238,18 @@ def from_module( else {} ) + qat_wrapper_kwargs["symmetric_activations"] = ( + symmetric_activations + if "symmetric_activations" not in qat_wrapper_kwargs + else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] + ) + + qat_wrapper_kwargs["symmetric_weights"] = ( + symmetric_weights or False + if "symmetric_weights" not in qat_wrapper_kwargs + else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] + ) + qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -251,6 +269,30 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + qat_wrapper_kwargs["activation_dtype"] = ( + activation_dtype + if "activation_dtype" not in qat_wrapper_kwargs + else activation_dtype or qat_wrapper_kwargs["activation_dtype"] + ) + + qat_wrapper_kwargs["weight_dtype"] = ( + weight_dtype + if "weight_dtype" not in qat_wrapper_kwargs + else weight_dtype or qat_wrapper_kwargs["weight_dtype"] + ) + + qat_wrapper_kwargs["activation_bits"] = ( + activation_bits + if "activation_bits" not in qat_wrapper_kwargs + else activation_bits or qat_wrapper_kwargs["activation_bits"] + ) + + qat_wrapper_kwargs["weight_bits"] = ( + weight_bits + if "weight_bits" not in qat_wrapper_kwargs + else weight_bits or qat_wrapper_kwargs["weight_bits"] + ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -266,9 +308,15 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -288,25 +336,43 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn + self._symmetric_activations = symmetric_activations + self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._activation_dtype = activation_dtype + self._weight_dtype = weight_dtype + self._activation_bits = activation_bits + self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -390,9 +456,15 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -422,11 +494,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) + if symmetric_activations is None: + _symmetric_activations = qconfig == "symmetric" + else: + _symmetric_activations = symmetric_activations + qconfigs[idx] = get_qat_qconfig( - symmetric_activations=(qconfig == "symmetric"), + symmetric_activations=_symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) return qconfigs @@ -463,9 +545,15 @@ def configure_module_bn_wrappers(module: Module): def configure_module_qat_wrappers( module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -490,29 +578,43 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) -def compute_range(dtype: torch.dtype, bits: int): +def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + dtype = dtype if dtype else torch.quint8 + bits = bits if bits else 8 if dtype == torch.qint8: - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 + quant_min = -(2 ** (bits - 1)) + quant_max = (2 ** (bits - 1)) - 1 elif dtype == torch.quint8: quant_min = 0 - quant_max = 2 ** bits - 1 + quant_max = (2 ** bits) - 1 - return quant_min, quant_max + return quant_min, quant_max, dtype def configure_module_default_qconfigs(module: Module): @@ -575,15 +677,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = torch.quint8, - weight_dtype: Optional[torch.dtype] = torch.qint8, - activation_bits: Optional[int] = 8, - weight_bits: Optional[int] = 8, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -606,18 +708,28 @@ def get_qat_qconfig( """ activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, activation_qconfig_kwargs) - weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + if symmetric_weights is None: + _symmetric_weights = True + else: + _symmetric_weights = symmetric_weights + + if weight_dtype is None: + _weight_dtype = torch.qint8 + else: + _weight_dtype = weight_dtype + + weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): +def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): qscheme = ( torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - quant_min, quant_max = compute_range(dtype, bits) + quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, quant_min=quant_min, @@ -756,9 +868,15 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -776,11 +894,21 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: + if symmetric_weights is None: + _symmetric_weights = False + else: + _symmetric_weights = symmetric_weights + qconfig = get_qat_qconfig( - symmetric_weights=False, + symmetric_activations=symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 5a5e1913b18..27c5a4c336e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -147,10 +147,10 @@ def __init__( weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_batchnorm: bool = True, - exclude_module_types: Union[List[str], None] = None, + exclude_module_types: Optional[List[str]] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: Optional[bool] = False, + tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -379,7 +379,15 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - return self._weight_qconfig_kwargs + if "observer" in self._weight_qconfig_kwargs: + kwargs = self._weight_qconfig_kwargs.copy() + if kwargs["observer"] == "minmaxobserver": + kwargs["observer"] = torch_quantization.MinMaxObserver + return kwargs + else: + return self._weight_qconfig_kwargs + + @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -389,6 +397,15 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps + @ModifierProp() + def tensorrt(self) -> Dict[str, Any]: + """ + :return: Dictionary with correct quant_min, quant_max, and dtype values + for activations + + """ + return self._tensorrt + def initialize( self, module: Module, @@ -545,17 +562,23 @@ def _enable_module_qat(self, module: Module): # prepare each module / submodule for quantization if self.tensorrt: _symmetric_activations = True - _activations_dtype = torch.qint8 + _activation_dtype = torch.qint8 + _symmetric_weights = True + _weight_dtype = torch.qint8 else: - _symmetric_activations = False - _activations_dtype = torch.quint8 + _symmetric_activations = None + _activation_dtype = None + _symmetric_weights = None + _weight_dtype = None qconfig = get_qat_qconfig( symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=_activations_dtype, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, activation_bits=self.activation_bits, weight_bits=self.weight_bits ) @@ -563,9 +586,15 @@ def _enable_module_qat(self, module: Module): # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -594,9 +623,15 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, ) # propagate custom quant min/max range from FakeQuantize to Observer objects From 3a01361a51a0f9a2f6ab32fa4ab008d2f1bb7d11 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 21 Mar 2022 19:16:26 -0400 Subject: [PATCH 161/218] Included check to account for when weight_qconfig_kwatgs is None. --- .../sparsification/quantization/modifier_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 27c5a4c336e..a306f4d8e73 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -379,7 +379,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if "observer" in self._weight_qconfig_kwargs: + if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver From 3ea4268f45a0f5c08089694b4cd71cd42a85dcfb Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 14:20:19 -0400 Subject: [PATCH 162/218] Modified argument names for backwards compatibility. --- .../quantization/modifier_quantization.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index a306f4d8e73..73a50e0f9c4 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -141,8 +141,8 @@ def __init__( model_fuse_fn_kwargs: Dict[str, Any] = None, quantize_embeddings: bool = True, reduce_range: bool = False, - quantize_linear_output_activations: bool = False, - quantize_conv_output_activations: bool = False, + quantize_linear_activations: bool = False, + quantize_conv_activations: bool = False, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, @@ -174,8 +174,8 @@ def __init__( self._freeze_bn_stats_epoch = freeze_bn_stats_epoch self._quantize_embeddings = quantize_embeddings self._reduce_range = reduce_range - self._quantize_linear_output_activations = quantize_linear_output_activations - self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_linear_activations = quantize_linear_activations + self._quantize_conv_activations = quantize_conv_activations self._activation_bits = activation_bits self._weight_bits = weight_bits self._exclude_batchnorm = exclude_batchnorm @@ -320,7 +320,7 @@ def reduce_range(self) -> bool: return self._reduce_range @ModifierProp() - def quantize_linear_output_activations(self) -> bool: + def quantize_linear_activations(self) -> bool: """ :return: if False, FakeQuantize ops will not be run for activations of fully connected layers. this is important for quantizing @@ -328,15 +328,15 @@ def quantize_linear_output_activations(self) -> bool: are kept at 32 bits of precision and fake quantizing the outputs harm training recovery """ - return self._quantize_linear_output_activations + return self._quantize_linear_activations @ModifierProp() - def quantize_conv_output_activations(self) -> bool: + def quantize_conv_activations(self) -> bool: """ :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_conv_output_activations + return self._quantize_conv_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -544,10 +544,10 @@ def _enable_module_qat(self, module: Module): module_fuse_fn(**self._model_fuse_fn_kwargs) to_remove_layer_name = [] - if not self._quantize_linear_output_activations: + if not self._quantize_linear_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) - if not self._quantize_conv_output_activations: + if not self._quantize_conv_activations: to_remove_layer_name.extend( ["Conv1d", "Conv2d", "Conv3d", "ConvBn1d", "ConvBn2d", "ConvBn3d", From cdf316e6ace50e358f8dd4f763b9ad6a75c53e7c Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:40:51 -0400 Subject: [PATCH 163/218] Updated documentation to reflect changes. --- .../sparsification/quantization/helpers.py | 118 ++++++++++++------ 1 file changed, 81 insertions(+), 37 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 027c7514c32..bc9aeb6d58c 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -69,8 +69,14 @@ else None ) - +# class BNWrapper(Module): + """ + Wraps BatchNormalization module to expose methods needed to enable + freezing/unfreezing of statistics + + :param module: BatchNormalization module to be wrapped + """ def __init__(self, module: Module): super().__init__() self.bn = module @@ -220,14 +226,25 @@ def from_module( ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for - :param reduce_range: if True, the quantization range will be reduced by one - bit. This may prevent overflow issues with model execution on certain - hardware. Default is None, will only override qat_wrapper_kwargs if set - to a bool value + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. :return: QATWrapper object created using the given Module as the forward function. Will attempt to find any other named parameter of the QATWrapper constructor from the attributes of the given Module @@ -293,6 +310,7 @@ def from_module( else weight_bits or qat_wrapper_kwargs["weight_bits"] ) + # Remove qconfig from wrapped layer to avoid duplicate quantization module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -516,19 +534,10 @@ def _load_qconfigs( def configure_module_bn_wrappers(module: Module): """ - if any submodule of the given module has the attribute wrap_qat == True, - then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. - Other named kwargs to the QATWrapper constructor must be contained in a dictionary - under an attributed named `qat_wrapper_kwargs` + Wrap any BatchNormalization modules that are not fused with convolutions + with BNWrapper to enable freezing/unfreezing of BN statistics :param module: module to potentially wrap the submodules of - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required if not hasattr(module, 'freeze_bn_stats'): @@ -562,14 +571,25 @@ def configure_module_qat_wrappers( under an attributed named `qat_wrapper_kwargs` :param module: module to potentially wrap the submodules of + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} - """ + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. """ # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: @@ -605,6 +625,13 @@ def configure_module_qat_wrappers( def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + """ + compute quantization limits depending on data type and number of bits + + :param dtype: data type. If None dtype is set to torch.quint8. + :param bits: number of bits. If None is set to 8. + :return: minimum limit, maximum limit, data type + """ dtype = dtype if dtype else torch.quint8 bits = bits if bits else 8 if dtype == torch.qint8: @@ -689,18 +716,24 @@ def get_qat_qconfig( ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric - UINT8 quantization range with zero point set to 128. Otherwise activations - will use asymmetric quantization with any zero point. Default is False + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. :param symmetric_weights: if True, weights will have a symmetric - INT8 quantization range with zero point set to 0. Otherwise activations - will use asymmetric quantization with any zero point. Default is True + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. :return: A QAT fake quantization config for symmetric weight quantization and asymmetric activation quantization. The difference between this and torch.quantization.default_qat_qconfig is that the activation observer @@ -885,14 +918,25 @@ def prepare_embeddings_qat( :param module: module to run QAT for the embeddings of :param qconfig: qconfig to generate the fake quantize ops from. Default uses INT8 asymmetric range - :param activation_qconfig_kwargs: additional kwargs for quantizing activations. - Default is {}. - :param weight_qconfig_kwargs: additional kwargs for quantizing the weights. - Default is {}. + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware. - Default is False - """ + Default is False. + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. """ if qconfig is None: if symmetric_weights is None: _symmetric_weights = False From f36cc4d258116961a7088118bb586f7aedddcf54 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:40:57 -0400 Subject: [PATCH 164/218] Updated documentation to reflect changes. --- .../quantization/modifier_quantization.py | 59 ++++++++++++------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 73a50e0f9c4..4f912b3d8bb 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -113,21 +113,26 @@ class QuantizationModifier(ScheduledModifier): :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware Default is False - :param quantize_linear_activations: if False, FakeQuantize ops will not be run - for activations of fully connected layers. this is important for quantizing - transformer based models such as BERT where the quantized MatMul outputs - are kept at 32 bits of precision and fake quantizing the outputs harm training - recovery. Default is True + :param quantize_linear_activations: if True, FakeQuantize ops will be run + for output activations of fully connected layers. Default is False. + :param quantize_conv_activations: if True, FakeQuantize ops will be run + for output activations of convolutional layers. Default is False. :param activation_bits: Number of bits to use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. + activations. Default is None, which will quantize activations to 8 bits. + :param weight_bits: Number of bits to use for setting quant min/max values for + weights. Default is None, which will quantize weights to 8 bits. :param num_calibration_steps: Number of steps to run post training calibration for. - When None, the entire calibration_dataloader is used + When None, the entire calibration_dataloader is used + :param exclude_batchnorm: If True, do not propagate quantization qconfigs to + batch-normalization modules :param exclude_module_types: optional list of module class names to not propagate quantization configs to. Default is None :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. + weights. + :param tenssorrt: if True sets quantization configuration for compatibility with + explict quantization as supported by TensorRT 8.2. """ def __init__( @@ -232,11 +237,12 @@ def submodules(self, value: Union[List[str], None]): def model_fuse_fn_name(self) -> Union[str, None]: """ :return: Name of model function to fuse the model in place prior - to performing QAT. None to uses the default function + to performing QAT. None sets to default function. + If tensorrt flag is True, default is 'no_fuse', otherwise `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ if self._tensorrt: - fuse_fn = 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' else: fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @@ -322,19 +328,16 @@ def reduce_range(self) -> bool: @ModifierProp() def quantize_linear_activations(self) -> bool: """ - :return: if False, FakeQuantize ops will not be run - for activations of fully connected layers. this is important for quantizing - transformer based models such as BERT where the quantized MatMul outputs - are kept at 32 bits of precision and fake quantizing the outputs harm - training recovery + :return: if True, FakeQuantize ops will be run for output activations + of fully connected layers """ return self._quantize_linear_activations @ModifierProp() def quantize_conv_activations(self) -> bool: """ - :return: if False, FakeQuantize ops will not be run - for activations of convolutional layers. + :return: if True, FakeQuantize ops will be run for output activations + of convolutional layers """ return self._quantize_conv_activations @@ -358,7 +361,7 @@ def activation_bits(self) -> Optional[int]: def weight_bits(self) -> Optional[int]: """ :return: Number of bits to be use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. + weights. Default is None, which will quantize weights to 8 bits. """ return self._weight_bits @@ -543,6 +546,7 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) + # build list of layer types that should not quantize output activations to_remove_layer_name = [] if not self._quantize_linear_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) @@ -557,9 +561,16 @@ def _enable_module_qat(self, module: Module): if len(to_remove_layer_name) == 0: to_remove_layer_name = None + # fix for freezing batchnorm statistics when not fusing BN with convs. + # pytorch only supports freezing batchnorm statistics for fused modules. + # this fix wraps BN modules adding with a new module class that supports + # methods related to freezing/unfreezing BN statistics. configure_module_bn_wrappers(module) - # prepare each module / submodule for quantization + # set qconfig. + # if tensorrt flag is used, set activation and weights to symmetric + # quantization. + # otherwise, use the default values set in get_qat_qconfig if self.tensorrt: _symmetric_activations = True _activation_dtype = torch.qint8 @@ -582,6 +593,8 @@ def _enable_module_qat(self, module: Module): activation_bits=self.activation_bits, weight_bits=self.weight_bits ) + + # prepare each module / submodule for quantization for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( @@ -596,13 +609,17 @@ def _enable_module_qat(self, module: Module): activation_bits=self.activation_bits, weight_bits=self.weight_bits ) + # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig + # wrap all conv / linear blocks in with quantization observers torch_quantization.propagate_qconfig_(quant_module) configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) + + # Remove output quantization from appropriate modules if to_remove_layer_name: remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) @@ -611,6 +628,8 @@ def _enable_module_qat(self, module: Module): if self._exclude_module_types: to_exclude.extend(self._exclude_module_types) + # if exclude_batchnorm flag is used, add batch norm layers to list of + # modules to exclude qconfig if self._exclude_batchnorm: to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) From 416e938fb136e138e7e9342f7f3be13ded163eea Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:42:27 -0400 Subject: [PATCH 165/218] Updated documentation to reflect changes. --- src/sparseml/pytorch/models/classification/resnet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 21611f211d7..3a7a5169447 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -140,6 +140,10 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): + """ + Wrapper for the FloatFunctional class that enables QATWrapper used to + quantize the first input to the Add operation + """ def __init__(self, num_channels): super().__init__() if FloatFunctional: From f56d3304d1d508c8dff09e25b0de4dbef6a7237d Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:52:15 -0400 Subject: [PATCH 166/218] Fixed default weights data type. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index bc9aeb6d58c..b3e47162c5e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -751,7 +751,7 @@ def get_qat_qconfig( else: _weight_dtype = weight_dtype - weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + weight_observer = get_observer(_symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, From 6a3fa1d1690a3035ed415f0414fa3df106cf19e5 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 17:02:48 -0400 Subject: [PATCH 167/218] Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 54 ++-- .../sparsification/quantization/helpers.py | 247 +++++++++--------- .../quantization/modifier_quantization.py | 44 +++- 3 files changed, 186 insertions(+), 159 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3a7a5169447..cd8b979c3ad 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -144,12 +145,13 @@ class _AddReLU(Module): Wrapper for the FloatFunctional class that enables QATWrapper used to quantize the first input to the Add operation """ + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -209,12 +211,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -325,12 +327,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -441,15 +443,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -483,10 +485,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index b3e47162c5e..c2e21d30a16 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,7 +71,7 @@ else None ) -# + class BNWrapper(Module): """ Wraps BatchNormalization module to expose methods needed to enable @@ -77,6 +79,7 @@ class BNWrapper(Module): :param module: BatchNormalization module to be wrapped """ + def __init__(self, module: Module): super().__init__() self.bn = module @@ -213,16 +216,16 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -241,8 +244,10 @@ def from_module( activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_dtype: quantized activation data type. + Default is torch.quint8. + :param weight_dtype: quantized weights data type. + Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. :param weight_bits: number of bits for weights. Default is 8. :return: QATWrapper object created using the given Module as the forward @@ -277,7 +282,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -315,26 +320,26 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -471,18 +476,18 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -540,29 +545,29 @@ def configure_module_bn_wrappers(module: Module): :param module: module to potentially wrap the submodules of """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, 'freeze_bn_stats'): + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -589,7 +594,7 @@ def configure_module_qat_wrappers( :param activation_dtype: quantized activation data type. Default is torch.quint8. :param weight_dtype: quantized weights data type. Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. """ + :param weight_bits: number of bits for weights. Default is 8.""" # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: @@ -654,7 +659,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -669,9 +674,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -695,7 +700,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -704,15 +709,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -739,8 +744,13 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, - activation_qconfig_kwargs) + activation_observer = get_observer( + symmetric_activations, + activation_dtype, + activation_bits, + reduce_range, + activation_qconfig_kwargs, + ) if symmetric_weights is None: _symmetric_weights = True else: @@ -751,17 +761,23 @@ def get_qat_qconfig( else: _weight_dtype = weight_dtype - weight_observer = get_observer(_symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs) + weight_observer = get_observer( + _symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs + ) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): - qscheme = ( - torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine - ) +def get_observer( + symmetric: Optional[bool], + dtype: Optional[torch.dtype], + bits: Optional[int], + reduce_range: bool, + qconfig_kwargs: Dict[str, Any], +): + qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, @@ -793,7 +809,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -813,14 +829,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, 'freeze_bn_stats'): + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -855,14 +871,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -899,17 +915,17 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -936,7 +952,7 @@ def prepare_embeddings_qat( :param activation_dtype: quantized activation data type. Default is torch.quint8. :param weight_dtype: quantized weights data type. Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. """ + :param weight_bits: number of bits for weights. Default is 8.""" if qconfig is None: if symmetric_weights is None: _symmetric_weights = False @@ -960,24 +976,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -2 ** (bits - 1) + quant_min = -(2 ** (bits - 1)) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 4f912b3d8bb..30e1aefbe15 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -242,9 +242,15 @@ def model_fuse_fn_name(self) -> Union[str, None]: `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ if self._tensorrt: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + ) else: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name + if self._model_fuse_fn_name + else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -365,7 +371,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -382,7 +387,10 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: + if ( + self._weight_qconfig_kwargs is not None + and "observer" in self._weight_qconfig_kwargs + ): kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver @@ -390,8 +398,6 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: else: return self._weight_qconfig_kwargs - - @ModifierProp() def num_calibration_steps(self) -> Optional[int]: """ @@ -532,7 +538,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -553,10 +559,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -591,7 +607,7 @@ def _enable_module_qat(self, module: Module): activation_dtype=_activation_dtype, weight_dtype=_weight_dtype, activation_bits=self.activation_bits, - weight_bits=self.weight_bits + weight_bits=self.weight_bits, ) # prepare each module / submodule for quantization @@ -607,7 +623,7 @@ def _enable_module_qat(self, module: Module): activation_dtype=_activation_dtype, weight_dtype=_weight_dtype, activation_bits=self.activation_bits, - weight_bits=self.weight_bits + weight_bits=self.weight_bits, ) # set quantization config (asymmetric activations, symmetric weights) @@ -631,7 +647,7 @@ def _enable_module_qat(self, module: Module): # if exclude_batchnorm flag is used, add batch norm layers to list of # modules to exclude qconfig if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude if self._exclude_module_types: From 011fce92a7261cd45df08a1ba13562a118b90d8f Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 17:53:05 -0400 Subject: [PATCH 168/218] Removed unused method --- .../sparsification/quantization/helpers.py | 31 ------------------- 1 file changed, 31 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index c2e21d30a16..6c30789fbb7 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -41,7 +41,6 @@ "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", - "get_updated_qconfig_kwargs", "fix_observer_quant_range", "freeze_bn_stats", "fuse_module_conv_bn_relus", @@ -975,36 +974,6 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} - - # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): - raise ValueError( - "Cannot override quant_max and quant_min when number of bits is set" - ) - - if bits: - if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) - quant_max = 2 ** (bits - 1) - 1 - dtype = torch.qint8 - else: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 - - qconfig_kwargs.update( - dict( - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - ) - ) - - return qconfig_kwargs - - def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"): embedding.weight_fake_quant = qconfig.weight() From 30e61f3806414c2435f87f5e235ba41a4b88c4c3 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 12:14:39 -0500 Subject: [PATCH 169/218] Removed output quantization from conv layers --- .../sparsification/quantization/helpers.py | 601 +++++------------- .../quantization/modifier_quantization.py | 231 +++---- 2 files changed, 226 insertions(+), 606 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 6c30789fbb7..e10224bbce7 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,22 +31,21 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ + "QUANTIZABLE_MODULE_TYPES", "QATWrapper", - "configure_module_bn_wrappers", - "configure_module_default_qconfigs", "configure_module_qat_wrappers", + "configure_module_default_qconfigs", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", + "get_updated_qconfig_kwargs", "fix_observer_quant_range", - "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] -_QUANTIZABLE_MODULE_TYPES = ( +QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -71,113 +69,6 @@ ) -class BNWrapper(Module): - """ - Wraps BatchNormalization module to expose methods needed to enable - freezing/unfreezing of statistics - - :param module: BatchNormalization module to be wrapped - """ - - def __init__(self, module: Module): - super().__init__() - self.bn = module - self.freeze_bn = False - - @property - def running_mean(self): - return self.bn.running_mean - - @running_mean.setter - def running_mean(self, value): - self.bn.running_mean = value - - @property - def running_var(self): - return self.bn.running_var - - @running_var.setter - def running_var(self, value): - self.bn.running_var = value - - @property - def weight(self): - return self.bn.weight - - @weight.setter - def weight(self, value): - self.bn.weight = value - - @property - def bias(self): - return self.bn.bias - - @bias.setter - def bias(self, value): - self.bn.bias = value - - @property - def gamma(self): - return self.bn.gamma - - @gamma.setter - def gamma(self, value): - self.bn.gamma = value - - @property - def beta(self): - return self.bn.beta - - @beta.setter - def beta(self, value): - self.bn.beta = value - - @property - def num_batches_tracked(self): - return self.bn.num_batches_tracked - - @num_batches_tracked.setter - def num_batches_tracked(self, value): - self.bn.num_batches_tracked = value - - @property - def eps(self): - return self.bn.eps - - @eps.setter - def eps(self, value): - self.bn.eps = value - - @property - def momentum(self): - return self.bn.momentum - - @momentum.setter - def momentum(self, value): - self.bn.momentum = value - - def forward(self, x): - return self.bn(x) - - def freeze_bn_stats(self): - self.freeze_bn = True - self.bn.training = False - return self - - def reset_running_stats(self): - self.bn.reset_running_stats() - - def train(self, mode=True): - if not self.freeze_bn: - self.bn.train(mode) - return self - - def update_bn_stats(self): - self.freeze_bn = False - self.bn.training = True - return self - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -215,40 +106,21 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for - :param symmetric_activations: if True, activations will have a symmetric - quantization range with a pre-specified zero point - (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). - Default is False. - :param symmetric_weights: if True, weights will have a symmetric - quantization range with a pre-specified zero point - (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). - Default is True. - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware. - Default is False. + :param reduce_range: if True, the quantization range will be reduced by one + bit. This may prevent overflow issues with model execution on certain + hardware. Default is None, will only override qat_wrapper_kwargs if set + to a bool value :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. Default is {} :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param activation_dtype: quantized activation data type. - Default is torch.quint8. - :param weight_dtype: quantized weights data type. - Default is torch.qint8. - :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. + weights. Default is {} :return: QATWrapper object created using the given Module as the forward function. Will attempt to find any other named parameter of the QATWrapper constructor from the attributes of the given Module @@ -259,18 +131,6 @@ def from_module( else {} ) - qat_wrapper_kwargs["symmetric_activations"] = ( - symmetric_activations - if "symmetric_activations" not in qat_wrapper_kwargs - else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] - ) - - qat_wrapper_kwargs["symmetric_weights"] = ( - symmetric_weights or False - if "symmetric_weights" not in qat_wrapper_kwargs - else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] - ) - qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -281,7 +141,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -290,55 +150,23 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) - qat_wrapper_kwargs["activation_dtype"] = ( - activation_dtype - if "activation_dtype" not in qat_wrapper_kwargs - else activation_dtype or qat_wrapper_kwargs["activation_dtype"] - ) - - qat_wrapper_kwargs["weight_dtype"] = ( - weight_dtype - if "weight_dtype" not in qat_wrapper_kwargs - else weight_dtype or qat_wrapper_kwargs["weight_dtype"] - ) - - qat_wrapper_kwargs["activation_bits"] = ( - activation_bits - if "activation_bits" not in qat_wrapper_kwargs - else activation_bits or qat_wrapper_kwargs["activation_bits"] - ) - - qat_wrapper_kwargs["weight_bits"] = ( - weight_bits - if "weight_bits" not in qat_wrapper_kwargs - else weight_bits or qat_wrapper_kwargs["weight_bits"] - ) - - # Remove qconfig from wrapped layer to avoid duplicate quantization - module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -358,43 +186,25 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn - self._symmetric_activations = symmetric_activations - self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._activation_dtype = activation_dtype - self._weight_dtype = weight_dtype - self._activation_bits = activation_bits - self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -475,18 +285,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -516,57 +320,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) - if symmetric_activations is None: - _symmetric_activations = qconfig == "symmetric" - else: - _symmetric_activations = symmetric_activations - qconfigs[idx] = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=symmetric_weights, + symmetric_activations=(qconfig == "symmetric"), reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) return qconfigs -def configure_module_bn_wrappers(module: Module): - """ - Wrap any BatchNormalization modules that are not fused with convolutions - with BNWrapper to enable freezing/unfreezing of BN statistics - - :param module: module to potentially wrap the submodules of - """ - # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, "freeze_bn_stats"): - for child_name, child_module in module.named_children(): - if type(child_module) in [ - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - ]: - setattr(module, child_name, BNWrapper(child_module)) - # recurse on child module - configure_module_bn_wrappers(child_module) - - def configure_module_qat_wrappers( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -575,25 +343,14 @@ def configure_module_qat_wrappers( under an attributed named `qat_wrapper_kwargs` :param module: module to potentially wrap the submodules of - :param symmetric_activations: if True, activations will have a symmetric - quantization range with a pre-specified zero point - (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). - Default is False. - :param symmetric_weights: if True, weights will have a symmetric - quantization range with a pre-specified zero point - (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). - Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware. - Default is False. + This may prevent overflow issues with model execution on certain hardware + Default is False :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. Default is {} :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. - :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8.""" + weights. Default is {} + """ # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: @@ -602,52 +359,20 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) -def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): - """ - compute quantization limits depending on data type and number of bits - - :param dtype: data type. If None dtype is set to torch.quint8. - :param bits: number of bits. If None is set to 8. - :return: minimum limit, maximum limit, data type - """ - dtype = dtype if dtype else torch.quint8 - bits = bits if bits else 8 - if dtype == torch.qint8: - quant_min = -(2 ** (bits - 1)) - quant_max = (2 ** (bits - 1)) - 1 - elif dtype == torch.quint8: - quant_min = 0 - quant_max = (2 ** bits) - 1 - - return quant_min, quant_max, dtype - - def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -658,7 +383,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -673,9 +398,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -699,7 +424,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -708,90 +433,66 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric - quantization range with a pre-specified zero point - (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). - Default is False. + UINT8 quantization range with zero point set to 128. Otherwise activations + will use asymmetric quantization with any zero point. Default is False :param symmetric_weights: if True, weights will have a symmetric - quantization range with a pre-specified zero point - (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). - Default is True. + INT8 quantization range with zero point set to 0. Otherwise activations + will use asymmetric quantization with any zero point. Default is True :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware. - Default is False. + This may prevent overflow issues with model execution on certain hardware + Default is False :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. - :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. + weights. :return: A QAT fake quantization config for symmetric weight quantization and asymmetric activation quantization. The difference between this and torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_observer = get_observer( - symmetric_activations, - activation_dtype, - activation_bits, - reduce_range, - activation_qconfig_kwargs, + activation_qscheme = ( + torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine ) - if symmetric_weights is None: - _symmetric_weights = True - else: - _symmetric_weights = symmetric_weights - - if weight_dtype is None: - _weight_dtype = torch.qint8 - else: - _weight_dtype = weight_dtype - - weight_observer = get_observer( - _symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs + activation_observer_kwargs = dict( + observer=torch_quantization.MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=activation_qscheme, + reduce_range=reduce_range, ) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + activation_observer_kwargs.update(activation_qconfig_kwargs or {}) + activation_observer = torch_quantization.FakeQuantize.with_args( + **activation_observer_kwargs, ) - - -def get_observer( - symmetric: Optional[bool], - dtype: Optional[torch.dtype], - bits: Optional[int], - reduce_range: bool, - qconfig_kwargs: Dict[str, Any], -): - qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine - quant_min, quant_max, dtype = compute_range(dtype, bits) - observer_kwargs = dict( + weight_qscheme = ( + torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + ) + weight_observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - qscheme=qscheme, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=weight_qscheme, reduce_range=reduce_range, ) - observer_kwargs.update(qconfig_kwargs or {}) - observer = torch_quantization.FakeQuantize.with_args( - **observer_kwargs, - ) - return observer + weight_observer_kwargs.update(weight_qconfig_kwargs or {}) + weight_observer = torch_quantization.FakeQuantize.with_args( + **weight_observer_kwargs, + ) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, + ) def fix_observer_quant_range(module: Module): @@ -808,7 +509,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -817,9 +518,14 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) + or ( # do not propagate default uint8 symmetric range + observer.qscheme == torch.per_tensor_symmetric + and fake_quantize.quant_min == 0 + and fake_quantize.quant_max == 255 + ) ): continue observer.quant_min = fake_quantize.quant_min @@ -827,15 +533,10 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True -def freeze_bn_stats(module: Module): - if hasattr(module, "freeze_bn_stats"): - module.freeze_bn_stats() - - def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -870,14 +571,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -914,17 +615,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -933,47 +628,57 @@ def prepare_embeddings_qat( :param module: module to run QAT for the embeddings of :param qconfig: qconfig to generate the fake quantize ops from. Default uses INT8 asymmetric range - :param symmetric_activations: if True, activations will have a symmetric - quantization range with a pre-specified zero point - (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). - Default is False. - :param symmetric_weights: if True, weights will have a symmetric - quantization range with a pre-specified zero point - (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). - Default is True. + :param activation_qconfig_kwargs: additional kwargs for quantizing activations. + Default is {}. + :param weight_qconfig_kwargs: additional kwargs for quantizing the weights. + Default is {}. :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware. - Default is False. - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. - :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8.""" + Default is False + """ if qconfig is None: - if symmetric_weights is None: - _symmetric_weights = False - else: - _symmetric_weights = symmetric_weights - qconfig = get_qat_qconfig( - symmetric_activations=symmetric_activations, - symmetric_weights=_symmetric_weights, + symmetric_weights=False, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: _prepare_qat_embedding(submodule, qconfig) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits): + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) + + # update qconfig_kwargs for bits + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): + raise ValueError( + "Cannot override quant_max and quant_min when number of bits is set" + ) + + if bits: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( + dict( + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + ) + ) + + return qconfig_kwargs + + def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"): embedding.weight_fake_quant = qconfig.weight() diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 30e1aefbe15..79772790566 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,14 +47,14 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( + QUANTIZABLE_MODULE_TYPES, add_quant_dequant, - configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, - freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, + get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -94,8 +94,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as - 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use + the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -113,26 +113,21 @@ class QuantizationModifier(ScheduledModifier): :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware Default is False - :param quantize_linear_activations: if True, FakeQuantize ops will be run - for output activations of fully connected layers. Default is False. - :param quantize_conv_activations: if True, FakeQuantize ops will be run - for output activations of convolutional layers. Default is False. + :param quantize_linear_activations: if False, FakeQuantize ops will not be run + for activations of fully connected layers. this is important for quantizing + transformer based models such as BERT where the quantized MatMul outputs + are kept at 32 bits of precision and fake quantizing the outputs harm training + recovery. Default is True :param activation_bits: Number of bits to use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. - :param weight_bits: Number of bits to use for setting quant min/max values for - weights. Default is None, which will quantize weights to 8 bits. + activations. Default is None, which will quantize activations to 8 bits. :param num_calibration_steps: Number of steps to run post training calibration for. - When None, the entire calibration_dataloader is used - :param exclude_batchnorm: If True, do not propagate quantization qconfigs to - batch-normalization modules + When None, the entire calibration_dataloader is used :param exclude_module_types: optional list of module class names to not propagate quantization configs to. Default is None :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param tenssorrt: if True sets quantization configuration for compatibility with - explict quantization as supported by TensorRT 8.2. + weights. """ def __init__( @@ -146,16 +141,15 @@ def __init__( model_fuse_fn_kwargs: Dict[str, Any] = None, quantize_embeddings: bool = True, reduce_range: bool = False, - quantize_linear_activations: bool = False, - quantize_conv_activations: bool = False, + quantize_linear_output_activations: bool = False, + quantize_conv_output_activations: bool = False, + quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, - exclude_batchnorm: bool = True, - exclude_module_types: Optional[List[str]] = None, + exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -179,11 +173,11 @@ def __init__( self._freeze_bn_stats_epoch = freeze_bn_stats_epoch self._quantize_embeddings = quantize_embeddings self._reduce_range = reduce_range - self._quantize_linear_activations = quantize_linear_activations - self._quantize_conv_activations = quantize_conv_activations + self._quantize_linear_output_activations = quantize_linear_output_activations + self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits - self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -192,7 +186,6 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -237,21 +230,10 @@ def submodules(self, value: Union[List[str], None]): def model_fuse_fn_name(self) -> Union[str, None]: """ :return: Name of model function to fuse the model in place prior - to performing QAT. None sets to default function. - If tensorrt flag is True, default is 'no_fuse', otherwise + to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - if self._tensorrt: - fuse_fn = ( - self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" - ) - else: - fuse_fn = ( - self._model_fuse_fn_name - if self._model_fuse_fn_name - else "conv_bn_relus" - ) - return fuse_fn + return self._model_fuse_fn_name @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -280,7 +262,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """print + """ :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -332,20 +314,23 @@ def reduce_range(self) -> bool: return self._reduce_range @ModifierProp() - def quantize_linear_activations(self) -> bool: + def quantize_linear_output_activations(self) -> bool: """ - :return: if True, FakeQuantize ops will be run for output activations - of fully connected layers + :return: if False, FakeQuantize ops will not be run + for activations of fully connected layers. this is important for quantizing + transformer based models such as BERT where the quantized MatMul outputs + are kept at 32 bits of precision and fake quantizing the outputs harm + training recovery """ - return self._quantize_linear_activations + return self._quantize_linear_output_activations @ModifierProp() - def quantize_conv_activations(self) -> bool: + def quantize_conv_output_activations(self) -> bool: """ - :return: if True, FakeQuantize ops will be run for output activations - of convolutional layers + :return: if False, FakeQuantize ops will not be run + for activations of convolutional layers. """ - return self._quantize_conv_activations + return self._quantize_linear_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -367,10 +352,11 @@ def activation_bits(self) -> Optional[int]: def weight_bits(self) -> Optional[int]: """ :return: Number of bits to be use for setting quant min/max values for - weights. Default is None, which will quantize weights to 8 bits. + activations. Default is None, which will quantize activations to 8 bits. """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -387,16 +373,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if ( - self._weight_qconfig_kwargs is not None - and "observer" in self._weight_qconfig_kwargs - ): - kwargs = self._weight_qconfig_kwargs.copy() - if kwargs["observer"] == "minmaxobserver": - kwargs["observer"] = torch_quantization.MinMaxObserver - return kwargs - else: - return self._weight_qconfig_kwargs + return self._weight_qconfig_kwargs @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -406,15 +383,6 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps - @ModifierProp() - def tensorrt(self) -> Dict[str, Any]: - """ - :return: Dictionary with correct quant_min, quant_max, and dtype values - for activations - - """ - return self._tensorrt - def initialize( self, module: Module, @@ -448,7 +416,10 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if name in self._submodules: + if ( + type(submodule) in QUANTIZABLE_MODULE_TYPES + and name in self._submodules + ): self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -533,15 +504,15 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(freeze_bn_stats) + quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == "conv_bn_relus": - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) - elif self.model_fuse_fn_name != "no_fuse": + if ( + self._model_fuse_fn_name is not None + and self._model_fuse_fn_name != "no_fuse" + ): # module class fn module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -551,105 +522,49 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) + elif self._model_fuse_fn_name is None: # default auto fn + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + + activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() + weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() - # build list of layer types that should not quantize output activations to_remove_layer_name = [] - if not self._quantize_linear_activations: - to_remove_layer_name.extend(["Linear", "LinearReLU"]) + if not self._quantize_linear_output_activations: + to_remove_layer_name.extend(["Linear", "LinearReLu"]) - if not self._quantize_conv_activations: + if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) - if len(to_remove_layer_name) == 0: - to_remove_layer_name = None - - # fix for freezing batchnorm statistics when not fusing BN with convs. - # pytorch only supports freezing batchnorm statistics for fused modules. - # this fix wraps BN modules adding with a new module class that supports - # methods related to freezing/unfreezing BN statistics. - configure_module_bn_wrappers(module) - - # set qconfig. - # if tensorrt flag is used, set activation and weights to symmetric - # quantization. - # otherwise, use the default values set in get_qat_qconfig - if self.tensorrt: - _symmetric_activations = True - _activation_dtype = torch.qint8 - _symmetric_weights = True - _weight_dtype = torch.qint8 - else: - _symmetric_activations = None - _activation_dtype = None - _symmetric_weights = None - _weight_dtype = None + # prepare each module / submodule for quantization qconfig = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) - - # prepare each module / submodule for quantization for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) - # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig - # wrap all conv / linear blocks in with quantization observers torch_quantization.propagate_qconfig_(quant_module) configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - - # Remove output quantization from appropriate modules - if to_remove_layer_name: - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types - to_exclude = [] - if self._exclude_module_types: - to_exclude.extend(self._exclude_module_types) - - # if exclude_batchnorm flag is used, add batch norm layers to list of - # modules to exclude qconfig - if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) - - self._exclude_module_types = to_exclude if self._exclude_module_types: self._strip_excluded_module_qconfigs(module) @@ -658,15 +573,9 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # propagate custom quant min/max range from FakeQuantize to Observer objects @@ -723,6 +632,12 @@ def _calibrate(self, module): if module_training: module.train() + def _get_updated_activation_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + + def _get_updated_weight_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From 757fad6aea9feac5bcfccc148da78bd857f5c2f7 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:35:49 -0500 Subject: [PATCH 170/218] Added _Add_ReLU module that enables QATWrapper for quantizaiton. --- .../pytorch/models/classification/resnet.py | 66 +++++++++---------- 1 file changed, 30 insertions(+), 36 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index cd8b979c3ad..3112da7c2e1 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,7 +41,6 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: @@ -141,19 +140,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - """ - Wrapper for the FloatFunctional class that enables QATWrapper used to - quantize the first input to the Add operation - """ - - def __init__(self, num_channels): + def __init__(self): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: - self.functional = ReLU(num_channels=num_channels, inplace=True) + self.functional = ReLU(num_channels=out_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -185,7 +179,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU(out_channels) + self.add_relu = _AddReLU() self.initialize() @@ -211,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -242,7 +236,7 @@ def __init__( else None ) - self.add_relu = _AddReLU(out_channels) + self.add_relu = _AddReLU() self.initialize() @@ -327,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -443,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -485,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 398ae3ed75451a0c0fc8d937a8654bb98e916042 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:36:37 -0500 Subject: [PATCH 171/218] Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. --- .../sparsification/quantization/helpers.py | 6 +-- .../quantization/modifier_quantization.py | 39 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index e10224bbce7..ec69ded82c8 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,7 +32,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ - "QUANTIZABLE_MODULE_TYPES", "QATWrapper", "configure_module_qat_wrappers", "configure_module_default_qconfigs", @@ -45,7 +44,7 @@ "prepare_embeddings_qat", ] -QUANTIZABLE_MODULE_TYPES = ( +_QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -150,6 +149,7 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -398,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in QUANTIZABLE_MODULE_TYPES + type(module) in _QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 79772790566..f914b1f2b91 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,7 +47,6 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( - QUANTIZABLE_MODULE_TYPES, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, @@ -94,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use - the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' + to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,10 +142,10 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, - quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, + exclude_batchnorm: bool = True, exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, @@ -175,9 +174,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations - self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits + self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -233,7 +232,8 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - return self._model_fuse_fn_name + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + return fuse_fn @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -416,10 +416,7 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if ( - type(submodule) in QUANTIZABLE_MODULE_TYPES - and name in self._submodules - ): + if name in self._submodules: self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -509,10 +506,10 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if ( - self._model_fuse_fn_name is not None - and self._model_fuse_fn_name != "no_fuse" - ): # module class fn + if self._model_fuse_fn_name == 'conv_bn_relus': + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -522,16 +519,13 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) - elif self._model_fuse_fn_name is None: # default auto fn - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLu"]) + to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -565,8 +559,15 @@ def _enable_module_qat(self, module: Module): remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types + to_exclude = [] if self._exclude_module_types: - self._strip_excluded_module_qconfigs(module) + to_exclude.extend(self._exclude_module_types) + + if self._exclude_batchnorm: + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + + self._exclude_module_types = to_exclude + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) From 4da2e8f2e38dbd4ad48737a529b85146ea5aeae3 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:40:31 -0500 Subject: [PATCH 172/218] Minor fixes. Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 61 ++++----- .../sparsification/quantization/helpers.py | 123 +++++++++--------- .../quantization/modifier_quantization.py | 33 +++-- 3 files changed, 112 insertions(+), 105 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3112da7c2e1..be4182891d6 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -140,14 +141,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self): + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: - self.functional = ReLU(num_channels=out_channels, inplace=True) + self.functional = ReLU(num_channels=num_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +180,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -205,12 +206,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -236,7 +237,7 @@ def __init__( else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -321,12 +322,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -437,15 +438,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -479,10 +480,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index ec69ded82c8..2c1ac640d6e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_qat_wrappers", @@ -105,10 +107,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -140,7 +142,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -153,20 +155,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -285,12 +287,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -331,10 +333,10 @@ def _load_qconfigs( def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -383,7 +385,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -398,9 +400,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -424,7 +426,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -433,11 +435,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -509,7 +511,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -534,9 +536,9 @@ def fix_observer_quant_range(module: Module): def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -571,14 +573,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -615,11 +617,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -649,17 +651,10 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index f914b1f2b91..637bf7e52dd 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -93,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' - to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as + 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -232,7 +232,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" return fuse_fn @model_fuse_fn_name.setter @@ -356,7 +356,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -506,7 +505,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': + if self._model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -529,10 +528,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) # prepare each module / submodule for quantization @@ -564,7 +573,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -634,7 +643,9 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits + ) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From a9f2b40ac6c9b02bc74794a801ce8d87ca66f86e Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:02:14 -0500 Subject: [PATCH 173/218] Added support to freezing bn stats. --- .../sparsification/quantization/helpers.py | 209 +++++++++++++----- .../quantization/modifier_quantization.py | 37 ++-- 2 files changed, 164 insertions(+), 82 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 2c1ac640d6e..a44369550b1 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,16 +31,17 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", - "configure_module_qat_wrappers", + "configure_module_bn_wrappers", "configure_module_default_qconfigs", + "configure_module_qat_wrappers", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", + "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] @@ -69,6 +69,54 @@ else None ) +_BN_MODULE_TYPES = ( + { + # Conv based layers + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.ConvBnReLU1d, + nni.ConvBnReLU2d, + nni.ConvBnReLU3d, + } + if nni # nni will always import if torch.quantization is available + else {} +) + + +class BNWrapper(Module): + def __init__(self, module: Module): + super().__init__() + self.bn = module + self.freeze_bn = False + + def forward(self, x): + return self.bn(x) + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def train(self, mode=True): + if not self.freeze_bn: + self.bn.train() + return self + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + +_BN_MODULE_TYPES.add(BNWrapper) + class QATWrapper(Module): """ @@ -107,10 +155,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -142,7 +190,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -155,20 +203,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -287,12 +335,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -332,11 +380,40 @@ def _load_qconfigs( return qconfigs +def configure_module_bn_wrappers(module: Module): + """ + if any submodule of the given module has the attribute wrap_qat == True, + then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. + Other named kwargs to the QATWrapper constructor must be contained in a dictionary + under an attributed named `qat_wrapper_kwargs` + + :param module: module to potentially wrap the submodules of + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware + Default is False + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. Default is {} + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. Default is {} + """ + # wrap any children of the given module as a QATWrapper if required + if type(module) != BNWrapper: + for child_name, child_module in module.named_children(): + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) + # recurse on child module + configure_module_bn_wrappers(child_module) + + def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -385,7 +462,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -400,9 +477,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -426,7 +503,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -435,11 +512,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -511,7 +588,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -535,10 +612,15 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True +def freeze_bn_stats(module: Module): + if type(module) in _BN_MODULE_TYPES: + module.freeze_bn_stats() + + def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -573,14 +655,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -617,11 +699,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -651,10 +733,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 637bf7e52dd..7eed410b441 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -48,9 +48,11 @@ from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( add_quant_dequant, + configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, + freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, get_updated_qconfig_kwargs, @@ -232,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' return fuse_fn @model_fuse_fn_name.setter @@ -262,7 +264,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """ + """print :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -356,6 +358,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -500,12 +503,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) + quant_module.apply(freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == "conv_bn_relus": + if self._model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -528,22 +531,14 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) + configure_module_bn_wrappers(module) + # prepare each module / submodule for quantization qconfig = get_qat_qconfig( reduce_range=self._reduce_range, @@ -573,7 +568,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -643,9 +638,7 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits - ) + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From bbbc1b3fbe4649708dce268ddc8b72d0604404dc Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:57:15 -0500 Subject: [PATCH 174/218] Added mode argument to wrapping of train function in BNWrapper --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index a44369550b1..48ed0708eae 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -106,7 +106,7 @@ def reset_running_stats(self): def train(self, mode=True): if not self.freeze_bn: - self.bn.train() + self.bn.train(mode) return self def update_bn_stats(self): From 2817614a2bf391dfba89c4fdec5ec52366d7818e Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:08:20 -0500 Subject: [PATCH 175/218] Set BN fusing back as default. --- .../sparsification/quantization/modifier_quantization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 7eed410b441..37307e38863 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -508,8 +508,8 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True + if self.model_fuse_fn_name == 'conv_bn_relus': + self.model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 2fb23c049b5e6c916c577933a331f6a1b0e7e4a1 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:19:09 -0500 Subject: [PATCH 176/218] Set BN fusing back as default. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- .../sparsification/quantization/modifier_quantization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 48ed0708eae..71f6553fc44 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -397,7 +397,7 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) != BNWrapper: + if type(module) not in _BN_MODULE_TYPES: for child_name, child_module in module.named_children(): if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: setattr( diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 37307e38863..2a35ebd2aaf 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -509,7 +509,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs if self.model_fuse_fn_name == 'conv_bn_relus': - self.model_fuse_fn_kwargs["inplace"] = True + self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 572bfbcc16f086d7d151654c587422774cae38ea Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 11 Mar 2022 19:24:03 -0500 Subject: [PATCH 177/218] Fixed custom freeze_bn_stats. --- .../sparsification/quantization/helpers.py | 245 +++++++++++------- .../quantization/modifier_quantization.py | 46 +++- 2 files changed, 182 insertions(+), 109 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 71f6553fc44..e09c0e29690 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,23 +71,6 @@ else None ) -_BN_MODULE_TYPES = ( - { - # Conv based layers - nni.ConvBn1d, - nni.ConvBn2d, - nni.ConvBn3d, - nni.ConvReLU1d, - nni.ConvReLU2d, - nni.ConvReLU3d, - nni.ConvBnReLU1d, - nni.ConvBnReLU2d, - nni.ConvBnReLU3d, - } - if nni # nni will always import if torch.quantization is available - else {} -) - class BNWrapper(Module): def __init__(self, module: Module): @@ -93,6 +78,78 @@ def __init__(self, module: Module): self.bn = module self.freeze_bn = False + @property + def running_mean(self): + return self.bn.running_mean + + @running_mean.setter + def running_mean(self, value): + self.bn.running_mean = value + + @property + def running_var(self): + return self.bn.running_var + + @running_var.setter + def running_var(self, value): + self.bn.running_var = value + + @property + def weight(self): + return self.bn.weight + + @weight.setter + def weight(self, value): + self.bn.weight = value + + @property + def bias(self): + return self.bn.bias + + @bias.setter + def bias(self, value): + self.bn.bias = value + + @property + def gamma(self): + return self.bn.gamma + + @gamma.setter + def gamma(self, value): + self.bn.gamma = value + + @property + def beta(self): + return self.bn.beta + + @beta.setter + def beta(self, value): + self.bn.beta = value + + @property + def num_batches_tracked(self): + return self.bn.num_batches_tracked + + @num_batches_tracked.setter + def num_batches_tracked(self, value): + self.bn.num_batches_tracked = value + + @property + def eps(self): + return self.bn.eps + + @eps.setter + def eps(self, value): + self.bn.eps = value + + @property + def momentum(self): + return self.bn.momentum + + @momentum.setter + def momentum(self, value): + self.bn.momentum = value + def forward(self, x): return self.bn(x) @@ -115,9 +172,6 @@ def update_bn_stats(self): return self -_BN_MODULE_TYPES.add(BNWrapper) - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -155,10 +209,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -190,7 +244,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -203,20 +257,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -335,12 +389,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -397,23 +451,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) not in _BN_MODULE_TYPES: + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -462,7 +516,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -477,9 +531,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -503,7 +557,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -512,11 +566,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -588,7 +642,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -613,14 +667,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if type(module) in _BN_MODULE_TYPES: + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -655,14 +709,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -699,11 +753,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -732,26 +786,25 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 + if mode == "symmetric": + quant_min = -(2 ** (bits - 1)) + quant_max = 2 ** (bits - 1) - 1 + dtype = torch.qint8 + else: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 2a35ebd2aaf..acbae885d71 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,9 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -332,7 +334,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_linear_output_activations + return self._quantize_conv_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -358,7 +360,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -504,11 +505,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) + # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -531,11 +533,23 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) + if len(to_remove_layer_name) == 0: + to_remove_layer_name = None configure_module_bn_wrappers(module) @@ -560,7 +574,8 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + if to_remove_layer_name: + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types to_exclude = [] @@ -568,10 +583,11 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude - self._strip_excluded_module_qconfigs(module) + if self._exclude_module_types: + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) @@ -638,10 +654,14 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" + ) def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + return get_updated_qconfig_kwargs( + self.weight_qconfig_kwargs, self.weight_bits, "symmetric" + ) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( From c3975234f738b54ff6df66e3ddc35f11b0ede30a Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 14 Mar 2022 15:35:52 -0400 Subject: [PATCH 178/218] Temporary files for evaluating changes to graphs. --- .../pytorch/models/classification/resnet.py | 53 +++++++++---------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index be4182891d6..21611f211d7 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,7 +41,6 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: @@ -146,7 +145,7 @@ def __init__(self, num_channels): if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -206,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -322,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -438,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -480,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From ebfcae8ad55c0956e67459dee08c0a5d173c6dfa Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 17 Mar 2022 11:51:50 -0400 Subject: [PATCH 179/218] Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. --- .../sparsification/quantization/helpers.py | 207 ++++++++++-------- .../quantization/modifier_quantization.py | 58 ++--- 2 files changed, 134 insertions(+), 131 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index e09c0e29690..57b919470e4 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,7 +31,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -209,10 +207,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -244,7 +242,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -257,20 +255,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -389,12 +387,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -451,23 +449,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, "freeze_bn_stats"): + if not hasattr(module, 'freeze_bn_stats'): for child_name, child_module in module.named_children(): - if type(child_module) in [ - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - ]: - setattr(module, child_name, BNWrapper(child_module)) + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -506,6 +504,17 @@ def configure_module_qat_wrappers( ) +def compute_range(dtype: torch.dtype, bits: int): + if dtype == torch.qint8: + quant_min = -2 ** (bits - 1) + quant_max = 2 ** (bits - 1) - 1 + elif dtype == torch.quint8: + quant_min = 0 + quant_max = 2 ** bits - 1 + + return quant_min, quant_max + + def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -516,7 +525,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -531,9 +540,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -557,7 +566,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -566,11 +575,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = torch.quint8, + weight_dtype: Optional[torch.dtype] = torch.qint8, + activation_bits: Optional[int] = 8, + weight_bits: Optional[int] = 8, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -591,42 +604,35 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_qscheme = ( - torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine - ) - activation_observer_kwargs = dict( - observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - dtype=torch.quint8, - qscheme=activation_qscheme, - reduce_range=reduce_range, - ) - activation_observer_kwargs.update(activation_qconfig_kwargs or {}) - activation_observer = torch_quantization.FakeQuantize.with_args( - **activation_observer_kwargs, + activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, + activation_qconfig_kwargs) + weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, ) - weight_qscheme = ( - torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + + +def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): + qscheme = ( + torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - weight_observer_kwargs = dict( + quant_min, quant_max = compute_range(dtype, bits) + observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=-128, - quant_max=127, - dtype=torch.qint8, - qscheme=weight_qscheme, + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + qscheme=qscheme, reduce_range=reduce_range, ) - - weight_observer_kwargs.update(weight_qconfig_kwargs or {}) - weight_observer = torch_quantization.FakeQuantize.with_args( - **weight_observer_kwargs, - ) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + observer_kwargs.update(qconfig_kwargs or {}) + observer = torch_quantization.FakeQuantize.with_args( + **observer_kwargs, ) + return observer + def fix_observer_quant_range(module: Module): """ @@ -642,7 +648,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -667,14 +673,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, "freeze_bn_stats"): + if hasattr(module, 'freeze_bn_stats'): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -709,14 +715,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -753,11 +759,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -787,17 +793,24 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) + quant_min = -2 ** (bits - 1) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index acbae885d71..5a5e1913b18 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -55,7 +55,6 @@ freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, - get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -151,6 +150,7 @@ def __init__( exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + tensorrt: Optional[bool] = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -187,6 +187,7 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -234,9 +235,10 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = ( - self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" - ) + if self._tensorrt: + fuse_fn = 'no_fuse' + else: + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -360,6 +362,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -505,12 +508,11 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) - # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == "conv_bn_relus": + if self.model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -524,29 +526,16 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) - activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() - weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() - to_remove_layer_name = [] if not self._quantize_linear_output_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -554,10 +543,21 @@ def _enable_module_qat(self, module: Module): configure_module_bn_wrappers(module) # prepare each module / submodule for quantization + if self.tensorrt: + _symmetric_activations = True + _activations_dtype = torch.qint8 + else: + _symmetric_activations = False + _activations_dtype = torch.quint8 + qconfig = get_qat_qconfig( + symmetric_activations=_symmetric_activations, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=_activations_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) @@ -583,7 +583,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude if self._exclude_module_types: @@ -653,16 +653,6 @@ def _calibrate(self, module): if module_training: module.train() - def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" - ) - - def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.weight_qconfig_kwargs, self.weight_bits, "symmetric" - ) - def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From 1edd77da4021e4276a4d39fe237ac840e4581515 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Sun, 20 Mar 2022 11:42:14 -0400 Subject: [PATCH 180/218] Added support to TensorRT quantization --- .../sparsification/quantization/helpers.py | 166 ++++++++++++++++-- .../quantization/modifier_quantization.py | 61 +++++-- 2 files changed, 195 insertions(+), 32 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 57b919470e4..2ae713c16aa 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -208,9 +208,15 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - reduce_range: bool = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -232,6 +238,18 @@ def from_module( else {} ) + qat_wrapper_kwargs["symmetric_activations"] = ( + symmetric_activations + if "symmetric_activations" not in qat_wrapper_kwargs + else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] + ) + + qat_wrapper_kwargs["symmetric_weights"] = ( + symmetric_weights or False + if "symmetric_weights" not in qat_wrapper_kwargs + else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] + ) + qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -251,6 +269,30 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + qat_wrapper_kwargs["activation_dtype"] = ( + activation_dtype + if "activation_dtype" not in qat_wrapper_kwargs + else activation_dtype or qat_wrapper_kwargs["activation_dtype"] + ) + + qat_wrapper_kwargs["weight_dtype"] = ( + weight_dtype + if "weight_dtype" not in qat_wrapper_kwargs + else weight_dtype or qat_wrapper_kwargs["weight_dtype"] + ) + + qat_wrapper_kwargs["activation_bits"] = ( + activation_bits + if "activation_bits" not in qat_wrapper_kwargs + else activation_bits or qat_wrapper_kwargs["activation_bits"] + ) + + qat_wrapper_kwargs["weight_bits"] = ( + weight_bits + if "weight_bits" not in qat_wrapper_kwargs + else weight_bits or qat_wrapper_kwargs["weight_bits"] + ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -266,9 +308,15 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -288,25 +336,43 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn + self._symmetric_activations = symmetric_activations + self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._activation_dtype = activation_dtype + self._weight_dtype = weight_dtype + self._activation_bits = activation_bits + self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -390,9 +456,15 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -422,11 +494,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) + if symmetric_activations is None: + _symmetric_activations = qconfig == "symmetric" + else: + _symmetric_activations = symmetric_activations + qconfigs[idx] = get_qat_qconfig( - symmetric_activations=(qconfig == "symmetric"), + symmetric_activations=_symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) return qconfigs @@ -463,9 +545,15 @@ def configure_module_bn_wrappers(module: Module): def configure_module_qat_wrappers( module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -490,29 +578,43 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) -def compute_range(dtype: torch.dtype, bits: int): +def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + dtype = dtype if dtype else torch.quint8 + bits = bits if bits else 8 if dtype == torch.qint8: - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 + quant_min = -(2 ** (bits - 1)) + quant_max = (2 ** (bits - 1)) - 1 elif dtype == torch.quint8: quant_min = 0 - quant_max = 2 ** bits - 1 + quant_max = (2 ** bits) - 1 - return quant_min, quant_max + return quant_min, quant_max, dtype def configure_module_default_qconfigs(module: Module): @@ -575,15 +677,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = torch.quint8, - weight_dtype: Optional[torch.dtype] = torch.qint8, - activation_bits: Optional[int] = 8, - weight_bits: Optional[int] = 8, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -606,18 +708,28 @@ def get_qat_qconfig( """ activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, activation_qconfig_kwargs) - weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + if symmetric_weights is None: + _symmetric_weights = True + else: + _symmetric_weights = symmetric_weights + + if weight_dtype is None: + _weight_dtype = torch.qint8 + else: + _weight_dtype = weight_dtype + + weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): +def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): qscheme = ( torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - quant_min, quant_max = compute_range(dtype, bits) + quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, quant_min=quant_min, @@ -761,9 +873,15 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -781,11 +899,21 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: + if symmetric_weights is None: + _symmetric_weights = False + else: + _symmetric_weights = symmetric_weights + qconfig = get_qat_qconfig( - symmetric_weights=False, + symmetric_activations=symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 5a5e1913b18..27c5a4c336e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -147,10 +147,10 @@ def __init__( weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_batchnorm: bool = True, - exclude_module_types: Union[List[str], None] = None, + exclude_module_types: Optional[List[str]] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: Optional[bool] = False, + tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -379,7 +379,15 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - return self._weight_qconfig_kwargs + if "observer" in self._weight_qconfig_kwargs: + kwargs = self._weight_qconfig_kwargs.copy() + if kwargs["observer"] == "minmaxobserver": + kwargs["observer"] = torch_quantization.MinMaxObserver + return kwargs + else: + return self._weight_qconfig_kwargs + + @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -389,6 +397,15 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps + @ModifierProp() + def tensorrt(self) -> Dict[str, Any]: + """ + :return: Dictionary with correct quant_min, quant_max, and dtype values + for activations + + """ + return self._tensorrt + def initialize( self, module: Module, @@ -545,17 +562,23 @@ def _enable_module_qat(self, module: Module): # prepare each module / submodule for quantization if self.tensorrt: _symmetric_activations = True - _activations_dtype = torch.qint8 + _activation_dtype = torch.qint8 + _symmetric_weights = True + _weight_dtype = torch.qint8 else: - _symmetric_activations = False - _activations_dtype = torch.quint8 + _symmetric_activations = None + _activation_dtype = None + _symmetric_weights = None + _weight_dtype = None qconfig = get_qat_qconfig( symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=_activations_dtype, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, activation_bits=self.activation_bits, weight_bits=self.weight_bits ) @@ -563,9 +586,15 @@ def _enable_module_qat(self, module: Module): # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -594,9 +623,15 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, ) # propagate custom quant min/max range from FakeQuantize to Observer objects From b26df339632b52f019820c5a574567375e5bc618 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 21 Mar 2022 19:16:26 -0400 Subject: [PATCH 181/218] Included check to account for when weight_qconfig_kwatgs is None. --- .../sparsification/quantization/modifier_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 27c5a4c336e..a306f4d8e73 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -379,7 +379,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if "observer" in self._weight_qconfig_kwargs: + if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver From 2f85a4bdd8eea7bdbb2a76da928242575ce60014 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 12:14:39 -0500 Subject: [PATCH 182/218] Removed output quantization from conv layers --- .../sparsification/quantization/helpers.py | 377 +++--------------- .../quantization/modifier_quantization.py | 130 ++---- 2 files changed, 87 insertions(+), 420 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 2ae713c16aa..75d11c67c31 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,21 +32,20 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ + "QUANTIZABLE_MODULE_TYPES", "QATWrapper", - "configure_module_bn_wrappers", - "configure_module_default_qconfigs", "configure_module_qat_wrappers", + "configure_module_default_qconfigs", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", - "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] -_QUANTIZABLE_MODULE_TYPES = ( +QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -70,106 +69,6 @@ ) -class BNWrapper(Module): - def __init__(self, module: Module): - super().__init__() - self.bn = module - self.freeze_bn = False - - @property - def running_mean(self): - return self.bn.running_mean - - @running_mean.setter - def running_mean(self, value): - self.bn.running_mean = value - - @property - def running_var(self): - return self.bn.running_var - - @running_var.setter - def running_var(self, value): - self.bn.running_var = value - - @property - def weight(self): - return self.bn.weight - - @weight.setter - def weight(self, value): - self.bn.weight = value - - @property - def bias(self): - return self.bn.bias - - @bias.setter - def bias(self, value): - self.bn.bias = value - - @property - def gamma(self): - return self.bn.gamma - - @gamma.setter - def gamma(self, value): - self.bn.gamma = value - - @property - def beta(self): - return self.bn.beta - - @beta.setter - def beta(self, value): - self.bn.beta = value - - @property - def num_batches_tracked(self): - return self.bn.num_batches_tracked - - @num_batches_tracked.setter - def num_batches_tracked(self, value): - self.bn.num_batches_tracked = value - - @property - def eps(self): - return self.bn.eps - - @eps.setter - def eps(self, value): - self.bn.eps = value - - @property - def momentum(self): - return self.bn.momentum - - @momentum.setter - def momentum(self, value): - self.bn.momentum = value - - def forward(self, x): - return self.bn(x) - - def freeze_bn_stats(self): - self.freeze_bn = True - self.bn.training = False - return self - - def reset_running_stats(self): - self.bn.reset_running_stats() - - def train(self, mode=True): - if not self.freeze_bn: - self.bn.train(mode) - return self - - def update_bn_stats(self): - self.freeze_bn = False - self.bn.training = True - return self - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -208,15 +107,9 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, + reduce_range: bool = None, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -238,18 +131,6 @@ def from_module( else {} ) - qat_wrapper_kwargs["symmetric_activations"] = ( - symmetric_activations - if "symmetric_activations" not in qat_wrapper_kwargs - else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] - ) - - qat_wrapper_kwargs["symmetric_weights"] = ( - symmetric_weights or False - if "symmetric_weights" not in qat_wrapper_kwargs - else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] - ) - qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -269,31 +150,6 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) - qat_wrapper_kwargs["activation_dtype"] = ( - activation_dtype - if "activation_dtype" not in qat_wrapper_kwargs - else activation_dtype or qat_wrapper_kwargs["activation_dtype"] - ) - - qat_wrapper_kwargs["weight_dtype"] = ( - weight_dtype - if "weight_dtype" not in qat_wrapper_kwargs - else weight_dtype or qat_wrapper_kwargs["weight_dtype"] - ) - - qat_wrapper_kwargs["activation_bits"] = ( - activation_bits - if "activation_bits" not in qat_wrapper_kwargs - else activation_bits or qat_wrapper_kwargs["activation_bits"] - ) - - qat_wrapper_kwargs["weight_bits"] = ( - weight_bits - if "weight_bits" not in qat_wrapper_kwargs - else weight_bits or qat_wrapper_kwargs["weight_bits"] - ) - - module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -308,15 +164,9 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ): super().__init__() @@ -336,43 +186,25 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn - self._symmetric_activations = symmetric_activations - self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._activation_dtype = activation_dtype - self._weight_dtype = weight_dtype - self._activation_bits = activation_bits - self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -456,15 +288,9 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -494,66 +320,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) - if symmetric_activations is None: - _symmetric_activations = qconfig == "symmetric" - else: - _symmetric_activations = symmetric_activations - qconfigs[idx] = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=symmetric_weights, + symmetric_activations=(qconfig == "symmetric"), reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) return qconfigs -def configure_module_bn_wrappers(module: Module): - """ - if any submodule of the given module has the attribute wrap_qat == True, - then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. - Other named kwargs to the QATWrapper constructor must be contained in a dictionary - under an attributed named `qat_wrapper_kwargs` - - :param module: module to potentially wrap the submodules of - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} - """ - # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, 'freeze_bn_stats'): - for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) - # recurse on child module - configure_module_bn_wrappers(child_module) - - def configure_module_qat_wrappers( module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -578,45 +359,20 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) -def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): - dtype = dtype if dtype else torch.quint8 - bits = bits if bits else 8 - if dtype == torch.qint8: - quant_min = -(2 ** (bits - 1)) - quant_max = (2 ** (bits - 1)) - 1 - elif dtype == torch.quint8: - quant_min = 0 - quant_max = (2 ** bits) - 1 - - return quant_min, quant_max, dtype - - def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -642,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES + type(module) in QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): @@ -677,15 +433,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -706,44 +458,41 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, - activation_qconfig_kwargs) - if symmetric_weights is None: - _symmetric_weights = True - else: - _symmetric_weights = symmetric_weights - - if weight_dtype is None: - _weight_dtype = torch.qint8 - else: - _weight_dtype = weight_dtype - - weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + activation_qscheme = ( + torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine ) - - -def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): - qscheme = ( - torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine - ) - quant_min, quant_max, dtype = compute_range(dtype, bits) - observer_kwargs = dict( + activation_observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - qscheme=qscheme, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=activation_qscheme, reduce_range=reduce_range, ) - observer_kwargs.update(qconfig_kwargs or {}) - observer = torch_quantization.FakeQuantize.with_args( - **observer_kwargs, + activation_observer_kwargs.update(activation_qconfig_kwargs or {}) + activation_observer = torch_quantization.FakeQuantize.with_args( + **activation_observer_kwargs, + ) + weight_qscheme = ( + torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + ) + weight_observer_kwargs = dict( + observer=torch_quantization.MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=weight_qscheme, + reduce_range=reduce_range, ) - return observer + weight_observer_kwargs.update(weight_qconfig_kwargs or {}) + weight_observer = torch_quantization.FakeQuantize.with_args( + **weight_observer_kwargs, + ) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, + ) def fix_observer_quant_range(module: Module): @@ -769,14 +518,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) - or ( # do not propagate default uint8 symmetric range - observer.qscheme == torch.per_tensor_symmetric - and fake_quantize.quant_min == 0 - and fake_quantize.quant_max == 255 - ) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -784,11 +528,6 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True -def freeze_bn_stats(module: Module): - if hasattr(module, 'freeze_bn_stats'): - module.freeze_bn_stats() - - def fuse_module_conv_bn_relus( module: Module, inplace: bool = True, @@ -873,15 +612,9 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -899,28 +632,18 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: - if symmetric_weights is None: - _symmetric_weights = False - else: - _symmetric_weights = symmetric_weights - qconfig = get_qat_qconfig( - symmetric_activations=symmetric_activations, - symmetric_weights=_symmetric_weights, + symmetric_weights=False, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): +def get_updated_qconfig_kwargs(qconfig_kwargs, bits): qconfig_kwargs = ( qconfig_kwargs.copy() if qconfig_kwargs @@ -937,15 +660,9 @@ def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): ) if bits: - if mode == "symmetric": - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 - dtype = torch.qint8 - else: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 - + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index a306f4d8e73..79772790566 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,14 +47,14 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( + QUANTIZABLE_MODULE_TYPES, add_quant_dequant, - configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, - freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, + get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -94,8 +94,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as - 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use + the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,14 +143,13 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, + quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, - exclude_batchnorm: bool = True, - exclude_module_types: Optional[List[str]] = None, + exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -176,9 +175,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits - self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -187,7 +186,6 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -235,11 +233,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - if self._tensorrt: - fuse_fn = 'no_fuse' - else: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' - return fuse_fn + return self._model_fuse_fn_name @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -268,7 +262,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """print + """ :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -336,7 +330,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_conv_output_activations + return self._quantize_linear_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -379,15 +373,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: - kwargs = self._weight_qconfig_kwargs.copy() - if kwargs["observer"] == "minmaxobserver": - kwargs["observer"] = torch_quantization.MinMaxObserver - return kwargs - else: - return self._weight_qconfig_kwargs - - + return self._weight_qconfig_kwargs @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -397,15 +383,6 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps - @ModifierProp() - def tensorrt(self) -> Dict[str, Any]: - """ - :return: Dictionary with correct quant_min, quant_max, and dtype values - for activations - - """ - return self._tensorrt - def initialize( self, module: Module, @@ -439,7 +416,10 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if name in self._submodules: + if ( + type(submodule) in QUANTIZABLE_MODULE_TYPES + and name in self._submodules + ): self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -524,15 +504,15 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(freeze_bn_stats) + quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) - elif self.model_fuse_fn_name != "no_fuse": + if ( + self._model_fuse_fn_name is not None + and self._model_fuse_fn_name != "no_fuse" + ): # module class fn module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -542,10 +522,16 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) + elif self._model_fuse_fn_name is None: # default auto fn + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + + activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() + weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLU"]) + to_remove_layer_name.extend(["Linear", "LinearReLu"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -554,47 +540,20 @@ def _enable_module_qat(self, module: Module): "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) - if len(to_remove_layer_name) == 0: - to_remove_layer_name = None - - configure_module_bn_wrappers(module) # prepare each module / submodule for quantization - if self.tensorrt: - _symmetric_activations = True - _activation_dtype = torch.qint8 - _symmetric_weights = True - _weight_dtype = torch.qint8 - else: - _symmetric_activations = None - _activation_dtype = None - _symmetric_weights = None - _weight_dtype = None - qconfig = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -603,18 +562,9 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - if to_remove_layer_name: - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types - to_exclude = [] - if self._exclude_module_types: - to_exclude.extend(self._exclude_module_types) - - if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) - - self._exclude_module_types = to_exclude if self._exclude_module_types: self._strip_excluded_module_qconfigs(module) @@ -623,15 +573,9 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, + activation_qconfig_kwargs=activation_qconfig_kwargs, + weight_qconfig_kwargs=weight_qconfig_kwargs, ) # propagate custom quant min/max range from FakeQuantize to Observer objects @@ -688,6 +632,12 @@ def _calibrate(self, module): if module_training: module.train() + def _get_updated_activation_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + + def _get_updated_weight_qconfig_kwargs(self): + return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From c0940dd728d124042d4e54f7b3f7f89b2d739d69 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:35:49 -0500 Subject: [PATCH 183/218] Added _Add_ReLU module that enables QATWrapper for quantizaiton. --- src/sparseml/pytorch/models/classification/resnet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 21611f211d7..3112da7c2e1 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -140,14 +140,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self, num_channels): + def __init__(self): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: - self.functional = ReLU(num_channels=num_channels, inplace=True) + self.functional = ReLU(num_channels=out_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +179,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU(out_channels) + self.add_relu = _AddReLU() self.initialize() @@ -236,7 +236,7 @@ def __init__( else None ) - self.add_relu = _AddReLU(out_channels) + self.add_relu = _AddReLU() self.initialize() From 3b30724030162eb6ce6d4ee65eb6a11f21f9a0cb Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:36:37 -0500 Subject: [PATCH 184/218] Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. --- .../sparsification/quantization/helpers.py | 6 +-- .../quantization/modifier_quantization.py | 39 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 75d11c67c31..f28656f1712 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -32,7 +32,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm __all__ = [ - "QUANTIZABLE_MODULE_TYPES", "QATWrapper", "configure_module_qat_wrappers", "configure_module_default_qconfigs", @@ -45,7 +44,7 @@ "prepare_embeddings_qat", ] -QUANTIZABLE_MODULE_TYPES = ( +_QUANTIZABLE_MODULE_TYPES = ( { # Conv based layers torch.nn.Conv1d, @@ -150,6 +149,7 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( @@ -398,7 +398,7 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in QUANTIZABLE_MODULE_TYPES + type(module) in _QUANTIZABLE_MODULE_TYPES and hasattr(module, "qconfig") and module.qconfig ): diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 79772790566..f914b1f2b91 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,7 +47,6 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( - QUANTIZABLE_MODULE_TYPES, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, @@ -94,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as 'no_fuse' to skip module fusing. Leave None to use - the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' + to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -143,10 +142,10 @@ def __init__( reduce_range: bool = False, quantize_linear_output_activations: bool = False, quantize_conv_output_activations: bool = False, - quantize_add_input_activations: bool = True, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, + exclude_batchnorm: bool = True, exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, @@ -175,9 +174,9 @@ def __init__( self._reduce_range = reduce_range self._quantize_linear_output_activations = quantize_linear_output_activations self._quantize_conv_output_activations = quantize_conv_output_activations - self._quantize_add_input_activations = quantize_add_input_activations self._activation_bits = activation_bits self._weight_bits = weight_bits + self._exclude_batchnorm = exclude_batchnorm self._exclude_module_types = exclude_module_types self._modules_to_quantize = None @@ -233,7 +232,8 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - return self._model_fuse_fn_name + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + return fuse_fn @model_fuse_fn_name.setter def model_fuse_fn_name(self, value: Union[str, None]): @@ -416,10 +416,7 @@ def initialize( if self._submodules is not None: found_submodules = [] for name, submodule in module.named_modules(): - if ( - type(submodule) in QUANTIZABLE_MODULE_TYPES - and name in self._submodules - ): + if name in self._submodules: self._modules_to_quantize.append(_ModuleToQuantize(name, submodule)) found_submodules.append(name) if not len(found_submodules) == len(self._submodules): @@ -509,10 +506,10 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if ( - self._model_fuse_fn_name is not None - and self._model_fuse_fn_name != "no_fuse" - ): # module class fn + if self._model_fuse_fn_name == 'conv_bn_relus': + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) + elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) if module_fuse_fn is None or not callable(module_fuse_fn): raise ValueError( @@ -522,16 +519,13 @@ def _enable_module_qat(self, module: Module): ) ) module_fuse_fn(**self._model_fuse_fn_kwargs) - elif self._model_fuse_fn_name is None: # default auto fn - self._model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() to_remove_layer_name = [] if not self._quantize_linear_output_activations: - to_remove_layer_name.extend(["Linear", "LinearReLu"]) + to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( @@ -565,8 +559,15 @@ def _enable_module_qat(self, module: Module): remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types + to_exclude = [] if self._exclude_module_types: - self._strip_excluded_module_qconfigs(module) + to_exclude.extend(self._exclude_module_types) + + if self._exclude_batchnorm: + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + + self._exclude_module_types = to_exclude + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) From ea18f9386a72f6e446a6a27ff5bd6f1ae20841da Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 9 Mar 2022 21:40:31 -0500 Subject: [PATCH 185/218] Minor fixes. Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 61 +++++---- .../sparsification/quantization/helpers.py | 129 +++++++++--------- .../quantization/modifier_quantization.py | 33 +++-- 3 files changed, 115 insertions(+), 108 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3112da7c2e1..be4182891d6 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -140,14 +141,14 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): - def __init__(self): + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: - self.functional = ReLU(num_channels=out_channels, inplace=True) + self.functional = ReLU(num_channels=num_channels, inplace=True) def forward(self, x, y): if isinstance(self.functional, FloatFunctional): @@ -179,7 +180,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -205,12 +206,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -236,7 +237,7 @@ def __init__( else None ) - self.add_relu = _AddReLU() + self.add_relu = _AddReLU(out_channels) self.initialize() @@ -321,12 +322,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -437,15 +438,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -479,10 +480,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index f28656f1712..ef4445a0d5f 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_qat_wrappers", @@ -105,10 +107,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -140,7 +142,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -153,20 +155,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -285,12 +287,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -331,10 +333,10 @@ def _load_qconfigs( def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -383,7 +385,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -398,9 +400,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -424,7 +426,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -433,11 +435,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -509,7 +511,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -518,9 +520,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -529,9 +531,9 @@ def fix_observer_quant_range(module: Module): def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -566,14 +568,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -610,11 +612,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -644,17 +646,10 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index f914b1f2b91..637bf7e52dd 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -93,8 +93,8 @@ class QuantizationModifier(ScheduledModifier): :param submodules: List of submodule names to perform QAT on. Leave None to quantize entire model. Default is None :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as 'conv_bv_relus' - to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as + 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Default is None :param disable_quantization_observer_epoch: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will @@ -232,7 +232,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" return fuse_fn @model_fuse_fn_name.setter @@ -356,7 +356,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -506,7 +505,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': + if self._model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -529,10 +528,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) # prepare each module / submodule for quantization @@ -564,7 +573,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -634,7 +643,9 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits + ) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From 7475a5b7c0711657569ddab330fb87ea9ca4ca3c Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:02:14 -0500 Subject: [PATCH 186/218] Added support to freezing bn stats. --- .../sparsification/quantization/helpers.py | 215 +++++++++++++----- .../quantization/modifier_quantization.py | 37 ++- 2 files changed, 167 insertions(+), 85 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index ef4445a0d5f..c4f165d23ef 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,16 +31,17 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", - "configure_module_qat_wrappers", + "configure_module_bn_wrappers", "configure_module_default_qconfigs", + "configure_module_qat_wrappers", "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", "get_updated_qconfig_kwargs", "fix_observer_quant_range", + "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", ] @@ -69,6 +69,54 @@ else None ) +_BN_MODULE_TYPES = ( + { + # Conv based layers + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.ConvBnReLU1d, + nni.ConvBnReLU2d, + nni.ConvBnReLU3d, + } + if nni # nni will always import if torch.quantization is available + else {} +) + + +class BNWrapper(Module): + def __init__(self, module: Module): + super().__init__() + self.bn = module + self.freeze_bn = False + + def forward(self, x): + return self.bn(x) + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def train(self, mode=True): + if not self.freeze_bn: + self.bn.train() + return self + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + +_BN_MODULE_TYPES.add(BNWrapper) + class QATWrapper(Module): """ @@ -107,10 +155,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -142,7 +190,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -155,20 +203,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -287,12 +335,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -332,11 +380,40 @@ def _load_qconfigs( return qconfigs +def configure_module_bn_wrappers(module: Module): + """ + if any submodule of the given module has the attribute wrap_qat == True, + then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. + Other named kwargs to the QATWrapper constructor must be contained in a dictionary + under an attributed named `qat_wrapper_kwargs` + + :param module: module to potentially wrap the submodules of + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware + Default is False + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. Default is {} + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. Default is {} + """ + # wrap any children of the given module as a QATWrapper if required + if type(module) != BNWrapper: + for child_name, child_module in module.named_children(): + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) + # recurse on child module + configure_module_bn_wrappers(child_module) + + def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -385,7 +462,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -400,9 +477,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -426,7 +503,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -435,11 +512,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -511,7 +588,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -520,9 +597,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -530,10 +607,15 @@ def fix_observer_quant_range(module: Module): observer.has_customized_qrange = True +def freeze_bn_stats(module: Module): + if type(module) in _BN_MODULE_TYPES: + module.freeze_bn_stats() + + def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -568,14 +650,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -612,11 +694,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -646,10 +728,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 637bf7e52dd..7eed410b441 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -48,9 +48,11 @@ from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( add_quant_dequant, + configure_module_bn_wrappers, configure_module_default_qconfigs, configure_module_qat_wrappers, fix_observer_quant_range, + freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, get_updated_qconfig_kwargs, @@ -232,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' return fuse_fn @model_fuse_fn_name.setter @@ -262,7 +264,7 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: @disable_quantization_observer_epoch.setter def disable_quantization_observer_epoch(self, value: Union[float, None]): - """ + """print :params value: Epoch to disable updates to the module's quantization observers. After this point, quantized weights and zero points will not be updated. Set None to not disable observers during QAT @@ -356,6 +358,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -500,12 +503,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: - quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) + quant_module.apply(freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == "conv_bn_relus": + if self._model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -528,22 +531,14 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) + configure_module_bn_wrappers(module) + # prepare each module / submodule for quantization qconfig = get_qat_qconfig( reduce_range=self._reduce_range, @@ -573,7 +568,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude self._strip_excluded_module_qconfigs(module) @@ -643,9 +638,7 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits - ) + return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) def _get_updated_weight_qconfig_kwargs(self): return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) From cc0fc7cba6bb99ce8700bbaea4640996f5cb579f Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 13:57:15 -0500 Subject: [PATCH 187/218] Added mode argument to wrapping of train function in BNWrapper --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index c4f165d23ef..64958570e2d 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -106,7 +106,7 @@ def reset_running_stats(self): def train(self, mode=True): if not self.freeze_bn: - self.bn.train() + self.bn.train(mode) return self def update_bn_stats(self): From 2eda233537f040f5e94e2e221530903c0219bdfc Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:08:20 -0500 Subject: [PATCH 188/218] Set BN fusing back as default. --- .../sparsification/quantization/modifier_quantization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 7eed410b441..37307e38863 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,7 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -508,8 +508,8 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self._model_fuse_fn_name == 'conv_bn_relus': - self._model_fuse_fn_kwargs["inplace"] = True + if self.model_fuse_fn_name == 'conv_bn_relus': + self.model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From 1ed0ac17dba45bcd6a69e1a13e72edf9f7c57a33 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 10 Mar 2022 15:19:09 -0500 Subject: [PATCH 189/218] Set BN fusing back as default. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- .../sparsification/quantization/modifier_quantization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 64958570e2d..a43d69d947b 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -397,7 +397,7 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) != BNWrapper: + if type(module) not in _BN_MODULE_TYPES: for child_name, child_module in module.named_children(): if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: setattr( diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 37307e38863..2a35ebd2aaf 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -509,7 +509,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs if self.model_fuse_fn_name == 'conv_bn_relus': - self.model_fuse_fn_kwargs["inplace"] = True + self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": module_fuse_fn = getattr(module, self._model_fuse_fn_name, None) From c07e24a06a1192ff8259bc7d4e53e2a8579a84f5 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 11 Mar 2022 19:24:03 -0500 Subject: [PATCH 190/218] Fixed custom freeze_bn_stats. --- .../sparsification/quantization/helpers.py | 251 +++++++++++------- .../quantization/modifier_quantization.py | 46 +++- 2 files changed, 185 insertions(+), 112 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index a43d69d947b..6110a499b70 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,23 +71,6 @@ else None ) -_BN_MODULE_TYPES = ( - { - # Conv based layers - nni.ConvBn1d, - nni.ConvBn2d, - nni.ConvBn3d, - nni.ConvReLU1d, - nni.ConvReLU2d, - nni.ConvReLU3d, - nni.ConvBnReLU1d, - nni.ConvBnReLU2d, - nni.ConvBnReLU3d, - } - if nni # nni will always import if torch.quantization is available - else {} -) - class BNWrapper(Module): def __init__(self, module: Module): @@ -93,6 +78,78 @@ def __init__(self, module: Module): self.bn = module self.freeze_bn = False + @property + def running_mean(self): + return self.bn.running_mean + + @running_mean.setter + def running_mean(self, value): + self.bn.running_mean = value + + @property + def running_var(self): + return self.bn.running_var + + @running_var.setter + def running_var(self, value): + self.bn.running_var = value + + @property + def weight(self): + return self.bn.weight + + @weight.setter + def weight(self, value): + self.bn.weight = value + + @property + def bias(self): + return self.bn.bias + + @bias.setter + def bias(self, value): + self.bn.bias = value + + @property + def gamma(self): + return self.bn.gamma + + @gamma.setter + def gamma(self, value): + self.bn.gamma = value + + @property + def beta(self): + return self.bn.beta + + @beta.setter + def beta(self, value): + self.bn.beta = value + + @property + def num_batches_tracked(self): + return self.bn.num_batches_tracked + + @num_batches_tracked.setter + def num_batches_tracked(self, value): + self.bn.num_batches_tracked = value + + @property + def eps(self): + return self.bn.eps + + @eps.setter + def eps(self, value): + self.bn.eps = value + + @property + def momentum(self): + return self.bn.momentum + + @momentum.setter + def momentum(self, value): + self.bn.momentum = value + def forward(self, x): return self.bn(x) @@ -115,9 +172,6 @@ def update_bn_stats(self): return self -_BN_MODULE_TYPES.add(BNWrapper) - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -155,10 +209,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -190,7 +244,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -203,20 +257,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -335,12 +389,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -397,23 +451,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if type(module) not in _BN_MODULE_TYPES: + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -462,7 +516,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -477,9 +531,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -503,7 +557,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -512,11 +566,11 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -588,7 +642,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -597,9 +651,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -608,14 +662,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if type(module) in _BN_MODULE_TYPES: + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -650,14 +704,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -694,11 +748,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -727,26 +781,25 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) +def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 + if mode == "symmetric": + quant_min = -(2 ** (bits - 1)) + quant_max = 2 ** (bits - 1) - 1 + dtype = torch.qint8 + else: + quant_min = 0 + quant_max = 2 ** bits - 1 + dtype = torch.quint8 + qconfig_kwargs.update( dict( quant_min=quant_min, diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 2a35ebd2aaf..acbae885d71 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -234,7 +234,9 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -332,7 +334,7 @@ def quantize_conv_output_activations(self) -> bool: :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_linear_output_activations + return self._quantize_conv_output_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -358,7 +360,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -504,11 +505,12 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) + # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -531,11 +533,23 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) + if len(to_remove_layer_name) == 0: + to_remove_layer_name = None configure_module_bn_wrappers(module) @@ -560,7 +574,8 @@ def _enable_module_qat(self, module: Module): configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) - remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) + if to_remove_layer_name: + remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) # remove qconfigs for module types in exclude_module_types to_exclude = [] @@ -568,10 +583,11 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude - self._strip_excluded_module_qconfigs(module) + if self._exclude_module_types: + self._strip_excluded_module_qconfigs(module) # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) @@ -638,10 +654,14 @@ def _calibrate(self, module): module.train() def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.activation_qconfig_kwargs, self.activation_bits) + return get_updated_qconfig_kwargs( + self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" + ) def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs(self.weight_qconfig_kwargs, self.weight_bits) + return get_updated_qconfig_kwargs( + self.weight_qconfig_kwargs, self.weight_bits, "symmetric" + ) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( From cb77f955c3628a2a98099d17a2e6808c6d97b9ec Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 14 Mar 2022 15:35:52 -0400 Subject: [PATCH 191/218] Temporary files for evaluating changes to graphs. --- .../pytorch/models/classification/resnet.py | 53 +++++++++---------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index be4182891d6..21611f211d7 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,7 +41,6 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU - try: from torch.nn.quantized import FloatFunctional except Exception: @@ -146,7 +145,7 @@ def __init__(self, num_channels): if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} + self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -206,12 +205,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -322,12 +321,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -438,15 +437,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -480,10 +479,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() From 4299dba2c80753ba0a10078c0f8bed5b1711cf19 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 17 Mar 2022 11:51:50 -0400 Subject: [PATCH 192/218] Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. --- .../sparsification/quantization/helpers.py | 213 ++++++++++-------- .../quantization/modifier_quantization.py | 58 ++--- 2 files changed, 137 insertions(+), 134 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 6110a499b70..8ae045de9e8 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,7 +22,6 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU - try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -32,7 +31,6 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm - __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -209,10 +207,10 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - reduce_range: bool = None, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = None, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -244,7 +242,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -257,20 +255,20 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): super().__init__() @@ -389,12 +387,12 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -451,23 +449,23 @@ def configure_module_bn_wrappers(module: Module): weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, "freeze_bn_stats"): + if not hasattr(module, 'freeze_bn_stats'): for child_name, child_module in module.named_children(): - if type(child_module) in [ - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - ]: - setattr(module, child_name, BNWrapper(child_module)) + if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: + setattr( + module, + child_name, + BNWrapper(child_module) + ) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -506,6 +504,17 @@ def configure_module_qat_wrappers( ) +def compute_range(dtype: torch.dtype, bits: int): + if dtype == torch.qint8: + quant_min = -2 ** (bits - 1) + quant_max = 2 ** (bits - 1) - 1 + elif dtype == torch.quint8: + quant_min = 0 + quant_max = 2 ** bits - 1 + + return quant_min, quant_max + + def configure_module_default_qconfigs(module: Module): """ if any submodule of the given module has a configure_qconfig function, @@ -516,7 +525,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -531,9 +540,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -557,7 +566,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -566,11 +575,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + symmetric_activations: bool = False, + symmetric_weights: bool = True, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = torch.quint8, + weight_dtype: Optional[torch.dtype] = torch.qint8, + activation_bits: Optional[int] = 8, + weight_bits: Optional[int] = 8, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -591,42 +604,35 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_qscheme = ( - torch.per_tensor_symmetric if symmetric_activations else torch.per_tensor_affine - ) - activation_observer_kwargs = dict( - observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - dtype=torch.quint8, - qscheme=activation_qscheme, - reduce_range=reduce_range, - ) - activation_observer_kwargs.update(activation_qconfig_kwargs or {}) - activation_observer = torch_quantization.FakeQuantize.with_args( - **activation_observer_kwargs, + activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, + activation_qconfig_kwargs) + weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, ) - weight_qscheme = ( - torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + + +def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): + qscheme = ( + torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - weight_observer_kwargs = dict( + quant_min, quant_max = compute_range(dtype, bits) + observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, - quant_min=-128, - quant_max=127, - dtype=torch.qint8, - qscheme=weight_qscheme, + quant_min=quant_min, + quant_max=quant_max, + dtype=dtype, + qscheme=qscheme, reduce_range=reduce_range, ) - - weight_observer_kwargs.update(weight_qconfig_kwargs or {}) - weight_observer = torch_quantization.FakeQuantize.with_args( - **weight_observer_kwargs, - ) - return torch_quantization.QConfig( - activation=activation_observer, - weight=weight_observer, + observer_kwargs.update(qconfig_kwargs or {}) + observer = torch_quantization.FakeQuantize.with_args( + **observer_kwargs, ) + return observer + def fix_observer_quant_range(module: Module): """ @@ -642,7 +648,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -651,9 +657,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min @@ -662,14 +668,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, "freeze_bn_stats"): + if hasattr(module, 'freeze_bn_stats'): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -704,14 +710,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -748,11 +754,11 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -782,17 +788,24 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} + qconfig_kwargs = ( + qconfig_kwargs.copy() + if qconfig_kwargs + else {} + ) # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): + if bits and ( + qconfig_kwargs.get("quant_min") + or qconfig_kwargs.get("quant_max") + ): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) + quant_min = -2 ** (bits - 1) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index acbae885d71..5a5e1913b18 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -55,7 +55,6 @@ freeze_bn_stats, fuse_module_conv_bn_relus, get_qat_qconfig, - get_updated_qconfig_kwargs, prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) @@ -151,6 +150,7 @@ def __init__( exclude_module_types: Union[List[str], None] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + tensorrt: Optional[bool] = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -187,6 +187,7 @@ def __init__( self._bn_stats_frozen = False self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._tensorrt = tensorrt self._calibration_dataloader = None self._calibration_function = None @@ -234,9 +235,10 @@ def model_fuse_fn_name(self) -> Union[str, None]: to performing QAT. None to uses the default function `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - fuse_fn = ( - self._model_fuse_fn_name if self._model_fuse_fn_name else "conv_bn_relus" - ) + if self._tensorrt: + fuse_fn = 'no_fuse' + else: + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @model_fuse_fn_name.setter @@ -360,6 +362,7 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits + @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -505,12 +508,11 @@ def _check_quantization_update( if self._freeze_bn_stats_update_ready(epoch): for _, quant_module in self._modules_to_quantize: quant_module.apply(freeze_bn_stats) - # quant_module.apply(torch_intrinsic.qat.freeze_bn_stats) self._bn_stats_frozen = True def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == "conv_bn_relus": + if self.model_fuse_fn_name == 'conv_bn_relus': self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -524,29 +526,16 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) - activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs() - weight_qconfig_kwargs = self._get_updated_weight_qconfig_kwargs() - to_remove_layer_name = [] if not self._quantize_linear_output_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) if not self._quantize_conv_output_activations: to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] + ["Conv1d", "Conv2d", "Conv3d", + "ConvBn1d", "ConvBn2d", "ConvBn3d", + "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", + "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -554,10 +543,21 @@ def _enable_module_qat(self, module: Module): configure_module_bn_wrappers(module) # prepare each module / submodule for quantization + if self.tensorrt: + _symmetric_activations = True + _activations_dtype = torch.qint8 + else: + _symmetric_activations = False + _activations_dtype = torch.quint8 + qconfig = get_qat_qconfig( + symmetric_activations=_symmetric_activations, reduce_range=self._reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=_activations_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) @@ -583,7 +583,7 @@ def _enable_module_qat(self, module: Module): to_exclude.extend(self._exclude_module_types) if self._exclude_batchnorm: - to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) + to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) self._exclude_module_types = to_exclude if self._exclude_module_types: @@ -653,16 +653,6 @@ def _calibrate(self, module): if module_training: module.train() - def _get_updated_activation_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.activation_qconfig_kwargs, self.activation_bits, "asymmetric" - ) - - def _get_updated_weight_qconfig_kwargs(self): - return get_updated_qconfig_kwargs( - self.weight_qconfig_kwargs, self.weight_bits, "symmetric" - ) - def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return ( self._disable_quantization_observer_epoch is not None From b3d7340394b7274d8d7011167d1a82d3d9d58346 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Sun, 20 Mar 2022 11:42:14 -0400 Subject: [PATCH 193/218] Added support to TensorRT quantization --- .../sparsification/quantization/helpers.py | 166 ++++++++++++++++-- .../quantization/modifier_quantization.py | 61 +++++-- 2 files changed, 195 insertions(+), 32 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 8ae045de9e8..027c7514c32 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -208,9 +208,15 @@ class QATWrapper(Module): @staticmethod def from_module( module: Module, - reduce_range: bool = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -232,6 +238,18 @@ def from_module( else {} ) + qat_wrapper_kwargs["symmetric_activations"] = ( + symmetric_activations + if "symmetric_activations" not in qat_wrapper_kwargs + else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] + ) + + qat_wrapper_kwargs["symmetric_weights"] = ( + symmetric_weights or False + if "symmetric_weights" not in qat_wrapper_kwargs + else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] + ) + qat_wrapper_kwargs["reduce_range"] = ( reduce_range or False if "reduce_range" not in qat_wrapper_kwargs @@ -251,6 +269,30 @@ def from_module( else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] ) + qat_wrapper_kwargs["activation_dtype"] = ( + activation_dtype + if "activation_dtype" not in qat_wrapper_kwargs + else activation_dtype or qat_wrapper_kwargs["activation_dtype"] + ) + + qat_wrapper_kwargs["weight_dtype"] = ( + weight_dtype + if "weight_dtype" not in qat_wrapper_kwargs + else weight_dtype or qat_wrapper_kwargs["weight_dtype"] + ) + + qat_wrapper_kwargs["activation_bits"] = ( + activation_bits + if "activation_bits" not in qat_wrapper_kwargs + else activation_bits or qat_wrapper_kwargs["activation_bits"] + ) + + qat_wrapper_kwargs["weight_bits"] = ( + weight_bits + if "weight_bits" not in qat_wrapper_kwargs + else weight_bits or qat_wrapper_kwargs["weight_bits"] + ) + module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -266,9 +308,15 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -288,25 +336,43 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn + self._symmetric_activations = symmetric_activations + self._symmetric_weights = symmetric_weights self._reduce_range = reduce_range self._activation_qconfig_kwargs = activation_qconfig_kwargs self._weight_qconfig_kwargs = weight_qconfig_kwargs + self._activation_dtype = activation_dtype + self._weight_dtype = weight_dtype + self._activation_bits = activation_bits + self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, + symmetric_activations=self._symmetric_activations, + symmetric_weights=self._symmetric_weights, reduce_range=self._reduce_range, activation_qconfig_kwargs=self._activation_qconfig_kwargs, weight_qconfig_kwargs=self._weight_qconfig_kwargs, + activation_dtype=self._activation_dtype, + weight_dtype=self._weight_dtype, + activation_bits=self._activation_bits, + weight_bits=self._weight_bits, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -390,9 +456,15 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -422,11 +494,21 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) + if symmetric_activations is None: + _symmetric_activations = qconfig == "symmetric" + else: + _symmetric_activations = symmetric_activations + qconfigs[idx] = get_qat_qconfig( - symmetric_activations=(qconfig == "symmetric"), + symmetric_activations=_symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) return qconfigs @@ -463,9 +545,15 @@ def configure_module_bn_wrappers(module: Module): def configure_module_qat_wrappers( module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Dict[str, Any] = {}, weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -490,29 +578,43 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, + symmetric_activations=symmetric_activations, + symmetric_weights=symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) -def compute_range(dtype: torch.dtype, bits: int): +def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + dtype = dtype if dtype else torch.quint8 + bits = bits if bits else 8 if dtype == torch.qint8: - quant_min = -2 ** (bits - 1) - quant_max = 2 ** (bits - 1) - 1 + quant_min = -(2 ** (bits - 1)) + quant_max = (2 ** (bits - 1)) - 1 elif dtype == torch.quint8: quant_min = 0 - quant_max = 2 ** bits - 1 + quant_max = (2 ** bits) - 1 - return quant_min, quant_max + return quant_min, quant_max, dtype def configure_module_default_qconfigs(module: Module): @@ -575,15 +677,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: bool = False, - symmetric_weights: bool = True, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = torch.quint8, - weight_dtype: Optional[torch.dtype] = torch.qint8, - activation_bits: Optional[int] = 8, - weight_bits: Optional[int] = 8, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -606,18 +708,28 @@ def get_qat_qconfig( """ activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, activation_qconfig_kwargs) - weight_observer = get_observer(symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + if symmetric_weights is None: + _symmetric_weights = True + else: + _symmetric_weights = symmetric_weights + + if weight_dtype is None: + _weight_dtype = torch.qint8 + else: + _weight_dtype = weight_dtype + + weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: bool, dtype: torch.dtype, bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any]): +def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): qscheme = ( torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine ) - quant_min, quant_max = compute_range(dtype, bits) + quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, quant_min=quant_min, @@ -756,9 +868,15 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -776,11 +894,21 @@ def prepare_embeddings_qat( Default is False """ if qconfig is None: + if symmetric_weights is None: + _symmetric_weights = False + else: + _symmetric_weights = symmetric_weights + qconfig = get_qat_qconfig( - symmetric_weights=False, + symmetric_activations=symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=reduce_range, activation_qconfig_kwargs=activation_qconfig_kwargs, weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_dtype=activation_dtype, + weight_dtype=weight_dtype, + activation_bits=activation_bits, + weight_bits=weight_bits, ) for submodule in module.modules(): if type(submodule) is Embedding: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 5a5e1913b18..27c5a4c336e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -147,10 +147,10 @@ def __init__( weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_batchnorm: bool = True, - exclude_module_types: Union[List[str], None] = None, + exclude_module_types: Optional[List[str]] = None, activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - tensorrt: Optional[bool] = False, + tensorrt: bool = False, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -379,7 +379,15 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - return self._weight_qconfig_kwargs + if "observer" in self._weight_qconfig_kwargs: + kwargs = self._weight_qconfig_kwargs.copy() + if kwargs["observer"] == "minmaxobserver": + kwargs["observer"] = torch_quantization.MinMaxObserver + return kwargs + else: + return self._weight_qconfig_kwargs + + @ModifierProp() def num_calibration_steps(self) -> Optional[int]: @@ -389,6 +397,15 @@ def num_calibration_steps(self) -> Optional[int]: """ return self._num_calibration_steps + @ModifierProp() + def tensorrt(self) -> Dict[str, Any]: + """ + :return: Dictionary with correct quant_min, quant_max, and dtype values + for activations + + """ + return self._tensorrt + def initialize( self, module: Module, @@ -545,17 +562,23 @@ def _enable_module_qat(self, module: Module): # prepare each module / submodule for quantization if self.tensorrt: _symmetric_activations = True - _activations_dtype = torch.qint8 + _activation_dtype = torch.qint8 + _symmetric_weights = True + _weight_dtype = torch.qint8 else: - _symmetric_activations = False - _activations_dtype = torch.quint8 + _symmetric_activations = None + _activation_dtype = None + _symmetric_weights = None + _weight_dtype = None qconfig = get_qat_qconfig( symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=_activations_dtype, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, activation_bits=self.activation_bits, weight_bits=self.weight_bits ) @@ -563,9 +586,15 @@ def _enable_module_qat(self, module: Module): # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( quant_module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits ) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -594,9 +623,15 @@ def _enable_module_qat(self, module: Module): if self._quantize_embeddings: prepare_embeddings_qat( module, + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, reduce_range=self._reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, ) # propagate custom quant min/max range from FakeQuantize to Observer objects From 6976650b41ee7ca4bd5748ff33ba48331a2b76e8 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 21 Mar 2022 19:16:26 -0400 Subject: [PATCH 194/218] Included check to account for when weight_qconfig_kwatgs is None. --- .../sparsification/quantization/modifier_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 27c5a4c336e..a306f4d8e73 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -379,7 +379,7 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if "observer" in self._weight_qconfig_kwargs: + if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver From bdb1d217288252e16b88551e660409b0c8751398 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 14:20:19 -0400 Subject: [PATCH 195/218] Modified argument names for backwards compatibility. --- .../quantization/modifier_quantization.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index a306f4d8e73..73a50e0f9c4 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -141,8 +141,8 @@ def __init__( model_fuse_fn_kwargs: Dict[str, Any] = None, quantize_embeddings: bool = True, reduce_range: bool = False, - quantize_linear_output_activations: bool = False, - quantize_conv_output_activations: bool = False, + quantize_linear_activations: bool = False, + quantize_conv_activations: bool = False, activation_bits: Optional[int] = None, weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, @@ -174,8 +174,8 @@ def __init__( self._freeze_bn_stats_epoch = freeze_bn_stats_epoch self._quantize_embeddings = quantize_embeddings self._reduce_range = reduce_range - self._quantize_linear_output_activations = quantize_linear_output_activations - self._quantize_conv_output_activations = quantize_conv_output_activations + self._quantize_linear_activations = quantize_linear_activations + self._quantize_conv_activations = quantize_conv_activations self._activation_bits = activation_bits self._weight_bits = weight_bits self._exclude_batchnorm = exclude_batchnorm @@ -320,7 +320,7 @@ def reduce_range(self) -> bool: return self._reduce_range @ModifierProp() - def quantize_linear_output_activations(self) -> bool: + def quantize_linear_activations(self) -> bool: """ :return: if False, FakeQuantize ops will not be run for activations of fully connected layers. this is important for quantizing @@ -328,15 +328,15 @@ def quantize_linear_output_activations(self) -> bool: are kept at 32 bits of precision and fake quantizing the outputs harm training recovery """ - return self._quantize_linear_output_activations + return self._quantize_linear_activations @ModifierProp() - def quantize_conv_output_activations(self) -> bool: + def quantize_conv_activations(self) -> bool: """ :return: if False, FakeQuantize ops will not be run for activations of convolutional layers. """ - return self._quantize_conv_output_activations + return self._quantize_conv_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -544,10 +544,10 @@ def _enable_module_qat(self, module: Module): module_fuse_fn(**self._model_fuse_fn_kwargs) to_remove_layer_name = [] - if not self._quantize_linear_output_activations: + if not self._quantize_linear_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) - if not self._quantize_conv_output_activations: + if not self._quantize_conv_activations: to_remove_layer_name.extend( ["Conv1d", "Conv2d", "Conv3d", "ConvBn1d", "ConvBn2d", "ConvBn3d", From 491b071e5b85c8956a654b3dd429424d4e52ceb4 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:40:51 -0400 Subject: [PATCH 196/218] Updated documentation to reflect changes. --- .../sparsification/quantization/helpers.py | 118 ++++++++++++------ 1 file changed, 81 insertions(+), 37 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 027c7514c32..bc9aeb6d58c 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -69,8 +69,14 @@ else None ) - +# class BNWrapper(Module): + """ + Wraps BatchNormalization module to expose methods needed to enable + freezing/unfreezing of statistics + + :param module: BatchNormalization module to be wrapped + """ def __init__(self, module: Module): super().__init__() self.bn = module @@ -220,14 +226,25 @@ def from_module( ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for - :param reduce_range: if True, the quantization range will be reduced by one - bit. This may prevent overflow issues with model execution on certain - hardware. Default is None, will only override qat_wrapper_kwargs if set - to a bool value + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. :return: QATWrapper object created using the given Module as the forward function. Will attempt to find any other named parameter of the QATWrapper constructor from the attributes of the given Module @@ -293,6 +310,7 @@ def from_module( else weight_bits or qat_wrapper_kwargs["weight_bits"] ) + # Remove qconfig from wrapped layer to avoid duplicate quantization module.qconfig = None return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) @@ -516,19 +534,10 @@ def _load_qconfigs( def configure_module_bn_wrappers(module: Module): """ - if any submodule of the given module has the attribute wrap_qat == True, - then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. - Other named kwargs to the QATWrapper constructor must be contained in a dictionary - under an attributed named `qat_wrapper_kwargs` + Wrap any BatchNormalization modules that are not fused with convolutions + with BNWrapper to enable freezing/unfreezing of BN statistics :param module: module to potentially wrap the submodules of - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} """ # wrap any children of the given module as a QATWrapper if required if not hasattr(module, 'freeze_bn_stats'): @@ -562,14 +571,25 @@ def configure_module_qat_wrappers( under an attributed named `qat_wrapper_kwargs` :param module: module to potentially wrap the submodules of + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. Default is {} + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. Default is {} - """ + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. """ # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: @@ -605,6 +625,13 @@ def configure_module_qat_wrappers( def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): + """ + compute quantization limits depending on data type and number of bits + + :param dtype: data type. If None dtype is set to torch.quint8. + :param bits: number of bits. If None is set to 8. + :return: minimum limit, maximum limit, data type + """ dtype = dtype if dtype else torch.quint8 bits = bits if bits else 8 if dtype == torch.qint8: @@ -689,18 +716,24 @@ def get_qat_qconfig( ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric - UINT8 quantization range with zero point set to 128. Otherwise activations - will use asymmetric quantization with any zero point. Default is False + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. :param symmetric_weights: if True, weights will have a symmetric - INT8 quantization range with zero point set to 0. Otherwise activations - will use asymmetric quantization with any zero point. Default is True + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False + This may prevent overflow issues with model execution on certain hardware. + Default is False. :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. :return: A QAT fake quantization config for symmetric weight quantization and asymmetric activation quantization. The difference between this and torch.quantization.default_qat_qconfig is that the activation observer @@ -885,14 +918,25 @@ def prepare_embeddings_qat( :param module: module to run QAT for the embeddings of :param qconfig: qconfig to generate the fake quantize ops from. Default uses INT8 asymmetric range - :param activation_qconfig_kwargs: additional kwargs for quantizing activations. - Default is {}. - :param weight_qconfig_kwargs: additional kwargs for quantizing the weights. - Default is {}. + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware. - Default is False - """ + Default is False. + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. + :param activation_dtype: quantized activation data type. Default is torch.quint8. + :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. """ if qconfig is None: if symmetric_weights is None: _symmetric_weights = False From e40b509799961648a8611ae69920f1d9da082546 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:40:57 -0400 Subject: [PATCH 197/218] Updated documentation to reflect changes. --- .../quantization/modifier_quantization.py | 59 ++++++++++++------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 73a50e0f9c4..4f912b3d8bb 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -113,21 +113,26 @@ class QuantizationModifier(ScheduledModifier): :param reduce_range: if True, the quantization range will be reduced by one bit. This may prevent overflow issues with model execution on certain hardware Default is False - :param quantize_linear_activations: if False, FakeQuantize ops will not be run - for activations of fully connected layers. this is important for quantizing - transformer based models such as BERT where the quantized MatMul outputs - are kept at 32 bits of precision and fake quantizing the outputs harm training - recovery. Default is True + :param quantize_linear_activations: if True, FakeQuantize ops will be run + for output activations of fully connected layers. Default is False. + :param quantize_conv_activations: if True, FakeQuantize ops will be run + for output activations of convolutional layers. Default is False. :param activation_bits: Number of bits to use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. + activations. Default is None, which will quantize activations to 8 bits. + :param weight_bits: Number of bits to use for setting quant min/max values for + weights. Default is None, which will quantize weights to 8 bits. :param num_calibration_steps: Number of steps to run post training calibration for. - When None, the entire calibration_dataloader is used + When None, the entire calibration_dataloader is used + :param exclude_batchnorm: If True, do not propagate quantization qconfigs to + batch-normalization modules :param exclude_module_types: optional list of module class names to not propagate quantization configs to. Default is None :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. + activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. + weights. + :param tenssorrt: if True sets quantization configuration for compatibility with + explict quantization as supported by TensorRT 8.2. """ def __init__( @@ -232,11 +237,12 @@ def submodules(self, value: Union[List[str], None]): def model_fuse_fn_name(self) -> Union[str, None]: """ :return: Name of model function to fuse the model in place prior - to performing QAT. None to uses the default function + to performing QAT. None sets to default function. + If tensorrt flag is True, default is 'no_fuse', otherwise `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ if self._tensorrt: - fuse_fn = 'no_fuse' + fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' else: fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' return fuse_fn @@ -322,19 +328,16 @@ def reduce_range(self) -> bool: @ModifierProp() def quantize_linear_activations(self) -> bool: """ - :return: if False, FakeQuantize ops will not be run - for activations of fully connected layers. this is important for quantizing - transformer based models such as BERT where the quantized MatMul outputs - are kept at 32 bits of precision and fake quantizing the outputs harm - training recovery + :return: if True, FakeQuantize ops will be run for output activations + of fully connected layers """ return self._quantize_linear_activations @ModifierProp() def quantize_conv_activations(self) -> bool: """ - :return: if False, FakeQuantize ops will not be run - for activations of convolutional layers. + :return: if True, FakeQuantize ops will be run for output activations + of convolutional layers """ return self._quantize_conv_activations @@ -358,7 +361,7 @@ def activation_bits(self) -> Optional[int]: def weight_bits(self) -> Optional[int]: """ :return: Number of bits to be use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. + weights. Default is None, which will quantize weights to 8 bits. """ return self._weight_bits @@ -543,6 +546,7 @@ def _enable_module_qat(self, module: Module): ) module_fuse_fn(**self._model_fuse_fn_kwargs) + # build list of layer types that should not quantize output activations to_remove_layer_name = [] if not self._quantize_linear_activations: to_remove_layer_name.extend(["Linear", "LinearReLU"]) @@ -557,9 +561,16 @@ def _enable_module_qat(self, module: Module): if len(to_remove_layer_name) == 0: to_remove_layer_name = None + # fix for freezing batchnorm statistics when not fusing BN with convs. + # pytorch only supports freezing batchnorm statistics for fused modules. + # this fix wraps BN modules adding with a new module class that supports + # methods related to freezing/unfreezing BN statistics. configure_module_bn_wrappers(module) - # prepare each module / submodule for quantization + # set qconfig. + # if tensorrt flag is used, set activation and weights to symmetric + # quantization. + # otherwise, use the default values set in get_qat_qconfig if self.tensorrt: _symmetric_activations = True _activation_dtype = torch.qint8 @@ -582,6 +593,8 @@ def _enable_module_qat(self, module: Module): activation_bits=self.activation_bits, weight_bits=self.weight_bits ) + + # prepare each module / submodule for quantization for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) configure_module_qat_wrappers( @@ -596,13 +609,17 @@ def _enable_module_qat(self, module: Module): activation_bits=self.activation_bits, weight_bits=self.weight_bits ) + # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig + # wrap all conv / linear blocks in with quantization observers torch_quantization.propagate_qconfig_(quant_module) configure_module_default_qconfigs(quant_module) add_quant_dequant(quant_module, name, module) + + # Remove output quantization from appropriate modules if to_remove_layer_name: remove_activation_qat_by_layer_name(quant_module, to_remove_layer_name) @@ -611,6 +628,8 @@ def _enable_module_qat(self, module: Module): if self._exclude_module_types: to_exclude.extend(self._exclude_module_types) + # if exclude_batchnorm flag is used, add batch norm layers to list of + # modules to exclude qconfig if self._exclude_batchnorm: to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) From 10134b8a0f0fcc870b76c542ac5dec4b285b440c Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:42:27 -0400 Subject: [PATCH 198/218] Updated documentation to reflect changes. --- src/sparseml/pytorch/models/classification/resnet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 21611f211d7..3a7a5169447 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -140,6 +140,10 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool: class _AddReLU(Module): + """ + Wrapper for the FloatFunctional class that enables QATWrapper used to + quantize the first input to the Add operation + """ def __init__(self, num_channels): super().__init__() if FloatFunctional: From 4ed31b407cb2389c969a75ec1b27ed46422801dd Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 16:52:15 -0400 Subject: [PATCH 199/218] Fixed default weights data type. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index bc9aeb6d58c..b3e47162c5e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -751,7 +751,7 @@ def get_qat_qconfig( else: _weight_dtype = weight_dtype - weight_observer = get_observer(_symmetric_weights, weight_dtype, weight_bits, False, weight_qconfig_kwargs) + weight_observer = get_observer(_symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, From f5ec19f40f0676aaacefdcd0c7c82fa66a4b0b63 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 17:02:48 -0400 Subject: [PATCH 200/218] Style and quality fixes. --- .../pytorch/models/classification/resnet.py | 54 ++-- .../sparsification/quantization/helpers.py | 247 +++++++++--------- .../quantization/modifier_quantization.py | 44 +++- 3 files changed, 186 insertions(+), 159 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 3a7a5169447..cd8b979c3ad 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -41,6 +41,7 @@ from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.nn import ReLU + try: from torch.nn.quantized import FloatFunctional except Exception: @@ -144,12 +145,13 @@ class _AddReLU(Module): Wrapper for the FloatFunctional class that enables QATWrapper used to quantize the first input to the Add operation """ + def __init__(self, num_channels): super().__init__() if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {'num_inputs': 1, 'num_outputs': 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -209,12 +211,12 @@ def initialize(self): class _BottleneckBlock(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -325,12 +327,12 @@ def initialize(self): class _BottleneckBlockV2(Module): def __init__( - self, - in_channels: int, - out_channels: int, - proj_channels: int, - stride: int = 1, - groups: int = 1, + self, + in_channels: int, + out_channels: int, + proj_channels: int, + stride: int = 1, + groups: int = 1, ): super().__init__() @@ -441,15 +443,15 @@ class ResNetSectionSettings(object): """ def __init__( - self, - num_blocks: int, - in_channels: int, - out_channels: int, - downsample: bool, - proj_channels: int = -1, - groups: int = 1, - use_se: bool = False, - version: int = 1, + self, + num_blocks: int, + in_channels: int, + out_channels: int, + downsample: bool, + proj_channels: int = -1, + groups: int = 1, + use_se: bool = False, + version: int = 1, ): if use_se: # TODO: add support for squeeze excite @@ -483,10 +485,10 @@ class ResNet(Module): """ def __init__( - self, - sec_settings: List[ResNetSectionSettings], - num_classes: int, - class_type: str, + self, + sec_settings: List[ResNetSectionSettings], + num_classes: int, + class_type: str, ): super().__init__() self.input = _Input() diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index b3e47162c5e..c2e21d30a16 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -22,6 +22,7 @@ import torch from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + try: import torch.nn.intrinsic as nni from torch import quantization as torch_quantization @@ -31,6 +32,7 @@ from sparseml.pytorch.nn import ReLU as ReLU_nm + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -69,7 +71,7 @@ else None ) -# + class BNWrapper(Module): """ Wraps BatchNormalization module to expose methods needed to enable @@ -77,6 +79,7 @@ class BNWrapper(Module): :param module: BatchNormalization module to be wrapped """ + def __init__(self, module: Module): super().__init__() self.bn = module @@ -213,16 +216,16 @@ class QATWrapper(Module): @staticmethod def from_module( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for @@ -241,8 +244,10 @@ def from_module( activations. :param weight_qconfig_kwargs: Additional kwargs for quantization of weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. + :param activation_dtype: quantized activation data type. + Default is torch.quint8. + :param weight_dtype: quantized weights data type. + Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. :param weight_bits: number of bits for weights. Default is 8. :return: QATWrapper object created using the given Module as the forward @@ -277,7 +282,7 @@ def from_module( activation_qconfig_kwargs if "activation_qconfig_kwargs" not in qat_wrapper_kwargs else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] + or qat_wrapper_kwargs["activation_qconfig_kwargs"] ) qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( @@ -315,26 +320,26 @@ def from_module( return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) def __init__( - self, - forward_fn: Callable[[Any], Any], - num_inputs: int = 1, - kwarg_input_names: List[str] = None, - num_outputs: int = 1, - input_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - output_qconfigs: Union[ - "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] - ] = "asymmetric", - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + self, + forward_fn: Callable[[Any], Any], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): super().__init__() @@ -471,18 +476,18 @@ def configure_qconfig(self): @staticmethod def _load_qconfigs( - name: str, - expected_len: int, - qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -540,29 +545,29 @@ def configure_module_bn_wrappers(module: Module): :param module: module to potentially wrap the submodules of """ # wrap any children of the given module as a QATWrapper if required - if not hasattr(module, 'freeze_bn_stats'): + if not hasattr(module, "freeze_bn_stats"): for child_name, child_module in module.named_children(): - if type(child_module) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - setattr( - module, - child_name, - BNWrapper(child_module) - ) + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) def configure_module_qat_wrappers( - module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Dict[str, Any] = {}, + weight_qconfig_kwargs: Dict[str, Any] = {}, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -589,7 +594,7 @@ def configure_module_qat_wrappers( :param activation_dtype: quantized activation data type. Default is torch.quint8. :param weight_dtype: quantized weights data type. Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. """ + :param weight_bits: number of bits for weights. Default is 8.""" # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: @@ -654,7 +659,7 @@ def configure_module_default_qconfigs(module: Module): """ for submodule in module.modules(): if hasattr(submodule, "configure_qconfig") and callable( - getattr(submodule, "configure_qconfig") + getattr(submodule, "configure_qconfig") ): submodule.configure_qconfig() @@ -669,9 +674,9 @@ def add_quant_dequant(module, name=None, parent_module=None): """ named_children = module.named_children() if ( - type(module) in _QUANTIZABLE_MODULE_TYPES - and hasattr(module, "qconfig") - and module.qconfig + type(module) in _QUANTIZABLE_MODULE_TYPES + and hasattr(module, "qconfig") + and module.qconfig ): if parent_module is not None and len(list(named_children)) <= 0: module = torch_quantization.QuantWrapper(module) @@ -695,7 +700,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ """ for submodule in module.modules(): if submodule.__class__.__name__ in layer_class_names and hasattr( - submodule, "qconfig" + submodule, "qconfig" ): submodule.qconfig = torch_quantization.QConfig( activation=torch.nn.Identity, @@ -704,15 +709,15 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ def get_qat_qconfig( - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric @@ -739,8 +744,13 @@ def get_qat_qconfig( torch.quantization.default_qat_qconfig is that the activation observer will not have reduce_range enabled. """ - activation_observer = get_observer(symmetric_activations, activation_dtype, activation_bits, reduce_range, - activation_qconfig_kwargs) + activation_observer = get_observer( + symmetric_activations, + activation_dtype, + activation_bits, + reduce_range, + activation_qconfig_kwargs, + ) if symmetric_weights is None: _symmetric_weights = True else: @@ -751,17 +761,23 @@ def get_qat_qconfig( else: _weight_dtype = weight_dtype - weight_observer = get_observer(_symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs) + weight_observer = get_observer( + _symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs + ) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, ) -def get_observer(symmetric: Optional[bool], dtype: Optional[torch.dtype], bits: Optional[int], reduce_range: bool, qconfig_kwargs: Dict[str, Any]): - qscheme = ( - torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine - ) +def get_observer( + symmetric: Optional[bool], + dtype: Optional[torch.dtype], + bits: Optional[int], + reduce_range: bool, + qconfig_kwargs: Dict[str, Any], +): + qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine quant_min, quant_max, dtype = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, @@ -793,7 +809,7 @@ def fix_observer_quant_range(module: Module): if isinstance(submodule, torch_quantization.FakeQuantize): fake_quantize = submodule elif hasattr(submodule, "activation_post_process") and isinstance( - submodule.activation_post_process, torch_quantization.FakeQuantize + submodule.activation_post_process, torch_quantization.FakeQuantize ): fake_quantize = submodule.activation_post_process else: @@ -813,14 +829,14 @@ def fix_observer_quant_range(module: Module): def freeze_bn_stats(module: Module): - if hasattr(module, 'freeze_bn_stats'): + if hasattr(module, "freeze_bn_stats"): module.freeze_bn_stats() def fuse_module_conv_bn_relus( - module: Module, - inplace: bool = True, - override_bn_subclasses_forward: Union[bool, str] = True, + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -855,14 +871,14 @@ def fuse_module_conv_bn_relus( for name, layer in module.named_modules(): submodule_name = ".".join(name.split(".")[:-1]) if ( - len(current_block) == 1 # [Conv2d] - and isinstance(layer, BatchNorm2d) - and submodule_name == current_block_submodule_name + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name ) or ( - len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] - and isinstance(layer, ReLU) - and not isinstance(current_block[-1], ReLU) - and submodule_name == current_block_submodule_name + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) @@ -899,17 +915,17 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( - module: Module, - qconfig: "torch.quantization.QConfig" = None, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + module: Module, + qconfig: "torch.quantization.QConfig" = None, + symmetric_activations: Optional[bool] = None, + symmetric_weights: Optional[bool] = None, + reduce_range: bool = False, + activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, + weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, + activation_dtype: Optional[torch.dtype] = None, + weight_dtype: Optional[torch.dtype] = None, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -936,7 +952,7 @@ def prepare_embeddings_qat( :param activation_dtype: quantized activation data type. Default is torch.quint8. :param weight_dtype: quantized weights data type. Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. """ + :param weight_bits: number of bits for weights. Default is 8.""" if qconfig is None: if symmetric_weights is None: _symmetric_weights = False @@ -960,24 +976,17 @@ def prepare_embeddings_qat( def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = ( - qconfig_kwargs.copy() - if qconfig_kwargs - else {} - ) + qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} # update qconfig_kwargs for bits - if bits and ( - qconfig_kwargs.get("quant_min") - or qconfig_kwargs.get("quant_max") - ): + if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): raise ValueError( "Cannot override quant_max and quant_min when number of bits is set" ) if bits: if mode == "symmetric": - quant_min = -2 ** (bits - 1) + quant_min = -(2 ** (bits - 1)) quant_max = 2 ** (bits - 1) - 1 dtype = torch.qint8 else: diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 4f912b3d8bb..30e1aefbe15 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -242,9 +242,15 @@ def model_fuse_fn_name(self) -> Union[str, None]: `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ if self._tensorrt: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'no_fuse' + fuse_fn = ( + self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" + ) else: - fuse_fn = self._model_fuse_fn_name if self._model_fuse_fn_name else 'conv_bn_relus' + fuse_fn = ( + self._model_fuse_fn_name + if self._model_fuse_fn_name + else "conv_bn_relus" + ) return fuse_fn @model_fuse_fn_name.setter @@ -365,7 +371,6 @@ def weight_bits(self) -> Optional[int]: """ return self._weight_bits - @ModifierProp() def activation_qconfig_kwargs(self) -> Dict[str, Any]: """ @@ -382,7 +387,10 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: for weights """ - if self._weight_qconfig_kwargs is not None and "observer" in self._weight_qconfig_kwargs: + if ( + self._weight_qconfig_kwargs is not None + and "observer" in self._weight_qconfig_kwargs + ): kwargs = self._weight_qconfig_kwargs.copy() if kwargs["observer"] == "minmaxobserver": kwargs["observer"] = torch_quantization.MinMaxObserver @@ -390,8 +398,6 @@ def weight_qconfig_kwargs(self) -> Dict[str, Any]: else: return self._weight_qconfig_kwargs - - @ModifierProp() def num_calibration_steps(self) -> Optional[int]: """ @@ -532,7 +538,7 @@ def _check_quantization_update( def _enable_module_qat(self, module: Module): # fuse module Conv-BNs - if self.model_fuse_fn_name == 'conv_bn_relus': + if self.model_fuse_fn_name == "conv_bn_relus": self._model_fuse_fn_kwargs["inplace"] = True fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) elif self.model_fuse_fn_name != "no_fuse": @@ -553,10 +559,20 @@ def _enable_module_qat(self, module: Module): if not self._quantize_conv_activations: to_remove_layer_name.extend( - ["Conv1d", "Conv2d", "Conv3d", - "ConvBn1d", "ConvBn2d", "ConvBn3d", - "ConvReLU1d", "ConvReLU2d", "ConvReLU3d", - "ConvBnReLU1d", "ConvBnReLU2d", "ConvBnReLU3d"] + [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + ] ) if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -591,7 +607,7 @@ def _enable_module_qat(self, module: Module): activation_dtype=_activation_dtype, weight_dtype=_weight_dtype, activation_bits=self.activation_bits, - weight_bits=self.weight_bits + weight_bits=self.weight_bits, ) # prepare each module / submodule for quantization @@ -607,7 +623,7 @@ def _enable_module_qat(self, module: Module): activation_dtype=_activation_dtype, weight_dtype=_weight_dtype, activation_bits=self.activation_bits, - weight_bits=self.weight_bits + weight_bits=self.weight_bits, ) # set quantization config (asymmetric activations, symmetric weights) @@ -631,7 +647,7 @@ def _enable_module_qat(self, module: Module): # if exclude_batchnorm flag is used, add batch norm layers to list of # modules to exclude qconfig if self._exclude_batchnorm: - to_exclude.extend(['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']) + to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"]) self._exclude_module_types = to_exclude if self._exclude_module_types: From cb9d75e202af7dddbc57105c632b5ef9af74dbc9 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 30 Mar 2022 17:53:05 -0400 Subject: [PATCH 201/218] Removed unused method --- .../sparsification/quantization/helpers.py | 31 ------------------- 1 file changed, 31 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index c2e21d30a16..6c30789fbb7 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -41,7 +41,6 @@ "add_quant_dequant", "remove_activation_qat_by_layer_name", "get_qat_qconfig", - "get_updated_qconfig_kwargs", "fix_observer_quant_range", "freeze_bn_stats", "fuse_module_conv_bn_relus", @@ -975,36 +974,6 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, qconfig) -def get_updated_qconfig_kwargs(qconfig_kwargs, bits, mode): - qconfig_kwargs = qconfig_kwargs.copy() if qconfig_kwargs else {} - - # update qconfig_kwargs for bits - if bits and (qconfig_kwargs.get("quant_min") or qconfig_kwargs.get("quant_max")): - raise ValueError( - "Cannot override quant_max and quant_min when number of bits is set" - ) - - if bits: - if mode == "symmetric": - quant_min = -(2 ** (bits - 1)) - quant_max = 2 ** (bits - 1) - 1 - dtype = torch.qint8 - else: - quant_min = 0 - quant_max = 2 ** bits - 1 - dtype = torch.quint8 - - qconfig_kwargs.update( - dict( - quant_min=quant_min, - quant_max=quant_max, - dtype=dtype, - ) - ) - - return qconfig_kwargs - - def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"): embedding.weight_fake_quant = qconfig.weight() From 60eed5cadf28728eefe2a51f85381e782911d5c8 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 31 Mar 2022 10:05:55 -0400 Subject: [PATCH 202/218] Removed testing files --- sandbox/quantization_recipe.yaml | 7 ------- sandbox/quantization_test.py | 23 ----------------------- 2 files changed, 30 deletions(-) delete mode 100644 sandbox/quantization_recipe.yaml delete mode 100644 sandbox/quantization_test.py diff --git a/sandbox/quantization_recipe.yaml b/sandbox/quantization_recipe.yaml deleted file mode 100644 index 411dd6f025a..00000000000 --- a/sandbox/quantization_recipe.yaml +++ /dev/null @@ -1,7 +0,0 @@ -quantization_modifiers: - - !QuantizationModifier - start_epoch: -1.0 - model_fuse_fn_name: no_fuse - submodules: - - input - - sections diff --git a/sandbox/quantization_test.py b/sandbox/quantization_test.py deleted file mode 100644 index ea6fba5acd5..00000000000 --- a/sandbox/quantization_test.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch -from sparseml.pytorch.utils import ModuleExporter -from sparseml.pytorch.models import ModelRegistry -from sparseml.pytorch.optim import ScheduledModifierManager - -model = ModelRegistry.create( - key='resnet50', - pretrained=False, - pretrained_dataset="imagenet", - num_classes=1000 -) - - -ScheduledModifierManager.from_yaml("quantization_recipe.yaml").apply(model, epoch=float("inf")) - -print(model) - -exporter = ModuleExporter(model, ".") -exporter.export_onnx( - torch.randn(1, 3, 224, 224), - "quantized_test.onnx", - convert_qat=False, -) \ No newline at end of file From c1c5e14ab17acd930250273819676d5de6cf47f8 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 31 Mar 2022 10:12:29 -0400 Subject: [PATCH 203/218] Style and quality fixes. --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 6c30789fbb7..fa92e5fab46 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -817,9 +817,9 @@ def fix_observer_quant_range(module: Module): # continue if fake_quantize quant range not set, or observer quant range is set observer = fake_quantize.activation_post_process if ( - fake_quantize.quant_min is None - or fake_quantize.quant_max is None - or (observer.quant_min is not None or observer.quant_max is not None) + fake_quantize.quant_min is None + or fake_quantize.quant_max is None + or (observer.quant_min is not None or observer.quant_max is not None) ): continue observer.quant_min = fake_quantize.quant_min From c0148729c04f4480ffd1eb4a7f670f6ec7a6f874 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 4 Apr 2022 21:01:33 -0400 Subject: [PATCH 204/218] Changed call to get_qat_qconfig to not specify symmetry and data type arguments for default case. --- .../quantization/modifier_quantization.py | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 30e1aefbe15..210e57df76b 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -588,27 +588,25 @@ def _enable_module_qat(self, module: Module): # quantization. # otherwise, use the default values set in get_qat_qconfig if self.tensorrt: - _symmetric_activations = True - _activation_dtype = torch.qint8 - _symmetric_weights = True - _weight_dtype = torch.qint8 + qconfig = get_qat_qconfig( + symmetric_activations=True, + symmetric_weights=True, + reduce_range=self._reduce_range, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=torch.qint8, + weight_dtype=torch.qint8, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, + ) else: - _symmetric_activations = None - _activation_dtype = None - _symmetric_weights = None - _weight_dtype = None - - qconfig = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, - reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, - ) + qconfig = get_qat_qconfig( + reduce_range=self._reduce_range, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, + ) # prepare each module / submodule for quantization for name, quant_module in self._modules_to_quantize: From a32490215c914e50743dffe1d2e7b1b1be6126f8 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 4 Apr 2022 21:05:59 -0400 Subject: [PATCH 205/218] Changed default number of activation and weight bits from None to 8. --- .../sparsification/quantization/modifier_quantization.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 210e57df76b..5fd66c58359 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -118,9 +118,9 @@ class QuantizationModifier(ScheduledModifier): :param quantize_conv_activations: if True, FakeQuantize ops will be run for output activations of convolutional layers. Default is False. :param activation_bits: Number of bits to use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. + activations. Default is 8. :param weight_bits: Number of bits to use for setting quant min/max values for - weights. Default is None, which will quantize weights to 8 bits. + weights. Default is 8. :param num_calibration_steps: Number of steps to run post training calibration for. When None, the entire calibration_dataloader is used :param exclude_batchnorm: If True, do not propagate quantization qconfigs to @@ -148,8 +148,8 @@ def __init__( reduce_range: bool = False, quantize_linear_activations: bool = False, quantize_conv_activations: bool = False, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + activation_bits: int = 8, + weight_bits: int = 8, num_calibration_steps: Optional[int] = None, exclude_batchnorm: bool = True, exclude_module_types: Optional[List[str]] = None, From 38fb750859c3692b8c24f10d5af06b66fb790225 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 4 Apr 2022 21:12:48 -0400 Subject: [PATCH 206/218] Revert "Changed default number of activation and weight bits from None to 8." This reverts commit 95e966ed929fa3512331a73667d5ba2ac3d594b1. --- .../sparsification/quantization/modifier_quantization.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 5fd66c58359..210e57df76b 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -118,9 +118,9 @@ class QuantizationModifier(ScheduledModifier): :param quantize_conv_activations: if True, FakeQuantize ops will be run for output activations of convolutional layers. Default is False. :param activation_bits: Number of bits to use for setting quant min/max values for - activations. Default is 8. + activations. Default is None, which will quantize activations to 8 bits. :param weight_bits: Number of bits to use for setting quant min/max values for - weights. Default is 8. + weights. Default is None, which will quantize weights to 8 bits. :param num_calibration_steps: Number of steps to run post training calibration for. When None, the entire calibration_dataloader is used :param exclude_batchnorm: If True, do not propagate quantization qconfigs to @@ -148,8 +148,8 @@ def __init__( reduce_range: bool = False, quantize_linear_activations: bool = False, quantize_conv_activations: bool = False, - activation_bits: int = 8, - weight_bits: int = 8, + activation_bits: Optional[int] = None, + weight_bits: Optional[int] = None, num_calibration_steps: Optional[int] = None, exclude_batchnorm: bool = True, exclude_module_types: Optional[List[str]] = None, From 28c9422fdfdcad9aa18b8a1808a15be256788183 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 4 Apr 2022 21:12:59 -0400 Subject: [PATCH 207/218] Revert "Changed call to get_qat_qconfig to not specify symmetry and data type arguments for default case." This reverts commit a675813f77a6f10f49e18be4d3fdd6a40b6b5a9c. --- .../quantization/modifier_quantization.py | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 210e57df76b..30e1aefbe15 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -588,25 +588,27 @@ def _enable_module_qat(self, module: Module): # quantization. # otherwise, use the default values set in get_qat_qconfig if self.tensorrt: - qconfig = get_qat_qconfig( - symmetric_activations=True, - symmetric_weights=True, - reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=torch.qint8, - weight_dtype=torch.qint8, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, - ) + _symmetric_activations = True + _activation_dtype = torch.qint8 + _symmetric_weights = True + _weight_dtype = torch.qint8 else: - qconfig = get_qat_qconfig( - reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, - ) + _symmetric_activations = None + _activation_dtype = None + _symmetric_weights = None + _weight_dtype = None + + qconfig = get_qat_qconfig( + symmetric_activations=_symmetric_activations, + symmetric_weights=_symmetric_weights, + reduce_range=self._reduce_range, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + activation_dtype=_activation_dtype, + weight_dtype=_weight_dtype, + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, + ) # prepare each module / submodule for quantization for name, quant_module in self._modules_to_quantize: From d687ea8a9937133666d9a5c2cd5111c124b7e61b Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 4 Apr 2022 21:58:14 -0400 Subject: [PATCH 208/218] Lumped qconfig properties into a dataclass. --- .../sparsification/quantization/helpers.py | 411 +++++------------- 1 file changed, 107 insertions(+), 304 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index fa92e5fab46..c83fa57ad5c 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -31,7 +31,7 @@ torch_quantization = None from sparseml.pytorch.nn import ReLU as ReLU_nm - +from dataclasses import dataclass, field __all__ = [ "QATWrapper", @@ -71,6 +71,69 @@ ) +@dataclass +class QConfigProperties: + """ + Dataclass that stores properties needed to define qconfig objects. + Default values set here. + + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware. + Default is False. + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. + :param activation_dtype: quantized activation data type. + Default is torch.quint8. + :param weight_dtype: quantized weights data type. + Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. + """ + _symmetric_activations: Optional[bool] = None + _symmetric_weights: Optional[bool] = None + reduce_range: bool = False + activation_qconfig_kwargs: Dict[str, Any] = field(default_factory=dict) + weight_qconfig_kwargs: Dict[str, Any] = field(default_factory=dict) + activation_dtype: torch.quint8 + weight_dtype: torch.dtype = torch.qint8 + activation_bits: int = 8 + weight_bits: int = 8 + + @property + def symmetric_activations(self) -> bool: + if self._symmetric_activations: + return self._symmetric_activations + else: + return False + + @symmetric_activations.setter + def symmetric_activations(self, value: bool): + if self._symmetric_activations is None: + self._symmetric_activations = value + + @property + def symmetric_weights(self) -> bool: + if self._symmetric_weights: + return self._symmetric_weights + else: + return True + + @symmetric_weights.setter + def symmetric_weights(self, value: bool): + if self._symmetric_weights is None: + self._symmetric_weights = value + + class BNWrapper(Module): """ Wraps BatchNormalization module to expose methods needed to enable @@ -204,51 +267,16 @@ class QATWrapper(Module): QConfig for each output. Instead of a QConfig objects, the string 'asymmetric' or 'symmetric' may be used to use default UINT8 asymmetric and symmetric quantization respectively - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware - Default is False - :param activation_qconfig_kwargs: Additional kwargs for quantization of activations. - Default is {} - :param weight_qconfig_kwargs: Additional kwargs for quantization of weights. - Default is {} + :param qproperties: properties used to define QConfig. """ @staticmethod def from_module( module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + qproperties: QConfigProperties, ) -> "QATWrapper": """ :param module: torch Module to create a QATWrapper for - :param symmetric_activations: if True, activations will have a symmetric - quantization range with a pre-specified zero point - (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). - Default is False. - :param symmetric_weights: if True, weights will have a symmetric - quantization range with a pre-specified zero point - (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). - Default is True. - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware. - Default is False. - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param activation_dtype: quantized activation data type. - Default is torch.quint8. - :param weight_dtype: quantized weights data type. - Default is torch.qint8. - :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. :return: QATWrapper object created using the given Module as the forward function. Will attempt to find any other named parameter of the QATWrapper constructor from the attributes of the given Module @@ -259,68 +287,14 @@ def from_module( else {} ) - qat_wrapper_kwargs["symmetric_activations"] = ( - symmetric_activations - if "symmetric_activations" not in qat_wrapper_kwargs - else symmetric_activations or qat_wrapper_kwargs["symmetric_activations"] - ) - - qat_wrapper_kwargs["symmetric_weights"] = ( - symmetric_weights or False - if "symmetric_weights" not in qat_wrapper_kwargs - else symmetric_weights or qat_wrapper_kwargs["symmetric_weights"] - ) - - qat_wrapper_kwargs["reduce_range"] = ( - reduce_range or False - if "reduce_range" not in qat_wrapper_kwargs - else reduce_range or qat_wrapper_kwargs["reduce_range"] - ) - - qat_wrapper_kwargs["activation_qconfig_kwargs"] = ( - activation_qconfig_kwargs - if "activation_qconfig_kwargs" not in qat_wrapper_kwargs - else activation_qconfig_kwargs - or qat_wrapper_kwargs["activation_qconfig_kwargs"] - ) - - qat_wrapper_kwargs["weight_qconfig_kwargs"] = ( - weight_qconfig_kwargs - if "weight_qconfig_kwargs" not in qat_wrapper_kwargs - else weight_qconfig_kwargs or qat_wrapper_kwargs["weight_qconfig_kwargs"] - ) - - qat_wrapper_kwargs["activation_dtype"] = ( - activation_dtype - if "activation_dtype" not in qat_wrapper_kwargs - else activation_dtype or qat_wrapper_kwargs["activation_dtype"] - ) - - qat_wrapper_kwargs["weight_dtype"] = ( - weight_dtype - if "weight_dtype" not in qat_wrapper_kwargs - else weight_dtype or qat_wrapper_kwargs["weight_dtype"] - ) - - qat_wrapper_kwargs["activation_bits"] = ( - activation_bits - if "activation_bits" not in qat_wrapper_kwargs - else activation_bits or qat_wrapper_kwargs["activation_bits"] - ) - - qat_wrapper_kwargs["weight_bits"] = ( - weight_bits - if "weight_bits" not in qat_wrapper_kwargs - else weight_bits or qat_wrapper_kwargs["weight_bits"] - ) - # Remove qconfig from wrapped layer to avoid duplicate quantization module.qconfig = None - return QATWrapper(forward_fn=module, **qat_wrapper_kwargs) + return QATWrapper(forward_fn=module, qproperties=qproperties, **qat_wrapper_kwargs) def __init__( self, forward_fn: Callable[[Any], Any], + qproperties: QConfigProperties, num_inputs: int = 1, kwarg_input_names: List[str] = None, num_outputs: int = 1, @@ -330,15 +304,6 @@ def __init__( output_qconfigs: Union[ "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] ] = "asymmetric", - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, ): super().__init__() @@ -358,43 +323,18 @@ def __init__( num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) self.forward_fn = forward_fn - self._symmetric_activations = symmetric_activations - self._symmetric_weights = symmetric_weights - self._reduce_range = reduce_range - self._activation_qconfig_kwargs = activation_qconfig_kwargs - self._weight_qconfig_kwargs = weight_qconfig_kwargs - self._activation_dtype = activation_dtype - self._weight_dtype = weight_dtype - self._activation_bits = activation_bits - self._weight_bits = weight_bits self.input_qconfigs = self._load_qconfigs( name="input_qconfigs", expected_len=num_input_quant_stubs, qconfigs=input_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, - reduce_range=self._reduce_range, - activation_qconfig_kwargs=self._activation_qconfig_kwargs, - weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, + qproperties=qproperties, ) self.output_qconfigs = self._load_qconfigs( name="output_qconfigs", expected_len=num_outputs, qconfigs=output_qconfigs, - symmetric_activations=self._symmetric_activations, - symmetric_weights=self._symmetric_weights, - reduce_range=self._reduce_range, - activation_qconfig_kwargs=self._activation_qconfig_kwargs, - weight_qconfig_kwargs=self._weight_qconfig_kwargs, - activation_dtype=self._activation_dtype, - weight_dtype=self._weight_dtype, - activation_bits=self._activation_bits, - weight_bits=self._weight_bits, + qproperties=qproperties, ) self.input_quant_stubs = torch.nn.ModuleList( @@ -478,15 +418,7 @@ def _load_qconfigs( name: str, expected_len: int, qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + qproperties: QConfigProperties, ): if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): raise ValueError( @@ -516,22 +448,9 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) - if symmetric_activations is None: - _symmetric_activations = qconfig == "symmetric" - else: - _symmetric_activations = symmetric_activations - - qconfigs[idx] = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=symmetric_weights, - reduce_range=reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, - ) + qproperties.symmetric_activations = qconfig == "symmetric" + + qconfigs[idx] = get_qat_qconfig(qproperties) return qconfigs @@ -558,15 +477,7 @@ def configure_module_bn_wrappers(module: Module): def configure_module_qat_wrappers( module: Module, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Dict[str, Any] = {}, - weight_qconfig_kwargs: Dict[str, Any] = {}, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + qproperties: QConfigProperties, ): """ if any submodule of the given module has the attribute wrap_qat == True, @@ -575,25 +486,8 @@ def configure_module_qat_wrappers( under an attributed named `qat_wrapper_kwargs` :param module: module to potentially wrap the submodules of - :param symmetric_activations: if True, activations will have a symmetric - quantization range with a pre-specified zero point - (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). - Default is False. - :param symmetric_weights: if True, weights will have a symmetric - quantization range with a pre-specified zero point - (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). - Default is True. - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware. - Default is False. - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. - :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8.""" + :param qproperties: properties used to define QConfig. + """ # wrap any children of the given module as a QATWrapper if required for child_name, child_module in module.named_children(): if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: @@ -602,41 +496,24 @@ def configure_module_qat_wrappers( child_name, QATWrapper.from_module( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, - reduce_range=reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, + qproperties=qproperties, ), ) # recurse on child module configure_module_qat_wrappers( module=child_module, - symmetric_activations=symmetric_activations, - symmetric_weights=symmetric_weights, - reduce_range=reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, + qproperties=qproperties, ) -def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): +def compute_range(dtype: torch.dtype, bits: int): """ compute quantization limits depending on data type and number of bits - :param dtype: data type. If None dtype is set to torch.quint8. - :param bits: number of bits. If None is set to 8. - :return: minimum limit, maximum limit, data type + :param dtype: data type. + :param bits: number of bits. + :return: minimum limit, maximum limit """ - dtype = dtype if dtype else torch.quint8 bits = bits if bits else 8 if dtype == torch.qint8: quant_min = -(2 ** (bits - 1)) @@ -645,7 +522,7 @@ def compute_range(dtype: Optional[torch.dtype], bits: Optional[int]): quant_min = 0 quant_max = (2 ** bits) - 1 - return quant_min, quant_max, dtype + return quant_min, quant_max def configure_module_default_qconfigs(module: Module): @@ -707,62 +584,27 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ ) -def get_qat_qconfig( - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, +def get_qat_qconfig(qproperties: QConfigProperties ) -> "torch.quantization.QConfig": """ - :param symmetric_activations: if True, activations will have a symmetric - quantization range with a pre-specified zero point - (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). - Default is False. - :param symmetric_weights: if True, weights will have a symmetric - quantization range with a pre-specified zero point - (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). - Default is True. - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware. - Default is False. - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. - :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8. - :return: A QAT fake quantization config for symmetric weight quantization and - asymmetric activation quantization. The difference between this and - torch.quantization.default_qat_qconfig is that the activation observer - will not have reduce_range enabled. + :param qproperties: properties used to define QConfig. """ activation_observer = get_observer( - symmetric_activations, - activation_dtype, - activation_bits, - reduce_range, - activation_qconfig_kwargs, + qproperties.symmetric_activations, + qproperties.activation_dtype, + qproperties.activation_bits, + qproperties.reduce_range, + qproperties.activation_qconfig_kwargs, ) - if symmetric_weights is None: - _symmetric_weights = True - else: - _symmetric_weights = symmetric_weights - - if weight_dtype is None: - _weight_dtype = torch.qint8 - else: - _weight_dtype = weight_dtype weight_observer = get_observer( - _symmetric_weights, _weight_dtype, weight_bits, False, weight_qconfig_kwargs + qproperties.symmetric_weights, + qproperties.weight_dtype, + qproperties.weight_bits, + False, + qproperties.weight_qconfig_kwargs ) + return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, @@ -770,14 +612,14 @@ def get_qat_qconfig( def get_observer( - symmetric: Optional[bool], - dtype: Optional[torch.dtype], - bits: Optional[int], + symmetric: bool, + dtype: torch.dtype, + bits: int, reduce_range: bool, qconfig_kwargs: Dict[str, Any], ): qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine - quant_min, quant_max, dtype = compute_range(dtype, bits) + quant_min, quant_max = compute_range(dtype, bits) observer_kwargs = dict( observer=torch_quantization.MovingAverageMinMaxObserver, quant_min=quant_min, @@ -915,16 +757,8 @@ def fuse_module_conv_bn_relus( def prepare_embeddings_qat( module: Module, - qconfig: "torch.quantization.QConfig" = None, - symmetric_activations: Optional[bool] = None, - symmetric_weights: Optional[bool] = None, - reduce_range: bool = False, - activation_qconfig_kwargs: Optional[Dict[str, Any]] = None, - weight_qconfig_kwargs: Optional[Dict[str, Any]] = None, - activation_dtype: Optional[torch.dtype] = None, - weight_dtype: Optional[torch.dtype] = None, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + qproperties: QConfigProperties, + qconfig: Optional["torch.quantization.QConfig"] = None, ): """ adds a fake quantize call to the weights of any Embedding modules in the given @@ -933,42 +767,11 @@ def prepare_embeddings_qat( :param module: module to run QAT for the embeddings of :param qconfig: qconfig to generate the fake quantize ops from. Default uses INT8 asymmetric range - :param symmetric_activations: if True, activations will have a symmetric - quantization range with a pre-specified zero point - (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). - Default is False. - :param symmetric_weights: if True, weights will have a symmetric - quantization range with a pre-specified zero point - (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). - Default is True. - :param reduce_range: if True, the quantization range will be reduced by one bit. - This may prevent overflow issues with model execution on certain hardware. - Default is False. - :param activation_qconfig_kwargs: Additional kwargs for quantization of - activations. - :param weight_qconfig_kwargs: Additional kwargs for quantization of - weights. - :param activation_dtype: quantized activation data type. Default is torch.quint8. - :param weight_dtype: quantized weights data type. Default is torch.qint8. - :param activation_bits: number of bits for activations. Default is 8. - :param weight_bits: number of bits for weights. Default is 8.""" + :param qproperties: properties used to define QConfig. + """ if qconfig is None: - if symmetric_weights is None: - _symmetric_weights = False - else: - _symmetric_weights = symmetric_weights - - qconfig = get_qat_qconfig( - symmetric_activations=symmetric_activations, - symmetric_weights=_symmetric_weights, - reduce_range=reduce_range, - activation_qconfig_kwargs=activation_qconfig_kwargs, - weight_qconfig_kwargs=weight_qconfig_kwargs, - activation_dtype=activation_dtype, - weight_dtype=weight_dtype, - activation_bits=activation_bits, - weight_bits=weight_bits, - ) + qproperties.symmetric_weights = False + qconfig = get_qat_qconfig(qproperties) for submodule in module.modules(): if type(submodule) is Embedding: _prepare_qat_embedding(submodule, qconfig) From 94509ce16a384a9ff142fd11c35966684a25dc6b Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 4 Apr 2022 22:36:36 -0400 Subject: [PATCH 209/218] Lumped qconfig properties into a dataclass. --- .../sparsification/quantization/helpers.py | 39 +++++++-- .../quantization/modifier_quantization.py | 79 ++++--------------- 2 files changed, 47 insertions(+), 71 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index c83fa57ad5c..555d6936d15 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -30,9 +30,11 @@ nni = None torch_quantization = None -from sparseml.pytorch.nn import ReLU as ReLU_nm from dataclasses import dataclass, field +from sparseml.pytorch.nn import ReLU as ReLU_nm + + __all__ = [ "QATWrapper", "configure_module_bn_wrappers", @@ -45,6 +47,25 @@ "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", + "QConfigProperties", + "LINEAR_ACTIVATION_NAMES", + "CONV_ACTIVATION_NAMES", +] + +LINEAR_ACTIVATION_NAMES = ["Linear", "LinearReLU"] +CONV_ACTIVATION_NAMES = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", ] _QUANTIZABLE_MODULE_TYPES = ( @@ -99,15 +120,16 @@ class QConfigProperties: :param activation_bits: number of bits for activations. Default is 8. :param weight_bits: number of bits for weights. Default is 8. """ + _symmetric_activations: Optional[bool] = None _symmetric_weights: Optional[bool] = None reduce_range: bool = False - activation_qconfig_kwargs: Dict[str, Any] = field(default_factory=dict) - weight_qconfig_kwargs: Dict[str, Any] = field(default_factory=dict) - activation_dtype: torch.quint8 + activation_dtype: torch.dtype = torch.quint8 weight_dtype: torch.dtype = torch.qint8 activation_bits: int = 8 weight_bits: int = 8 + activation_qconfig_kwargs: Dict[str, Any] = field(default_factory=dict) + weight_qconfig_kwargs: Dict[str, Any] = field(default_factory=dict) @property def symmetric_activations(self) -> bool: @@ -289,7 +311,9 @@ def from_module( # Remove qconfig from wrapped layer to avoid duplicate quantization module.qconfig = None - return QATWrapper(forward_fn=module, qproperties=qproperties, **qat_wrapper_kwargs) + return QATWrapper( + forward_fn=module, qproperties=qproperties, **qat_wrapper_kwargs + ) def __init__( self, @@ -584,8 +608,7 @@ def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[ ) -def get_qat_qconfig(qproperties: QConfigProperties -) -> "torch.quantization.QConfig": +def get_qat_qconfig(qproperties: QConfigProperties) -> "torch.quantization.QConfig": """ :param qproperties: properties used to define QConfig. """ @@ -602,7 +625,7 @@ def get_qat_qconfig(qproperties: QConfigProperties qproperties.weight_dtype, qproperties.weight_bits, False, - qproperties.weight_qconfig_kwargs + qproperties.weight_qconfig_kwargs, ) return torch_quantization.QConfig( diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 30e1aefbe15..a9190382b30 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -47,6 +47,9 @@ from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.sparsification.quantization.helpers import ( + CONV_ACTIVATION_NAMES, + LINEAR_ACTIVATION_NAMES, + QConfigProperties, add_quant_dequant, configure_module_bn_wrappers, configure_module_default_qconfigs, @@ -555,25 +558,11 @@ def _enable_module_qat(self, module: Module): # build list of layer types that should not quantize output activations to_remove_layer_name = [] if not self._quantize_linear_activations: - to_remove_layer_name.extend(["Linear", "LinearReLU"]) + to_remove_layer_name.extend(LINEAR_ACTIVATION_NAMES) if not self._quantize_conv_activations: - to_remove_layer_name.extend( - [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvBn1d", - "ConvBn2d", - "ConvBn3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - ] - ) + to_remove_layer_name.extend(CONV_ACTIVATION_NAMES) + if len(to_remove_layer_name) == 0: to_remove_layer_name = None @@ -586,45 +575,20 @@ def _enable_module_qat(self, module: Module): # set qconfig. # if tensorrt flag is used, set activation and weights to symmetric # quantization. - # otherwise, use the default values set in get_qat_qconfig + # otherwise, use the default values set in QConfigProperties + qproperties = QConfigProperties() if self.tensorrt: - _symmetric_activations = True - _activation_dtype = torch.qint8 - _symmetric_weights = True - _weight_dtype = torch.qint8 - else: - _symmetric_activations = None - _activation_dtype = None - _symmetric_weights = None - _weight_dtype = None - - qconfig = get_qat_qconfig( - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, - reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, - ) + qproperties.symmetric_activations = True + qproperties.activation_dtype = torch.qint8 + qproperties.symmetric_weights = True + qproperties.weight_dtype = torch.qint8 + + qconfig = get_qat_qconfig(qproperties) # prepare each module / submodule for quantization for name, quant_module in self._modules_to_quantize: # wrap any modules with wrap_qat set to True as QATWrapper(s) - configure_module_qat_wrappers( - quant_module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, - reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, - ) + configure_module_qat_wrappers(quant_module, qproperties) # set quantization config (asymmetric activations, symmetric weights) quant_module.qconfig = qconfig @@ -656,18 +620,7 @@ def _enable_module_qat(self, module: Module): # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) if self._quantize_embeddings: - prepare_embeddings_qat( - module, - symmetric_activations=_symmetric_activations, - symmetric_weights=_symmetric_weights, - reduce_range=self._reduce_range, - activation_qconfig_kwargs=self.activation_qconfig_kwargs, - weight_qconfig_kwargs=self.weight_qconfig_kwargs, - activation_dtype=_activation_dtype, - weight_dtype=_weight_dtype, - activation_bits=self.activation_bits, - weight_bits=self.weight_bits, - ) + prepare_embeddings_qat(module, qproperties) # propagate custom quant min/max range from FakeQuantize to Observer objects fix_observer_quant_range(module) From 9786433135defa0f481553a6eab978a09382530f Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 4 Apr 2022 22:36:54 -0400 Subject: [PATCH 210/218] Lumped qconfig properties into a dataclass. --- sandbox/quantization_recipe.yaml | 7 +++++++ sandbox/quantization_test.py | 23 +++++++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 sandbox/quantization_recipe.yaml create mode 100644 sandbox/quantization_test.py diff --git a/sandbox/quantization_recipe.yaml b/sandbox/quantization_recipe.yaml new file mode 100644 index 00000000000..411dd6f025a --- /dev/null +++ b/sandbox/quantization_recipe.yaml @@ -0,0 +1,7 @@ +quantization_modifiers: + - !QuantizationModifier + start_epoch: -1.0 + model_fuse_fn_name: no_fuse + submodules: + - input + - sections diff --git a/sandbox/quantization_test.py b/sandbox/quantization_test.py new file mode 100644 index 00000000000..ea6fba5acd5 --- /dev/null +++ b/sandbox/quantization_test.py @@ -0,0 +1,23 @@ +import torch +from sparseml.pytorch.utils import ModuleExporter +from sparseml.pytorch.models import ModelRegistry +from sparseml.pytorch.optim import ScheduledModifierManager + +model = ModelRegistry.create( + key='resnet50', + pretrained=False, + pretrained_dataset="imagenet", + num_classes=1000 +) + + +ScheduledModifierManager.from_yaml("quantization_recipe.yaml").apply(model, epoch=float("inf")) + +print(model) + +exporter = ModuleExporter(model, ".") +exporter.export_onnx( + torch.randn(1, 3, 224, 224), + "quantized_test.onnx", + convert_qat=False, +) \ No newline at end of file From 6922bf98a19ed94de18c82ac4ecc0710ed25aed2 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Tue, 5 Apr 2022 13:03:10 -0400 Subject: [PATCH 211/218] Resetting conv and linear activation flags to True. --- .../quantization/modifier_quantization.py | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index a9190382b30..8a5d4d60795 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -117,13 +117,13 @@ class QuantizationModifier(ScheduledModifier): This may prevent overflow issues with model execution on certain hardware Default is False :param quantize_linear_activations: if True, FakeQuantize ops will be run - for output activations of fully connected layers. Default is False. + for output activations of fully connected layers. Default is True. :param quantize_conv_activations: if True, FakeQuantize ops will be run - for output activations of convolutional layers. Default is False. + for output activations of convolutional layers. Default is True. :param activation_bits: Number of bits to use for setting quant min/max values for - activations. Default is None, which will quantize activations to 8 bits. + activations. Default 8. :param weight_bits: Number of bits to use for setting quant min/max values for - weights. Default is None, which will quantize weights to 8 bits. + weights. Default is 8. :param num_calibration_steps: Number of steps to run post training calibration for. When None, the entire calibration_dataloader is used :param exclude_batchnorm: If True, do not propagate quantization qconfigs to @@ -149,10 +149,10 @@ def __init__( model_fuse_fn_kwargs: Dict[str, Any] = None, quantize_embeddings: bool = True, reduce_range: bool = False, - quantize_linear_activations: bool = False, - quantize_conv_activations: bool = False, - activation_bits: Optional[int] = None, - weight_bits: Optional[int] = None, + quantize_linear_activations: bool = True, + quantize_conv_activations: bool = True, + activation_bits: int = 8, + weight_bits: int = 8, num_calibration_steps: Optional[int] = None, exclude_batchnorm: bool = True, exclude_module_types: Optional[List[str]] = None, @@ -340,7 +340,10 @@ def quantize_linear_activations(self) -> bool: :return: if True, FakeQuantize ops will be run for output activations of fully connected layers """ - return self._quantize_linear_activations + if self._tensorrt: + return False + else: + return self._quantize_linear_activations @ModifierProp() def quantize_conv_activations(self) -> bool: @@ -348,7 +351,10 @@ def quantize_conv_activations(self) -> bool: :return: if True, FakeQuantize ops will be run for output activations of convolutional layers """ - return self._quantize_conv_activations + if self._tensorrt: + return False + else: + return self._quantize_conv_activations @ModifierProp() def exclude_module_types(self) -> Union[List[str], None]: @@ -557,10 +563,10 @@ def _enable_module_qat(self, module: Module): # build list of layer types that should not quantize output activations to_remove_layer_name = [] - if not self._quantize_linear_activations: + if not self.quantize_linear_activations: to_remove_layer_name.extend(LINEAR_ACTIVATION_NAMES) - if not self._quantize_conv_activations: + if not self.quantize_conv_activations: to_remove_layer_name.extend(CONV_ACTIVATION_NAMES) if len(to_remove_layer_name) == 0: @@ -576,7 +582,13 @@ def _enable_module_qat(self, module: Module): # if tensorrt flag is used, set activation and weights to symmetric # quantization. # otherwise, use the default values set in QConfigProperties - qproperties = QConfigProperties() + qproperties = QConfigProperties( + activation_bits=self.activation_bits, + weight_bits=self.weight_bits, + activation_qconfig_kwargs=self.activation_qconfig_kwargs, + weight_qconfig_kwargs=self.weight_qconfig_kwargs, + reduce_range=self.reduce_range, + ) if self.tensorrt: qproperties.symmetric_activations = True qproperties.activation_dtype = torch.qint8 From e7e47be2d303485c22eeb6fda2026e67f88a342c Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 8 Apr 2022 11:35:12 -0400 Subject: [PATCH 212/218] Renamed class BNWrapper as _BNWrapper. --- .../sparsification/quantization/helpers.py | 216 +++++++++--------- 1 file changed, 108 insertions(+), 108 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 555d6936d15..b5c54ab309c 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -156,113 +156,6 @@ def symmetric_weights(self, value: bool): self._symmetric_weights = value -class BNWrapper(Module): - """ - Wraps BatchNormalization module to expose methods needed to enable - freezing/unfreezing of statistics - - :param module: BatchNormalization module to be wrapped - """ - - def __init__(self, module: Module): - super().__init__() - self.bn = module - self.freeze_bn = False - - @property - def running_mean(self): - return self.bn.running_mean - - @running_mean.setter - def running_mean(self, value): - self.bn.running_mean = value - - @property - def running_var(self): - return self.bn.running_var - - @running_var.setter - def running_var(self, value): - self.bn.running_var = value - - @property - def weight(self): - return self.bn.weight - - @weight.setter - def weight(self, value): - self.bn.weight = value - - @property - def bias(self): - return self.bn.bias - - @bias.setter - def bias(self, value): - self.bn.bias = value - - @property - def gamma(self): - return self.bn.gamma - - @gamma.setter - def gamma(self, value): - self.bn.gamma = value - - @property - def beta(self): - return self.bn.beta - - @beta.setter - def beta(self, value): - self.bn.beta = value - - @property - def num_batches_tracked(self): - return self.bn.num_batches_tracked - - @num_batches_tracked.setter - def num_batches_tracked(self, value): - self.bn.num_batches_tracked = value - - @property - def eps(self): - return self.bn.eps - - @eps.setter - def eps(self, value): - self.bn.eps = value - - @property - def momentum(self): - return self.bn.momentum - - @momentum.setter - def momentum(self, value): - self.bn.momentum = value - - def forward(self, x): - return self.bn(x) - - def freeze_bn_stats(self): - self.freeze_bn = True - self.bn.training = False - return self - - def reset_running_stats(self): - self.bn.reset_running_stats() - - def train(self, mode=True): - if not self.freeze_bn: - self.bn.train(mode) - return self - - def update_bn_stats(self): - self.freeze_bn = False - self.bn.training = True - return self - - class QATWrapper(Module): """ Wraps inputs and outputs of a Module or function with QuantStubs for @@ -494,7 +387,7 @@ def configure_module_bn_wrappers(module: Module): torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, ]: - setattr(module, child_name, BNWrapper(child_module)) + setattr(module, child_name, _BNWrapper(child_module)) # recurse on child module configure_module_bn_wrappers(child_module) @@ -834,3 +727,110 @@ def _wrap_bn_sub_class(bn_subclass, override_forward=True): batch_norm.forward = bn_subclass.forward del bn_subclass return batch_norm + + +class _BNWrapper(Module): + """ + Wraps BatchNormalization module to expose methods needed to enable + freezing/unfreezing of statistics + + :param module: BatchNormalization module to be wrapped + """ + + def __init__(self, module: Module): + super().__init__() + self.bn = module + self.freeze_bn = False + + @property + def running_mean(self): + return self.bn.running_mean + + @running_mean.setter + def running_mean(self, value): + self.bn.running_mean = value + + @property + def running_var(self): + return self.bn.running_var + + @running_var.setter + def running_var(self, value): + self.bn.running_var = value + + @property + def weight(self): + return self.bn.weight + + @weight.setter + def weight(self, value): + self.bn.weight = value + + @property + def bias(self): + return self.bn.bias + + @bias.setter + def bias(self, value): + self.bn.bias = value + + @property + def gamma(self): + return self.bn.gamma + + @gamma.setter + def gamma(self, value): + self.bn.gamma = value + + @property + def beta(self): + return self.bn.beta + + @beta.setter + def beta(self, value): + self.bn.beta = value + + @property + def num_batches_tracked(self): + return self.bn.num_batches_tracked + + @num_batches_tracked.setter + def num_batches_tracked(self, value): + self.bn.num_batches_tracked = value + + @property + def eps(self): + return self.bn.eps + + @eps.setter + def eps(self, value): + self.bn.eps = value + + @property + def momentum(self): + return self.bn.momentum + + @momentum.setter + def momentum(self, value): + self.bn.momentum = value + + def forward(self, x): + return self.bn(x) + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def train(self, mode=True): + if not self.freeze_bn: + self.bn.train(mode) + return self + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self From ffd819506378f4e419c6fc03835d1502da7405ee Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 8 Apr 2022 12:43:07 -0400 Subject: [PATCH 213/218] Added logging messages for when tensorrt forces overriding of configs. --- .../quantization/modifier_quantization.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 8a5d4d60795..90365512325 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -244,7 +244,8 @@ def model_fuse_fn_name(self) -> Union[str, None]: If tensorrt flag is True, default is 'no_fuse', otherwise `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ - if self._tensorrt: + if self.tensorrt: + _LOGGER.info("Overriding model_fuse_fn_name to False because tensorrt flag is True.") fuse_fn = ( self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" ) @@ -340,7 +341,8 @@ def quantize_linear_activations(self) -> bool: :return: if True, FakeQuantize ops will be run for output activations of fully connected layers """ - if self._tensorrt: + if self.tensorrt: + _LOGGER.info("Overriding quantize_linear_activations to False because tensorrt flag is True.") return False else: return self._quantize_linear_activations @@ -351,7 +353,8 @@ def quantize_conv_activations(self) -> bool: :return: if True, FakeQuantize ops will be run for output activations of convolutional layers """ - if self._tensorrt: + if self.tensorrt: + _LOGGER.info("Overriding quantize_conv_activations to False because tensorrt flag is True.") return False else: return self._quantize_conv_activations @@ -416,11 +419,10 @@ def num_calibration_steps(self) -> Optional[int]: return self._num_calibration_steps @ModifierProp() - def tensorrt(self) -> Dict[str, Any]: + def tensorrt(self) -> bool: """ - :return: Dictionary with correct quant_min, quant_max, and dtype values - for activations - + :return: boolean. When set to True overrides quantization configs + to be compatible with TensorRT. """ return self._tensorrt @@ -590,6 +592,7 @@ def _enable_module_qat(self, module: Module): reduce_range=self.reduce_range, ) if self.tensorrt: + _LOGGER.info("Overriding quantization scheme to symmetric int8 for both weights and activations because tensorrt flag is True.") qproperties.symmetric_activations = True qproperties.activation_dtype = torch.qint8 qproperties.symmetric_weights = True From 808e4293cb89c0b2b81608c1bb3a7aea8fdad2f0 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 8 Apr 2022 12:45:08 -0400 Subject: [PATCH 214/218] Style and quality fixes. --- .../quantization/modifier_quantization.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 90365512325..5f588fa27aa 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -245,7 +245,9 @@ def model_fuse_fn_name(self) -> Union[str, None]: `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ if self.tensorrt: - _LOGGER.info("Overriding model_fuse_fn_name to False because tensorrt flag is True.") + _LOGGER.info( + "Overriding model_fuse_fn_name to False because tensorrt flag is True." + ) fuse_fn = ( self._model_fuse_fn_name if self._model_fuse_fn_name else "no_fuse" ) @@ -342,7 +344,10 @@ def quantize_linear_activations(self) -> bool: of fully connected layers """ if self.tensorrt: - _LOGGER.info("Overriding quantize_linear_activations to False because tensorrt flag is True.") + _LOGGER.info( + "Overriding quantize_linear_activations to False " + "because tensorrt flag is True." + ) return False else: return self._quantize_linear_activations @@ -354,7 +359,10 @@ def quantize_conv_activations(self) -> bool: of convolutional layers """ if self.tensorrt: - _LOGGER.info("Overriding quantize_conv_activations to False because tensorrt flag is True.") + _LOGGER.info( + "Overriding quantize_conv_activations to False " + "because tensorrt flag is True." + ) return False else: return self._quantize_conv_activations @@ -592,7 +600,10 @@ def _enable_module_qat(self, module: Module): reduce_range=self.reduce_range, ) if self.tensorrt: - _LOGGER.info("Overriding quantization scheme to symmetric int8 for both weights and activations because tensorrt flag is True.") + _LOGGER.info( + "Overriding quantization scheme to symmetric int8 " + "for both weights and activations because tensorrt flag is True." + ) qproperties.symmetric_activations = True qproperties.activation_dtype = torch.qint8 qproperties.symmetric_weights = True From 1a579099c8a33ea4dd3ba5421aa84dd5bd9efe41 Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Fri, 8 Apr 2022 12:51:42 -0400 Subject: [PATCH 215/218] ConvInteger quantization conversion for quant refactor (#644) * ConvInteger quantization conversion for quant refactor * [quantization-refactor] mark/propagate conv export mode (#672) * batch norm fold with existing bias param bug fix --- .../quantization/modifier_quantization.py | 6 + .../quantization/quantize_qat_export.py | 251 +++++++++++++++--- src/sparseml/pytorch/utils/exporter.py | 10 +- 3 files changed, 227 insertions(+), 40 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 5f588fa27aa..648c93e35a4 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -654,6 +654,12 @@ def _enable_module_qat(self, module: Module): self._qat_enabled = True self._calibrate_if_possible(module) + # mark export mode for module Conv layers + module.export_with_qlinearconv = self._quantize_conv_activations + if hasattr(module, "module"): + # for DP/DDP unwrapping + module.module.export_with_qlinearconv = self._quantize_conv_activations + def _calibrate_if_possible(self, module): if self.num_calibration_steps == 0 and self._calibration_dataloader: warnings.warn( diff --git a/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py b/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py index 1ed660bdd3a..24a2bdff87b 100644 --- a/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py +++ b/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py @@ -63,7 +63,13 @@ _QUANTIZE_OP_NAMES = ["QuantizeLinear", "DequantizeLinear"] -_QLINEAR_OP_NAMES = ["QLinearConv", "QLinearMatMul", "QLinearAdd"] +KEEP_QUANT_INPUT_OPS = [ + "Add", + "ConvInteger", + "MatMulInteger, " "QLinearConv", + "QLinearMatMul", + "QLinearAdd", +] def get_quantization_params( @@ -160,7 +166,10 @@ def _fold_conv_bn_bias(model: ModelProto, conv_node: NodeProto, bn_node: NodePro folded_bias = folded_bias.astype(numpy.float32) bias_name = conv_node.name + ".bias" - conv_node.input.append(bias_name) + if len(conv_node.input) > 2: + conv_node.input[2] = bias_name + else: + conv_node.input.append(bias_name) update_model_param(model, bias_name, folded_bias) # forward conv output to bn children @@ -656,9 +665,9 @@ def _convert_quantizable_matmul(model: ModelProto): ) -def _add_quantized_matmul_add_ops( +def _add_quantized_conv_matmul_add_ops( model: ModelProto, - matmul_node: NodeProto, + node: NodeProto, input_quantize_node: NodeProto, weight_quantize_node: NodeProto, input_quantize_params: QuantizationParams, @@ -670,10 +679,11 @@ def _add_quantized_matmul_add_ops( output_quantize_node: Optional[NodeProto] = None, output_dequantize_node: Optional[NodeProto] = None, ): - # helper function for conversion of qat parameterized gemms/matmuls to - # matmul integer add blocks. Adds new quantized ops to graph, does not + # helper function for conversion of qat parameterized gemms, matmuls, + # or convs to conv/matmul integer add blocks. + # Adds new quantized ops to graph, does not # perform any checks or deletions (should be called by the operator main - # conversion function + # conversion function) # quantize weight quantized_weight = _quantize_array( @@ -683,31 +693,46 @@ def _add_quantized_matmul_add_ops( ) if transpose_weight: quantized_weight = quantized_weight.transpose() - quantized_weight_name = "{}.weight_quantized".format(matmul_node.name) + quantized_weight_name = "{}.weight_quantized".format(node.name) quantized_weight_initializer = numpy_helper.from_array( quantized_weight, name=quantized_weight_name ) model.graph.initializer.append(quantized_weight_initializer) - # MatMulInteger - # get matmulinteger inputs and outputs - matmul_integer_inputs = [ - input_quantize_node.input[0], # A matrix (replaces previous dequant node) - quantized_weight_name, # B matrix (quantized weight) - input_quantize_node.input[2], # a_zero_point - weight_quantize_node.input[2], # b_zero_point + # MatMulInteger/ConvInteger + # get inputs and outputs + integer_op_inputs = [ + input_quantize_node.input[0], # input matrix (replaces previous dequant node) + quantized_weight_name, # quantized weight + input_quantize_node.input[2], # input zero point + weight_quantize_node.input[2], # weight zero point ] - matmul_integer_output = "{}_quant".format(matmul_node.output[0]) - matmul_integer_name = "{}_quant".format(matmul_node.name) - - # create qmatmul node and add it to graph - matmul_integer_node = onnx.helper.make_node( - "MatMulInteger", - matmul_integer_inputs, - [matmul_integer_output], - matmul_integer_name, - ) - model.graph.node.append(matmul_integer_node) + integer_op_output = "{}_quant".format(node.output[0]) + integer_op_name = "{}_quant".format(node.name) + + # create MatMulInteger/ConvInteger node and add it to graph + if node.op_type == "Conv": + # get conv attributes as kwargs + conv_kwargs = {} + for attribute in node.attribute: + conv_kwargs.update(_attribute_to_kwarg(attribute)) + + # create ConvInteger node and add it to graph + integer_op_node = onnx.helper.make_node( + "ConvInteger", + integer_op_inputs, + [integer_op_output], + integer_op_name, + **conv_kwargs, + ) + else: + integer_op_node = onnx.helper.make_node( + "MatMulInteger", + integer_op_inputs, + [integer_op_output], + integer_op_name, + ) + model.graph.node.append(integer_op_node) # Add bias + zero point correction # quantize bias @@ -717,6 +742,9 @@ def _add_quantized_matmul_add_ops( quantized_bias = _quantize_array( bias_initializer, bias_scale, bias_zero_point, dtype=numpy.int32 ) + if node.op_type == "Conv" and len(quantized_bias.shape) == 1: + # reshape for bias add broadcasting + quantized_bias = quantized_bias.reshape(1, quantized_bias.shape[0], 1, 1) quantized_bias_name = "{}.bias_quantized".format(bias_add_name) quantized_bias_initializer = numpy_helper.from_array( @@ -739,11 +767,11 @@ def _add_quantized_matmul_add_ops( # get INT32 Add inputs and outputs quant_add_inputs = [ - matmul_integer_output, # MatMul integer outputs (INT32) + integer_op_output, # MatMul/Conv integer outputs (INT32) quantized_bias_name, # Quantized bias (INT32) ] - quant_add_name = "{}_bias_add_quant".format(matmul_node.name) + quant_add_name = "{}_bias_add_quant".format(node.name) quant_add_output = ( output_quantize_node.output[0] if output_quantize_node @@ -875,9 +903,9 @@ def _convert_quantizable_gemm_no_activations(model: ModelProto): _LOGGER.debug(f"Matched quantizable Gemm weight and bias: {gemm_node.name}") # Conversion - _add_quantized_matmul_add_ops( + _add_quantized_conv_matmul_add_ops( model=model, - matmul_node=gemm_node, + node=gemm_node, input_quantize_node=input_quantize_node, weight_quantize_node=weight_quantize_node, input_quantize_params=input_quantize_params, @@ -1021,9 +1049,9 @@ def _convert_quantizable_matmul_and_add(model: ModelProto): _LOGGER.debug(f"Matched quantizable MatMul weight and bias: {matmul_node.name}") # Conversion - _add_quantized_matmul_add_ops( + _add_quantized_conv_matmul_add_ops( model=model, - matmul_node=matmul_node, + node=matmul_node, input_quantize_node=input_quantize_node, weight_quantize_node=weight_quantize_node, input_quantize_params=input_quantize_params, @@ -1071,6 +1099,130 @@ def _convert_quantizable_matmul_and_add(model: ModelProto): graph.delete_unused_initializers() +def _convert_quantizable_conv_integer(model: ModelProto): + """ + A pass for converting a Conv op with kernel whose activations + are not necessarily quantized into a ConvInteger followed by + a bias add and cast to FP32 + + | Starting with: + | INPUT QuantizeLinear (with constant kernel) + | | | + | QuantizeLinear DequantizeLinear + | | | + | DequantizeLinear | + | | | + | Conv (with bias) + | | + | OUTPUT + | We end up converting to: + | INPUT + | | + | QuantizeLinear + | | + | ConvInteger (with constant uint8 kernel) + | | + | Add (constant bias + zero point correction) + | | + | Cast (INT32 -> FP32) + | | + | Mul (Rescale from bias scale) + | | + | OUTPUT + """ + + conversion_count = 0 + conv_nodes = [n for n in model.graph.node if n.op_type in ["Conv"]] + orig_conv_weight_name_to_node_ids = defaultdict(list) + for conv_node in conv_nodes: + if len(conv_node.input) != 3: + # this function currently only converts Conv nodes with bias param + # (i.e. from folded batch norm value) + continue + + graph = ONNXGraph(model) + + ############# + # Matching + ############# + weight_dequantize_node = graph.get_node_single_parent(conv_node, 1) + if ( + not weight_dequantize_node + or weight_dequantize_node.op_type != "DequantizeLinear" + ): + continue + weight_quantize_node = graph.get_node_single_parent(weight_dequantize_node, 0) + if not weight_quantize_node or weight_quantize_node.op_type != "QuantizeLinear": + continue + + input_quantize_node = graph.get_node_single_parent(conv_node, 0) + if ( + not input_quantize_node + or input_quantize_node.op_type not in _QUANTIZE_OP_NAMES + ): + continue + + input_quantize_params = get_quantization_params( + model, input_quantize_node, include_target=False + ) + weight_quantize_params = get_quantization_params( + model, weight_quantize_node, include_target=True + ) + if weight_quantize_params.target is None: + # weight initializer not included + continue + if input_quantize_node.op_type != "DequantizeLinear": + continue + + bias_initializer = graph.get_init_by_name(conv_node.input[2]) + if bias_initializer is None: + _LOGGER.debug(f"Unable to find bias initializer: {conv_node.input[2]}") + continue + + _LOGGER.debug(f"Matched quantizable Conv weight and bias: {conv_node.name}") + + # Conversion + _add_quantized_conv_matmul_add_ops( + model=model, + node=conv_node, + input_quantize_node=input_quantize_node, + weight_quantize_node=weight_quantize_node, + input_quantize_params=input_quantize_params, + weight_quantize_params=weight_quantize_params, + bias_initializer=bias_initializer, + bias_add_name="{}_bias_add".format(conv_node.name), + target_output=conv_node.output[0], + transpose_weight=False, + ) + orig_conv_weight_name_to_node_ids[input_quantize_node.input[0]].append( + "{}_quant".format(conv_node.output[0]) + ) + + # Cleanup + # delete folded quantization ops + delete_quant_node(model, weight_dequantize_node, keep_params=False) + delete_quant_node(model, weight_quantize_node, keep_params=True) + + # only delete input node if the conv is the only child + current_graph = ONNXGraph(model) + if len(current_graph.get_node_children(input_quantize_node)) == 1: + delete_quant_node(model, input_quantize_node, keep_params=True) + + # delete original Conv node + remove_node_and_params_from_graph(model, conv_node, keep_params=None) + + conversion_count += 1 + + if conv_nodes: + _LOGGER.info( + f"Converted {conversion_count} quantizable Conv ops with weight and bias " + "to ConvInteger and Add" + ) + _reduce_qconv_shared_weights(model, orig_conv_weight_name_to_node_ids) + graph = ONNXGraph(model) + graph.delete_unused_initializers() + + def _reduce_qconv_shared_weights( model: ModelProto, orig_qconv_weight_name_to_node_ids: Dict[str, List[NodeProto]] ): @@ -1080,10 +1232,17 @@ def _reduce_qconv_shared_weights( continue qconv_nodes = [graph.get_node_by_output_id(id) for id in node_ids] - if any(node.op_type != "QLinearConv" for node in qconv_nodes): + if any( + node.op_type not in ["QLinearConv", "ConvInteger"] for node in qconv_nodes + ): continue - weights = [graph.get_init_by_name(node.input[3]) for node in qconv_nodes] + weights = [ + graph.get_init_by_name( + node.input[3 if node.op_type == "QLinearConv" else 1] + ) + for node in qconv_nodes + ] if any(weight is None for weight in weights): continue @@ -1095,14 +1254,15 @@ def _reduce_qconv_shared_weights( weights[0], name=f"{weight_name}_quantized" ) for node in qconv_nodes: - node.input[3] = shared_weight.name + target_dim = 3 if node.op_type == "QLinearConv" else 1 + node.input[target_dim] = shared_weight.name model.graph.initializer.append(shared_weight) graph.update() graph.delete_unused_initializers() -def _convert_quantizable_ops(model: ModelProto): +def _convert_quantizable_ops(model: ModelProto, convert_qlinearconv: bool): quantizable_nodes = [n for n in model.graph.node if n.op_type in ["Conv", "Gemm"]] orig_qconv_weight_name_to_node_ids = defaultdict(list) for quantizable_node in quantizable_nodes: @@ -1123,7 +1283,7 @@ def _convert_quantizable_ops(model: ModelProto): if not output_quant or output_quant.op_type not in _QUANTIZE_OP_NAMES: continue - if quantizable_node.op_type == "Conv": + if convert_qlinearconv and quantizable_node.op_type == "Conv": weight_name = weight_quant.input[0] qconv_node = _convert_quantizable_conv( model, @@ -1299,7 +1459,10 @@ def _cleanup_unused_quants(model: ModelProto): ) dequant_children = graph.get_node_children(dequant_node) for child in dequant_children: - if isinstance(child, onnx.NodeProto) and child.op_type in _QLINEAR_OP_NAMES: + # check if any dequant children depend on having QDQ inputs + if isinstance(child, onnx.NodeProto) and ( + child.op_type in KEEP_QUANT_INPUT_OPS + ): removable = False if not removable: continue @@ -1323,11 +1486,16 @@ def quantize_torch_qat_export( model: Union[ModelProto, str], output_file_path: Union[str, None] = None, inplace: bool = True, + use_qlinearconv: bool = False, ) -> ModelProto: """ :param model: The model to convert, or a file path to it :param output_file_path: File path to save the converted model to :param inplace: If true, does conversion of model in place. Default is true + :param use_qlinearconv: Set True to use legacy QLinearConv format instead + of ConvInteger. QLinearConv requires output activations be quantized + in the quantization recipe. (This was the default behavior prior to + sparseml 0.12). Default is False :return: Converts a model exported from a torch QAT session from a QAT graph with fake quantize ops surrounding operations to a quantized graph with quantized operations. All quantized Convs and FC inputs and outputs be surrounded by @@ -1345,7 +1513,12 @@ def quantize_torch_qat_export( _delete_repeated_qat_blocks(model) _convert_quantizable_matmul(model) _convert_quantizable_matmul_and_add(model) - _convert_quantizable_ops(model) + + # only convert to either ConvInteger or QLinearConv (legacy) + if not use_qlinearconv: + _convert_quantizable_conv_integer(model) + _convert_quantizable_ops(model, convert_qlinearconv=use_qlinearconv) + _convert_quantizable_gemm_no_activations(model) _quantize_qat_embedding(model) quantize_resnet_identity_add_inputs(model) diff --git a/src/sparseml/pytorch/utils/exporter.py b/src/sparseml/pytorch/utils/exporter.py index b00ba190cda..a987287aba7 100644 --- a/src/sparseml/pytorch/utils/exporter.py +++ b/src/sparseml/pytorch/utils/exporter.py @@ -498,7 +498,15 @@ def export_onnx( quantize_torch_qat_export, ) - quantize_torch_qat_export(model=file_path, output_file_path=file_path) + use_qlinearconv = hasattr(module, "export_with_qlinearconv") and ( + module.export_with_qlinearconv + ) + + quantize_torch_qat_export( + model=file_path, + output_file_path=file_path, + use_qlinearconv=use_qlinearconv, + ) if skip_input_quantize: try: From 194fb16be776237401f3f2d9b4a422de202acf53 Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Fri, 8 Apr 2022 14:15:15 -0400 Subject: [PATCH 216/218] Quantization Refactor Tests (#685) --- .../quantization/modifier_quantization.py | 8 ++ .../quantization/test_helpers.py | 11 +-- .../test_modifier_quantization.py | 75 +++++++++++++------ 3 files changed, 68 insertions(+), 26 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 648c93e35a4..5185d7cd464 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -375,6 +375,14 @@ def exclude_module_types(self) -> Union[List[str], None]: """ return self._exclude_module_types + @ModifierProp() + def exclude_batchnorm(self) -> bool: + """ + :return: if True, do not propagate quantization qconfigs to + batch-normalization modules + """ + return self._exclude_batchnorm + @ModifierProp() def activation_bits(self) -> Optional[int]: """ diff --git a/tests/sparseml/pytorch/sparsification/quantization/test_helpers.py b/tests/sparseml/pytorch/sparsification/quantization/test_helpers.py index 9d8fa1c734e..14c8556fb2f 100644 --- a/tests/sparseml/pytorch/sparsification/quantization/test_helpers.py +++ b/tests/sparseml/pytorch/sparsification/quantization/test_helpers.py @@ -20,6 +20,7 @@ from sparseml.pytorch.models import mobilenet, resnet50 from sparseml.pytorch.sparsification.quantization import ( QATWrapper, + QConfigProperties, add_quant_dequant, configure_module_default_qconfigs, configure_module_qat_wrappers, @@ -87,7 +88,7 @@ def test_configure_module_qat_wrappers(): assert not _module_has_quant_stubs(module) - configure_module_qat_wrappers(module) + configure_module_qat_wrappers(module, QConfigProperties()) qat_matmul = module.module.module @@ -122,7 +123,7 @@ def _assert_module_quant_stub_configs_exist(module, should_exist): reason="torch quantization not available", ) def test_configure_module_default_qconfigs(): - module = QATWrapper.from_module(_QATMatMul()) + module = QATWrapper.from_module(_QATMatMul(), QConfigProperties()) _assert_module_quant_stub_configs_exist(module, False) configure_module_default_qconfigs(module) @@ -142,7 +143,7 @@ def test_configure_module_default_qconfigs(): reason="torch quantization not available", ) def test_configure_module_default_qconfigs_no_config(): - module = QATWrapper.from_module(_QATMatMul()) + module = QATWrapper.from_module(_QATMatMul(), QConfigProperties()) _assert_module_quant_stub_configs_exist(module, False) module.configure_qconfig = None @@ -191,7 +192,7 @@ def test_add_quant_dequant(model_lambda, num_quantizable_ops): reason="torch quantization not available", ) def test_get_qat_qconfig(): - assert isinstance(get_qat_qconfig(), torch_quantization.QConfig) + assert isinstance(get_qat_qconfig(QConfigProperties()), torch_quantization.QConfig) def _get_fake_conv_relus(num_blocks=1): @@ -273,7 +274,7 @@ def test_prepare_embeddings_qat(): # check that fake quant observer is properly added assert not hasattr(module.module, "weight_fake_quant") - prepare_embeddings_qat(module) + prepare_embeddings_qat(module, QConfigProperties()) assert hasattr(module.module, "weight_fake_quant") # check that the observer is updated on embedding forward pass diff --git a/tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py b/tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py index a7b2551b75e..c19cf1c42d8 100644 --- a/tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py +++ b/tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py @@ -18,7 +18,7 @@ from torch.nn import Conv2d, Identity, Linear from sparseml.pytorch.sparsification import QuantizationModifier -from tests.sparseml.pytorch.helpers import LinearNet, create_optim_sgd +from tests.sparseml.pytorch.helpers import ConvNet, LinearNet, create_optim_sgd from tests.sparseml.pytorch.optim.test_modifier import ScheduledModifierTest @@ -33,6 +33,7 @@ except Exception: torch_quantization = None + QUANTIZATION_MODIFIERS = [ lambda: QuantizationModifier( start_epoch=0.0, @@ -40,6 +41,18 @@ freeze_bn_stats_epoch=3.0, ), lambda: QuantizationModifier(start_epoch=2.0, submodules=["seq"]), + lambda: QuantizationModifier( + start_epoch=0.0, + quantize_linear_activations=False, + quantize_conv_activations=False, + ), + lambda: QuantizationModifier( + start_epoch=0.0, + activation_bits=4, + ), +] + +QUANTIZATION_MODIFIERS_LINEAR = QUANTIZATION_MODIFIERS + [ lambda: QuantizationModifier(start_epoch=2.0, submodules=["seq.fc1"]), lambda: QuantizationModifier( start_epoch=2.0, submodules=["seq.fc1", "seq.block1.fc2"] @@ -49,14 +62,6 @@ submodules=["seq.fc1", "seq.block1.fc2"], reduce_range=True, ), - lambda: QuantizationModifier( - start_epoch=0.0, - quantize_linear_activations=False, - ), - lambda: QuantizationModifier( - start_epoch=0.0, - activation_bits=4, - ), lambda: QuantizationModifier( start_epoch=0.0, exclude_module_types=["Linear"], @@ -90,16 +95,13 @@ def _test_quantizable_module(module, qat_expected, modifier): ) if module.qconfig.activation is not Identity: assert module.qconfig.activation.p.keywords["reduce_range"] == reduce_range - if modifier.activation_bits is not None: - expected_quant_min = 0 - expected_quant_max = (1 << modifier.activation_bits) - 1 + if modifier.activation_bits is not None: + expected_quant_min = 0 + expected_quant_max = (1 << modifier.activation_bits) - 1 + activation_quant_properties = module.qconfig.activation.p.keywords - assert ( - module.qconfig.activation.p.keywords["quant_min"] == expected_quant_min - ) - assert ( - module.qconfig.activation.p.keywords["quant_max"] == expected_quant_max - ) + assert activation_quant_properties["quant_min"] == expected_quant_min + assert activation_quant_properties["quant_max"] == expected_quant_max if isinstance(module, Linear): assert isinstance(module.activation_post_process, Identity) == ( @@ -143,8 +145,17 @@ def _test_qat_applied(modifier, model): torch_quantization is None, reason="torch quantization not available", ) -@pytest.mark.parametrize("modifier_lambda", QUANTIZATION_MODIFIERS, scope="function") -@pytest.mark.parametrize("model_lambda", [LinearNet], scope="function") +@pytest.mark.parametrize( + "modifier_lambda,model_lambda", + list( + zip( + QUANTIZATION_MODIFIERS_LINEAR, + [LinearNet] * len(QUANTIZATION_MODIFIERS_LINEAR), + ) + ) + + list(zip(QUANTIZATION_MODIFIERS, [ConvNet] * len(QUANTIZATION_MODIFIERS))), + scope="function", +) @pytest.mark.parametrize("optim_lambda", [create_optim_sgd], scope="function") class TestQuantizationModifierImpl(ScheduledModifierTest): def test_lifecycle( @@ -152,7 +163,7 @@ def test_lifecycle( modifier_lambda, model_lambda, optim_lambda, - test_steps_per_epoch, # noqa: F811 + test_steps_per_epoch, # noqa: F811, ): modifier = modifier_lambda() model = model_lambda() @@ -221,10 +232,13 @@ def test_quantization_modifier_yaml(): quantize_embeddings = False reduce_range = True quantize_linear_activations = False + quantize_conv_activations = False num_calibration_steps = 2 exclude_module_types = ["LayerNorm", "Tanh"] activation_bits = 4 averaging_constant = 0.05 + tensorrt = False + exclude_batchnorm = False activation_qconfig_kwargs = dict( averaging_constant=averaging_constant, ) @@ -238,10 +252,13 @@ def test_quantization_modifier_yaml(): quantize_embeddings: {quantize_embeddings} reduce_range: {reduce_range} quantize_linear_activations: {quantize_linear_activations} + quantize_conv_activations: {quantize_conv_activations} num_calibration_steps: {num_calibration_steps} exclude_module_types: {exclude_module_types} activation_bits: {activation_bits} activation_qconfig_kwargs: {activation_qconfig_kwargs} + tensorrt: {tensorrt} + exclude_batchnorm: {exclude_batchnorm} """ yaml_modifier = QuantizationModifier.load_obj( yaml_str @@ -258,10 +275,13 @@ def test_quantization_modifier_yaml(): quantize_embeddings=quantize_embeddings, reduce_range=reduce_range, quantize_linear_activations=quantize_linear_activations, + quantize_conv_activations=quantize_conv_activations, activation_bits=activation_bits, num_calibration_steps=num_calibration_steps, exclude_module_types=exclude_module_types, activation_qconfig_kwargs=activation_qconfig_kwargs, + tensorrt=tensorrt, + exclude_batchnorm=exclude_batchnorm, ) assert isinstance(yaml_modifier, QuantizationModifier) @@ -305,6 +325,11 @@ def test_quantization_modifier_yaml(): == serialized_modifier.quantize_linear_activations == obj_modifier.quantize_linear_activations ) + assert ( + yaml_modifier.quantize_conv_activations + == serialized_modifier.quantize_conv_activations + == obj_modifier.quantize_conv_activations + ) assert ( yaml_modifier.activation_bits == serialized_modifier.activation_bits @@ -325,3 +350,11 @@ def test_quantization_modifier_yaml(): == serialized_modifier.activation_qconfig_kwargs == obj_modifier.activation_qconfig_kwargs ) + assert ( + yaml_modifier.tensorrt == serialized_modifier.tensorrt == obj_modifier.tensorrt + ) + assert ( + yaml_modifier.exclude_batchnorm + == serialized_modifier.exclude_batchnorm + == obj_modifier.exclude_batchnorm + ) From 5f74e3127bf9be447821bcebb30718e460643971 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Fri, 8 Apr 2022 15:12:36 -0400 Subject: [PATCH 217/218] rebase import fix --- .../sparsification/quantization/test_modifier_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py b/tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py index c19cf1c42d8..64e1ae9d5ba 100644 --- a/tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py +++ b/tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py @@ -19,7 +19,7 @@ from sparseml.pytorch.sparsification import QuantizationModifier from tests.sparseml.pytorch.helpers import ConvNet, LinearNet, create_optim_sgd -from tests.sparseml.pytorch.optim.test_modifier import ScheduledModifierTest +from tests.sparseml.pytorch.sparsification.test_modifier import ScheduledModifierTest from tests.sparseml.pytorch.helpers import ( # noqa isort:skip From 5796f4fde9d62141cf49ce0ff2e0f57c5548b3dd Mon Sep 17 00:00:00 2001 From: Benjamin Date: Fri, 8 Apr 2022 15:30:37 -0400 Subject: [PATCH 218/218] update manager serialization test cases for new quantization params --- tests/sparseml/pytorch/optim/test_manager.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/sparseml/pytorch/optim/test_manager.py b/tests/sparseml/pytorch/optim/test_manager.py index 415b703a34f..abefbbb5f4e 100644 --- a/tests/sparseml/pytorch/optim/test_manager.py +++ b/tests/sparseml/pytorch/optim/test_manager.py @@ -455,12 +455,18 @@ start_epoch: 0.0 - !QuantizationModifier + activation_bits: 8 end_epoch: 52 + exclude_batchnorm: True + model_fuse_fn_name: conv_bn_relus + quantize_conv_activations: True quantize_embeddings: True quantize_linear_activations: True reduce_range: False start_epoch: 50 submodules: ['model.0'] + tensorrt: False + weight_bits: 8 - !SetLearningRateModifier constant_logging: False @@ -484,12 +490,18 @@ update_frequency: -1 - !QuantizationModifier + activation_bits: 8 end_epoch: -1.0 + exclude_batchnorm: True + model_fuse_fn_name: conv_bn_relus + quantize_conv_activations: True quantize_embeddings: True quantize_linear_activations: True reduce_range: False start_epoch: 102 submodules: ['model.0'] + tensorrt: False + weight_bits: 8 - !SetLearningRateModifier constant_logging: False