diff --git a/setup.py b/setup.py index f3d2ec580f4..db5f911734d 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,7 @@ "flake8==3.9.2", "isort==5.8.0", "m2r2~=0.2.7", + "mistune==0.8.4", "myst-parser~=0.14.0", "rinohtype~=0.4.2", "sphinx~=3.5.0", diff --git a/src/sparseml/pytorch/optim/modifier_quantization.py b/src/sparseml/pytorch/optim/modifier_quantization.py index ed69286beb6..464df8bb92d 100644 --- a/src/sparseml/pytorch/optim/modifier_quantization.py +++ b/src/sparseml/pytorch/optim/modifier_quantization.py @@ -100,6 +100,8 @@ class QuantizationModifier(ScheduledModifier): transformer based models such as BERT where the quantized MatMul outputs are kept at 32 bits of precision and fake quantizing the outputs harm training recovery. Default is True + :param exclude_module_types: optional list of module class names + to not propagate quantization configs to. Default is None """ def __init__( @@ -114,6 +116,7 @@ def __init__( quantize_embeddings: bool = True, reduce_range: bool = False, quantize_linear_activations: bool = True, + exclude_module_types: Union[List[str], None] = None, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -138,6 +141,7 @@ def __init__( self._quantize_embeddings = quantize_embeddings self._reduce_range = reduce_range self._quantize_linear_activations = quantize_linear_activations + self._exclude_module_types = exclude_module_types self._modules_to_quantize = None self._qat_enabled = False @@ -278,6 +282,14 @@ def quantize_linear_activations(self) -> bool: """ return self._quantize_linear_activations + @ModifierProp() + def exclude_module_types(self) -> Union[List[str], None]: + """ + :return: optional list of module class names to not propagate + quantization configs to. Default is None + """ + return self._exclude_module_types + def initialize( self, module: Module, @@ -423,10 +435,15 @@ def _enable_module_qat(self, module: Module): if not self._quantize_linear_activations: remove_activation_qat_by_layer_name(quant_module, ["Linear"]) + # remove qconfigs for module types in exclude_module_types + if self._exclude_module_types: + self._strip_excluded_module_qconfigs(module) + # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) if self._quantize_embeddings: prepare_embeddings_qat(module, reduce_range=self._reduce_range) + self._qat_enabled = True def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: @@ -443,6 +460,16 @@ def _freeze_bn_stats_update_ready(self, epoch: float) -> bool: and not self._bn_stats_frozen ) + def _strip_excluded_module_qconfigs(self, module: Module): + if not self._exclude_module_types: + return + excluded_classes = set(self._exclude_module_types) + for submodule in module.modules(): + if submodule.__class__.__name__ in excluded_classes and hasattr( + submodule, "qconfig" + ): + submodule.qconfig = None + def _validate_params(self): if ( self._disable_quantization_observer_epoch is not None diff --git a/src/sparseml/pytorch/utils/quantization/quantize_qat_export.py b/src/sparseml/pytorch/utils/quantization/quantize_qat_export.py index d5c548fc711..1ed660bdd3a 100644 --- a/src/sparseml/pytorch/utils/quantization/quantize_qat_export.py +++ b/src/sparseml/pytorch/utils/quantization/quantize_qat_export.py @@ -264,7 +264,11 @@ def _delete_repeated_qat_blocks(model: ModelProto): nodes_to_delete.append(dequant_node_1) for n in nodes_to_delete: - delete_quant_node(model, n) + delete_quant_node(model, n, keep_params=True) + + # cleanup graph + graph.update() + graph.delete_unused_initializers() def _attribute_to_kwarg(attribute: onnx.AttributeProto): @@ -1214,12 +1218,14 @@ def _quantize_qat_embedding(model: ModelProto): qdq_output = False if qdq_output: + # forward gather output to dequant input + output_dequant_node.input[0] = gather_node.output[0] + output_dequant_node.input[1] = input_quant_node.input[1] + output_dequant_node.input[2] = input_quant_node.input[2] # delete unnecessary quantize and dequantize ops - delete_quant_node(model, input_quant_node, keep_params=False) + delete_quant_node(model, input_quant_node, keep_params=True) delete_quant_node(model, input_dequant_node, keep_params=False) delete_quant_node(model, output_quant_node, keep_params=False) - # forward gather output to dequant input - output_dequant_node.input[0] = gather_node.output[0] else: # use input dequant to dequantize output @@ -1265,7 +1271,10 @@ def _remove_duplicate_quantize_ops(model: ModelProto): _replace_input_id_model( model, remove_node.output[0], keep_node.output[0] ) - remove_node_and_params_from_graph(model, remove_node) + delete_quant_node(model, remove_node, keep_params=True) + # cleanup graph + graph.update() + graph.delete_unused_initializers() def _cleanup_unused_quants(model: ModelProto): @@ -1296,15 +1305,18 @@ def _cleanup_unused_quants(model: ModelProto): continue # Forward QuantizeLinear input to DequantizeLinear output - for child in dequant_children: - _replace_input_id_model(model, dequant_node.output[0], quant_node.input[0]) + _replace_input_id_model(model, dequant_node.output[0], quant_node.input[0]) # Remove QuantizeLinear->DequantizeLinear block nodes_to_delete.append(quant_node) nodes_to_delete.append(dequant_node) for n in nodes_to_delete: - delete_quant_node(model, n) + delete_quant_node(model, n, keep_params=True) + + # update graph + graph.update() + graph.delete_unused_initializers() def quantize_torch_qat_export( diff --git a/src/sparseml/transformers/export.py b/src/sparseml/transformers/export.py index 2bc63ef6873..f42cd44a668 100644 --- a/src/sparseml/transformers/export.py +++ b/src/sparseml/transformers/export.py @@ -55,6 +55,8 @@ """ import argparse +import collections +import inspect import logging import math import os @@ -180,6 +182,24 @@ def export_transformer_to_onnx( inputs = tokenizer( "", return_tensors="pt", padding=PaddingStrategy.MAX_LENGTH.value ).data # Dict[Tensor] + + # Rearrange inputs' keys to match those defined by model foward func, which + # seem to define how the order of inputs is determined in the exported model + forward_args_spec = inspect.getfullargspec(model.__class__.forward) + dropped = [f for f in inputs.keys() if f not in forward_args_spec.args] + inputs = collections.OrderedDict( + [ + (f, inputs[f][0].reshape(1, -1)) + for f in forward_args_spec.args + if f in inputs + ] + ) + if dropped: + _LOGGER.warning( + "The following inputs were not present in the model forward function " + f"and therefore dropped from ONNX export: {dropped}" + ) + inputs_shapes = { key: ( f"{val.dtype if hasattr(val, 'dtype') else 'unknown'}: " @@ -187,6 +207,7 @@ def export_transformer_to_onnx( ) for key, val in inputs.items() } + _LOGGER.info(f"Created sample inputs for the ONNX export process: {inputs_shapes}") # run export diff --git a/src/sparseml/version.py b/src/sparseml/version.py index 78ecba7701e..8706a6f9544 100644 --- a/src/sparseml/version.py +++ b/src/sparseml/version.py @@ -19,7 +19,7 @@ from datetime import date -version_base = "0.10.0" +version_base = "0.10.1" is_release = False # change to True to set the generated version as a release version diff --git a/tests/sparseml/pytorch/optim/test_modifier_quantization.py b/tests/sparseml/pytorch/optim/test_modifier_quantization.py index 9e99fa464ac..d8cf99057db 100644 --- a/tests/sparseml/pytorch/optim/test_modifier_quantization.py +++ b/tests/sparseml/pytorch/optim/test_modifier_quantization.py @@ -54,6 +54,10 @@ start_epoch=0.0, quantize_linear_activations=False, ), + lambda: QuantizationModifier( + start_epoch=0.0, + exclude_module_types=["Linear"], + ), ] @@ -67,9 +71,13 @@ def _is_quantiable_module(module): return isinstance(module, (Conv2d, Linear)) -def _test_quantizable_module( - module, qat_expected, reduce_range, quantize_linear_activations -): +def _test_quantizable_module(module, qat_expected, modifier): + reduce_range = modifier.reduce_range + quantize_linear_activations = modifier.quantize_linear_activations + + excluded_types = modifier.exclude_module_types or [] + qat_expected = qat_expected and module.__class__.__name__ not in excluded_types + if qat_expected: assert hasattr(module, "qconfig") and module.qconfig is not None assert hasattr(module, "weight_fake_quant") and ( @@ -97,12 +105,7 @@ def _test_qat_applied(modifier, model): submodules = [""] for module in model.modules(): if _is_quantiable_module(module): - _test_quantizable_module( - module, - True, - modifier.reduce_range, - modifier.quantize_linear_activations, - ) + _test_quantizable_module(module, True, modifier) else: assert not hasattr(model, "qconfig") or model.qconfig is None submodules = modifier.submodules @@ -112,8 +115,7 @@ def _test_qat_applied(modifier, model): _test_quantizable_module( module, _is_valid_submodule(name, submodules), - modifier.reduce_range, - modifier.quantize_linear_activations, + modifier, ) @@ -207,6 +209,7 @@ def test_quantization_modifier_yaml(): quantize_embeddings = False reduce_range = True quantize_linear_activations = False + exclude_module_types = ["LayerNorm", "Tanh"] yaml_str = f""" !QuantizationModifier start_epoch: {start_epoch} @@ -217,6 +220,7 @@ def test_quantization_modifier_yaml(): quantize_embeddings: {quantize_embeddings} reduce_range: {reduce_range} quantize_linear_activations: {quantize_linear_activations} + exclude_module_types: {exclude_module_types} """ yaml_modifier = QuantizationModifier.load_obj( yaml_str @@ -233,6 +237,7 @@ def test_quantization_modifier_yaml(): quantize_embeddings=quantize_embeddings, reduce_range=reduce_range, quantize_linear_activations=quantize_linear_activations, + exclude_module_types=exclude_module_types, ) assert isinstance(yaml_modifier, QuantizationModifier) @@ -276,3 +281,8 @@ def test_quantization_modifier_yaml(): == serialized_modifier.quantize_linear_activations == obj_modifier.quantize_linear_activations ) + assert ( + yaml_modifier.exclude_module_types + == serialized_modifier.exclude_module_types + == obj_modifier.exclude_module_types + )