diff --git a/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py b/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py index 24a2bdff87b..64bc372e7cb 100644 --- a/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py +++ b/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py @@ -1296,6 +1296,14 @@ def _convert_quantizable_ops(model: ModelProto, convert_qlinearconv: bool): orig_qconv_weight_name_to_node_ids[weight_name].append(qconv_node.output[0]) if quantizable_node.op_type == "Gemm": + output_dequant = graph.get_node_single_child(output_quant) + if output_dequant and output_dequant.op_type in _QUANTIZE_OP_NAMES: + output_dequant_child = graph.get_node_single_child(output_dequant) + if output_dequant_child and output_dequant_child.op_type == "Gemm": + # output quant is not a QDQ block for the current Gemm Node but, + # the input QDQ block for a new Gemm block this Gemm should be + # skipped and processed by _convert_quantizable_gemm_no_activations + continue _convert_quantizable_gemm( model, quantizable_node,