From 9304997911c5ae49f73f91222dde14098f61e99b Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Wed, 2 Feb 2022 15:34:39 -0700 Subject: [PATCH 1/6] Update README.md for transformers to note the quantization conversion issue (#539) * Update README.md * Update integrations/huggingface-transformers/README.md Co-authored-by: Jeannie Finks <74554921+jeanniefinks@users.noreply.github.com> Co-authored-by: Jeannie Finks <74554921+jeanniefinks@users.noreply.github.com> --- integrations/huggingface-transformers/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/integrations/huggingface-transformers/README.md b/integrations/huggingface-transformers/README.md index 637d4a3a375..4485107b65a 100644 --- a/integrations/huggingface-transformers/README.md +++ b/integrations/huggingface-transformers/README.md @@ -116,3 +116,5 @@ python transformers/examples/pytorch/question-answering/run_qa.py \ The DeepSparse Engine [accepts ONNX formats](https://docs.neuralmagic.com/sparseml/source/onnx_export.html) and is engineered to significantly speed up inference on CPUs for the sparsified models from this integration. Examples for loading, benchmarking, and deploying can be found in the [DeepSparse repository here](https://github.com/neuralmagic/deepsparse). + +**Note: there is currently a known issue where conversion of the BERT models from PyTorch into ONNX is not preserving the accuracy of the model for some tasks and datasets. If you encounter this issue, try rolling back to the 0.9.0 release. As a resolution is being actively investigated, this note will be removed when the issue has been remediated.** From b3613539df7bbdf138fadd5fdd8bd186384a0bc6 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Tue, 8 Feb 2022 13:52:48 -0500 Subject: [PATCH 2/6] Enforce order on input keys to export (#545) * Enforce order on input keys to export * Warn if input dropped from onnx export --- src/sparseml/transformers/export.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/sparseml/transformers/export.py b/src/sparseml/transformers/export.py index 2bc63ef6873..f42cd44a668 100644 --- a/src/sparseml/transformers/export.py +++ b/src/sparseml/transformers/export.py @@ -55,6 +55,8 @@ """ import argparse +import collections +import inspect import logging import math import os @@ -180,6 +182,24 @@ def export_transformer_to_onnx( inputs = tokenizer( "", return_tensors="pt", padding=PaddingStrategy.MAX_LENGTH.value ).data # Dict[Tensor] + + # Rearrange inputs' keys to match those defined by model foward func, which + # seem to define how the order of inputs is determined in the exported model + forward_args_spec = inspect.getfullargspec(model.__class__.forward) + dropped = [f for f in inputs.keys() if f not in forward_args_spec.args] + inputs = collections.OrderedDict( + [ + (f, inputs[f][0].reshape(1, -1)) + for f in forward_args_spec.args + if f in inputs + ] + ) + if dropped: + _LOGGER.warning( + "The following inputs were not present in the model forward function " + f"and therefore dropped from ONNX export: {dropped}" + ) + inputs_shapes = { key: ( f"{val.dtype if hasattr(val, 'dtype') else 'unknown'}: " @@ -187,6 +207,7 @@ def export_transformer_to_onnx( ) for key, val in inputs.items() } + _LOGGER.info(f"Created sample inputs for the ONNX export process: {inputs_shapes}") # run export From 8c5d39834f71fddbb4a1622bcb593caa66a371d5 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 8 Feb 2022 16:22:22 -0500 Subject: [PATCH 3/6] Restrict mistune version to fix docs build (#547) --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index f3d2ec580f4..db5f911734d 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,7 @@ "flake8==3.9.2", "isort==5.8.0", "m2r2~=0.2.7", + "mistune==0.8.4", "myst-parser~=0.14.0", "rinohtype~=0.4.2", "sphinx~=3.5.0", From 037aff752505fb5dbd11840743dc702826316182 Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Wed, 9 Feb 2022 11:30:52 -0500 Subject: [PATCH 4/6] quantization fixes for transformers flows (#548) * quantization fixes for transformers flows * match on class name instead * quality --- .../pytorch/optim/modifier_quantization.py | 27 ++++++++++++++++ .../utils/quantization/quantize_qat_export.py | 28 +++++++++++----- .../optim/test_modifier_quantization.py | 32 ++++++++++++------- 3 files changed, 68 insertions(+), 19 deletions(-) diff --git a/src/sparseml/pytorch/optim/modifier_quantization.py b/src/sparseml/pytorch/optim/modifier_quantization.py index ed69286beb6..464df8bb92d 100644 --- a/src/sparseml/pytorch/optim/modifier_quantization.py +++ b/src/sparseml/pytorch/optim/modifier_quantization.py @@ -100,6 +100,8 @@ class QuantizationModifier(ScheduledModifier): transformer based models such as BERT where the quantized MatMul outputs are kept at 32 bits of precision and fake quantizing the outputs harm training recovery. Default is True + :param exclude_module_types: optional list of module class names + to not propagate quantization configs to. Default is None """ def __init__( @@ -114,6 +116,7 @@ def __init__( quantize_embeddings: bool = True, reduce_range: bool = False, quantize_linear_activations: bool = True, + exclude_module_types: Union[List[str], None] = None, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -138,6 +141,7 @@ def __init__( self._quantize_embeddings = quantize_embeddings self._reduce_range = reduce_range self._quantize_linear_activations = quantize_linear_activations + self._exclude_module_types = exclude_module_types self._modules_to_quantize = None self._qat_enabled = False @@ -278,6 +282,14 @@ def quantize_linear_activations(self) -> bool: """ return self._quantize_linear_activations + @ModifierProp() + def exclude_module_types(self) -> Union[List[str], None]: + """ + :return: optional list of module class names to not propagate + quantization configs to. Default is None + """ + return self._exclude_module_types + def initialize( self, module: Module, @@ -423,10 +435,15 @@ def _enable_module_qat(self, module: Module): if not self._quantize_linear_activations: remove_activation_qat_by_layer_name(quant_module, ["Linear"]) + # remove qconfigs for module types in exclude_module_types + if self._exclude_module_types: + self._strip_excluded_module_qconfigs(module) + # set modules with proper qconfigs to QAT mode torch_quantization.prepare_qat(module, inplace=True) if self._quantize_embeddings: prepare_embeddings_qat(module, reduce_range=self._reduce_range) + self._qat_enabled = True def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: @@ -443,6 +460,16 @@ def _freeze_bn_stats_update_ready(self, epoch: float) -> bool: and not self._bn_stats_frozen ) + def _strip_excluded_module_qconfigs(self, module: Module): + if not self._exclude_module_types: + return + excluded_classes = set(self._exclude_module_types) + for submodule in module.modules(): + if submodule.__class__.__name__ in excluded_classes and hasattr( + submodule, "qconfig" + ): + submodule.qconfig = None + def _validate_params(self): if ( self._disable_quantization_observer_epoch is not None diff --git a/src/sparseml/pytorch/utils/quantization/quantize_qat_export.py b/src/sparseml/pytorch/utils/quantization/quantize_qat_export.py index d5c548fc711..1ed660bdd3a 100644 --- a/src/sparseml/pytorch/utils/quantization/quantize_qat_export.py +++ b/src/sparseml/pytorch/utils/quantization/quantize_qat_export.py @@ -264,7 +264,11 @@ def _delete_repeated_qat_blocks(model: ModelProto): nodes_to_delete.append(dequant_node_1) for n in nodes_to_delete: - delete_quant_node(model, n) + delete_quant_node(model, n, keep_params=True) + + # cleanup graph + graph.update() + graph.delete_unused_initializers() def _attribute_to_kwarg(attribute: onnx.AttributeProto): @@ -1214,12 +1218,14 @@ def _quantize_qat_embedding(model: ModelProto): qdq_output = False if qdq_output: + # forward gather output to dequant input + output_dequant_node.input[0] = gather_node.output[0] + output_dequant_node.input[1] = input_quant_node.input[1] + output_dequant_node.input[2] = input_quant_node.input[2] # delete unnecessary quantize and dequantize ops - delete_quant_node(model, input_quant_node, keep_params=False) + delete_quant_node(model, input_quant_node, keep_params=True) delete_quant_node(model, input_dequant_node, keep_params=False) delete_quant_node(model, output_quant_node, keep_params=False) - # forward gather output to dequant input - output_dequant_node.input[0] = gather_node.output[0] else: # use input dequant to dequantize output @@ -1265,7 +1271,10 @@ def _remove_duplicate_quantize_ops(model: ModelProto): _replace_input_id_model( model, remove_node.output[0], keep_node.output[0] ) - remove_node_and_params_from_graph(model, remove_node) + delete_quant_node(model, remove_node, keep_params=True) + # cleanup graph + graph.update() + graph.delete_unused_initializers() def _cleanup_unused_quants(model: ModelProto): @@ -1296,15 +1305,18 @@ def _cleanup_unused_quants(model: ModelProto): continue # Forward QuantizeLinear input to DequantizeLinear output - for child in dequant_children: - _replace_input_id_model(model, dequant_node.output[0], quant_node.input[0]) + _replace_input_id_model(model, dequant_node.output[0], quant_node.input[0]) # Remove QuantizeLinear->DequantizeLinear block nodes_to_delete.append(quant_node) nodes_to_delete.append(dequant_node) for n in nodes_to_delete: - delete_quant_node(model, n) + delete_quant_node(model, n, keep_params=True) + + # update graph + graph.update() + graph.delete_unused_initializers() def quantize_torch_qat_export( diff --git a/tests/sparseml/pytorch/optim/test_modifier_quantization.py b/tests/sparseml/pytorch/optim/test_modifier_quantization.py index 9e99fa464ac..d8cf99057db 100644 --- a/tests/sparseml/pytorch/optim/test_modifier_quantization.py +++ b/tests/sparseml/pytorch/optim/test_modifier_quantization.py @@ -54,6 +54,10 @@ start_epoch=0.0, quantize_linear_activations=False, ), + lambda: QuantizationModifier( + start_epoch=0.0, + exclude_module_types=["Linear"], + ), ] @@ -67,9 +71,13 @@ def _is_quantiable_module(module): return isinstance(module, (Conv2d, Linear)) -def _test_quantizable_module( - module, qat_expected, reduce_range, quantize_linear_activations -): +def _test_quantizable_module(module, qat_expected, modifier): + reduce_range = modifier.reduce_range + quantize_linear_activations = modifier.quantize_linear_activations + + excluded_types = modifier.exclude_module_types or [] + qat_expected = qat_expected and module.__class__.__name__ not in excluded_types + if qat_expected: assert hasattr(module, "qconfig") and module.qconfig is not None assert hasattr(module, "weight_fake_quant") and ( @@ -97,12 +105,7 @@ def _test_qat_applied(modifier, model): submodules = [""] for module in model.modules(): if _is_quantiable_module(module): - _test_quantizable_module( - module, - True, - modifier.reduce_range, - modifier.quantize_linear_activations, - ) + _test_quantizable_module(module, True, modifier) else: assert not hasattr(model, "qconfig") or model.qconfig is None submodules = modifier.submodules @@ -112,8 +115,7 @@ def _test_qat_applied(modifier, model): _test_quantizable_module( module, _is_valid_submodule(name, submodules), - modifier.reduce_range, - modifier.quantize_linear_activations, + modifier, ) @@ -207,6 +209,7 @@ def test_quantization_modifier_yaml(): quantize_embeddings = False reduce_range = True quantize_linear_activations = False + exclude_module_types = ["LayerNorm", "Tanh"] yaml_str = f""" !QuantizationModifier start_epoch: {start_epoch} @@ -217,6 +220,7 @@ def test_quantization_modifier_yaml(): quantize_embeddings: {quantize_embeddings} reduce_range: {reduce_range} quantize_linear_activations: {quantize_linear_activations} + exclude_module_types: {exclude_module_types} """ yaml_modifier = QuantizationModifier.load_obj( yaml_str @@ -233,6 +237,7 @@ def test_quantization_modifier_yaml(): quantize_embeddings=quantize_embeddings, reduce_range=reduce_range, quantize_linear_activations=quantize_linear_activations, + exclude_module_types=exclude_module_types, ) assert isinstance(yaml_modifier, QuantizationModifier) @@ -276,3 +281,8 @@ def test_quantization_modifier_yaml(): == serialized_modifier.quantize_linear_activations == obj_modifier.quantize_linear_activations ) + assert ( + yaml_modifier.exclude_module_types + == serialized_modifier.exclude_module_types + == obj_modifier.exclude_module_types + ) From 32512f041491dc03f923b866ed4a7ce563d9b358 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 9 Feb 2022 11:35:11 -0500 Subject: [PATCH 5/6] set release branch version to 0.10.1 --- src/sparseml/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/version.py b/src/sparseml/version.py index 78ecba7701e..8706a6f9544 100644 --- a/src/sparseml/version.py +++ b/src/sparseml/version.py @@ -19,7 +19,7 @@ from datetime import date -version_base = "0.10.0" +version_base = "0.10.1" is_release = False # change to True to set the generated version as a release version From f677fb24f25985cd4b1ad545a8448b8bdfe9e115 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 9 Feb 2022 11:59:04 -0500 Subject: [PATCH 6/6] Revert "Update README.md for transformers to note the quantization conversion issue (#539)" This reverts commit 9304997911c5ae49f73f91222dde14098f61e99b. --- integrations/huggingface-transformers/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/integrations/huggingface-transformers/README.md b/integrations/huggingface-transformers/README.md index 4485107b65a..637d4a3a375 100644 --- a/integrations/huggingface-transformers/README.md +++ b/integrations/huggingface-transformers/README.md @@ -116,5 +116,3 @@ python transformers/examples/pytorch/question-answering/run_qa.py \ The DeepSparse Engine [accepts ONNX formats](https://docs.neuralmagic.com/sparseml/source/onnx_export.html) and is engineered to significantly speed up inference on CPUs for the sparsified models from this integration. Examples for loading, benchmarking, and deploying can be found in the [DeepSparse repository here](https://github.com/neuralmagic/deepsparse). - -**Note: there is currently a known issue where conversion of the BERT models from PyTorch into ONNX is not preserving the accuracy of the model for some tasks and datasets. If you encounter this issue, try rolling back to the 0.9.0 release. As a resolution is being actively investigated, this note will be removed when the issue has been remediated.**