diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 210e57df76b..7c2046a640f 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -673,6 +673,12 @@ def _enable_module_qat(self, module: Module): self._qat_enabled = True self._calibrate_if_possible(module) + # mark export mode for module Conv layers + module.export_with_qlinearconv = self._quantize_conv_activations + if hasattr(module, "module"): + # for DP/DDP unwrapping + module.module.export_with_qlinearconv = self._quantize_conv_activations + def _calibrate_if_possible(self, module): if self.num_calibration_steps == 0 and self._calibration_dataloader: warnings.warn( diff --git a/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py b/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py index 1ed660bdd3a..24a2bdff87b 100644 --- a/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py +++ b/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py @@ -63,7 +63,13 @@ _QUANTIZE_OP_NAMES = ["QuantizeLinear", "DequantizeLinear"] -_QLINEAR_OP_NAMES = ["QLinearConv", "QLinearMatMul", "QLinearAdd"] +KEEP_QUANT_INPUT_OPS = [ + "Add", + "ConvInteger", + "MatMulInteger, " "QLinearConv", + "QLinearMatMul", + "QLinearAdd", +] def get_quantization_params( @@ -160,7 +166,10 @@ def _fold_conv_bn_bias(model: ModelProto, conv_node: NodeProto, bn_node: NodePro folded_bias = folded_bias.astype(numpy.float32) bias_name = conv_node.name + ".bias" - conv_node.input.append(bias_name) + if len(conv_node.input) > 2: + conv_node.input[2] = bias_name + else: + conv_node.input.append(bias_name) update_model_param(model, bias_name, folded_bias) # forward conv output to bn children @@ -656,9 +665,9 @@ def _convert_quantizable_matmul(model: ModelProto): ) -def _add_quantized_matmul_add_ops( +def _add_quantized_conv_matmul_add_ops( model: ModelProto, - matmul_node: NodeProto, + node: NodeProto, input_quantize_node: NodeProto, weight_quantize_node: NodeProto, input_quantize_params: QuantizationParams, @@ -670,10 +679,11 @@ def _add_quantized_matmul_add_ops( output_quantize_node: Optional[NodeProto] = None, output_dequantize_node: Optional[NodeProto] = None, ): - # helper function for conversion of qat parameterized gemms/matmuls to - # matmul integer add blocks. Adds new quantized ops to graph, does not + # helper function for conversion of qat parameterized gemms, matmuls, + # or convs to conv/matmul integer add blocks. + # Adds new quantized ops to graph, does not # perform any checks or deletions (should be called by the operator main - # conversion function + # conversion function) # quantize weight quantized_weight = _quantize_array( @@ -683,31 +693,46 @@ def _add_quantized_matmul_add_ops( ) if transpose_weight: quantized_weight = quantized_weight.transpose() - quantized_weight_name = "{}.weight_quantized".format(matmul_node.name) + quantized_weight_name = "{}.weight_quantized".format(node.name) quantized_weight_initializer = numpy_helper.from_array( quantized_weight, name=quantized_weight_name ) model.graph.initializer.append(quantized_weight_initializer) - # MatMulInteger - # get matmulinteger inputs and outputs - matmul_integer_inputs = [ - input_quantize_node.input[0], # A matrix (replaces previous dequant node) - quantized_weight_name, # B matrix (quantized weight) - input_quantize_node.input[2], # a_zero_point - weight_quantize_node.input[2], # b_zero_point + # MatMulInteger/ConvInteger + # get inputs and outputs + integer_op_inputs = [ + input_quantize_node.input[0], # input matrix (replaces previous dequant node) + quantized_weight_name, # quantized weight + input_quantize_node.input[2], # input zero point + weight_quantize_node.input[2], # weight zero point ] - matmul_integer_output = "{}_quant".format(matmul_node.output[0]) - matmul_integer_name = "{}_quant".format(matmul_node.name) - - # create qmatmul node and add it to graph - matmul_integer_node = onnx.helper.make_node( - "MatMulInteger", - matmul_integer_inputs, - [matmul_integer_output], - matmul_integer_name, - ) - model.graph.node.append(matmul_integer_node) + integer_op_output = "{}_quant".format(node.output[0]) + integer_op_name = "{}_quant".format(node.name) + + # create MatMulInteger/ConvInteger node and add it to graph + if node.op_type == "Conv": + # get conv attributes as kwargs + conv_kwargs = {} + for attribute in node.attribute: + conv_kwargs.update(_attribute_to_kwarg(attribute)) + + # create ConvInteger node and add it to graph + integer_op_node = onnx.helper.make_node( + "ConvInteger", + integer_op_inputs, + [integer_op_output], + integer_op_name, + **conv_kwargs, + ) + else: + integer_op_node = onnx.helper.make_node( + "MatMulInteger", + integer_op_inputs, + [integer_op_output], + integer_op_name, + ) + model.graph.node.append(integer_op_node) # Add bias + zero point correction # quantize bias @@ -717,6 +742,9 @@ def _add_quantized_matmul_add_ops( quantized_bias = _quantize_array( bias_initializer, bias_scale, bias_zero_point, dtype=numpy.int32 ) + if node.op_type == "Conv" and len(quantized_bias.shape) == 1: + # reshape for bias add broadcasting + quantized_bias = quantized_bias.reshape(1, quantized_bias.shape[0], 1, 1) quantized_bias_name = "{}.bias_quantized".format(bias_add_name) quantized_bias_initializer = numpy_helper.from_array( @@ -739,11 +767,11 @@ def _add_quantized_matmul_add_ops( # get INT32 Add inputs and outputs quant_add_inputs = [ - matmul_integer_output, # MatMul integer outputs (INT32) + integer_op_output, # MatMul/Conv integer outputs (INT32) quantized_bias_name, # Quantized bias (INT32) ] - quant_add_name = "{}_bias_add_quant".format(matmul_node.name) + quant_add_name = "{}_bias_add_quant".format(node.name) quant_add_output = ( output_quantize_node.output[0] if output_quantize_node @@ -875,9 +903,9 @@ def _convert_quantizable_gemm_no_activations(model: ModelProto): _LOGGER.debug(f"Matched quantizable Gemm weight and bias: {gemm_node.name}") # Conversion - _add_quantized_matmul_add_ops( + _add_quantized_conv_matmul_add_ops( model=model, - matmul_node=gemm_node, + node=gemm_node, input_quantize_node=input_quantize_node, weight_quantize_node=weight_quantize_node, input_quantize_params=input_quantize_params, @@ -1021,9 +1049,9 @@ def _convert_quantizable_matmul_and_add(model: ModelProto): _LOGGER.debug(f"Matched quantizable MatMul weight and bias: {matmul_node.name}") # Conversion - _add_quantized_matmul_add_ops( + _add_quantized_conv_matmul_add_ops( model=model, - matmul_node=matmul_node, + node=matmul_node, input_quantize_node=input_quantize_node, weight_quantize_node=weight_quantize_node, input_quantize_params=input_quantize_params, @@ -1071,6 +1099,130 @@ def _convert_quantizable_matmul_and_add(model: ModelProto): graph.delete_unused_initializers() +def _convert_quantizable_conv_integer(model: ModelProto): + """ + A pass for converting a Conv op with kernel whose activations + are not necessarily quantized into a ConvInteger followed by + a bias add and cast to FP32 + + | Starting with: + | INPUT QuantizeLinear (with constant kernel) + | | | + | QuantizeLinear DequantizeLinear + | | | + | DequantizeLinear | + | | | + | Conv (with bias) + | | + | OUTPUT + | We end up converting to: + | INPUT + | | + | QuantizeLinear + | | + | ConvInteger (with constant uint8 kernel) + | | + | Add (constant bias + zero point correction) + | | + | Cast (INT32 -> FP32) + | | + | Mul (Rescale from bias scale) + | | + | OUTPUT + """ + + conversion_count = 0 + conv_nodes = [n for n in model.graph.node if n.op_type in ["Conv"]] + orig_conv_weight_name_to_node_ids = defaultdict(list) + for conv_node in conv_nodes: + if len(conv_node.input) != 3: + # this function currently only converts Conv nodes with bias param + # (i.e. from folded batch norm value) + continue + + graph = ONNXGraph(model) + + ############# + # Matching + ############# + weight_dequantize_node = graph.get_node_single_parent(conv_node, 1) + if ( + not weight_dequantize_node + or weight_dequantize_node.op_type != "DequantizeLinear" + ): + continue + weight_quantize_node = graph.get_node_single_parent(weight_dequantize_node, 0) + if not weight_quantize_node or weight_quantize_node.op_type != "QuantizeLinear": + continue + + input_quantize_node = graph.get_node_single_parent(conv_node, 0) + if ( + not input_quantize_node + or input_quantize_node.op_type not in _QUANTIZE_OP_NAMES + ): + continue + + input_quantize_params = get_quantization_params( + model, input_quantize_node, include_target=False + ) + weight_quantize_params = get_quantization_params( + model, weight_quantize_node, include_target=True + ) + if weight_quantize_params.target is None: + # weight initializer not included + continue + if input_quantize_node.op_type != "DequantizeLinear": + continue + + bias_initializer = graph.get_init_by_name(conv_node.input[2]) + if bias_initializer is None: + _LOGGER.debug(f"Unable to find bias initializer: {conv_node.input[2]}") + continue + + _LOGGER.debug(f"Matched quantizable Conv weight and bias: {conv_node.name}") + + # Conversion + _add_quantized_conv_matmul_add_ops( + model=model, + node=conv_node, + input_quantize_node=input_quantize_node, + weight_quantize_node=weight_quantize_node, + input_quantize_params=input_quantize_params, + weight_quantize_params=weight_quantize_params, + bias_initializer=bias_initializer, + bias_add_name="{}_bias_add".format(conv_node.name), + target_output=conv_node.output[0], + transpose_weight=False, + ) + orig_conv_weight_name_to_node_ids[input_quantize_node.input[0]].append( + "{}_quant".format(conv_node.output[0]) + ) + + # Cleanup + # delete folded quantization ops + delete_quant_node(model, weight_dequantize_node, keep_params=False) + delete_quant_node(model, weight_quantize_node, keep_params=True) + + # only delete input node if the conv is the only child + current_graph = ONNXGraph(model) + if len(current_graph.get_node_children(input_quantize_node)) == 1: + delete_quant_node(model, input_quantize_node, keep_params=True) + + # delete original Conv node + remove_node_and_params_from_graph(model, conv_node, keep_params=None) + + conversion_count += 1 + + if conv_nodes: + _LOGGER.info( + f"Converted {conversion_count} quantizable Conv ops with weight and bias " + "to ConvInteger and Add" + ) + _reduce_qconv_shared_weights(model, orig_conv_weight_name_to_node_ids) + graph = ONNXGraph(model) + graph.delete_unused_initializers() + + def _reduce_qconv_shared_weights( model: ModelProto, orig_qconv_weight_name_to_node_ids: Dict[str, List[NodeProto]] ): @@ -1080,10 +1232,17 @@ def _reduce_qconv_shared_weights( continue qconv_nodes = [graph.get_node_by_output_id(id) for id in node_ids] - if any(node.op_type != "QLinearConv" for node in qconv_nodes): + if any( + node.op_type not in ["QLinearConv", "ConvInteger"] for node in qconv_nodes + ): continue - weights = [graph.get_init_by_name(node.input[3]) for node in qconv_nodes] + weights = [ + graph.get_init_by_name( + node.input[3 if node.op_type == "QLinearConv" else 1] + ) + for node in qconv_nodes + ] if any(weight is None for weight in weights): continue @@ -1095,14 +1254,15 @@ def _reduce_qconv_shared_weights( weights[0], name=f"{weight_name}_quantized" ) for node in qconv_nodes: - node.input[3] = shared_weight.name + target_dim = 3 if node.op_type == "QLinearConv" else 1 + node.input[target_dim] = shared_weight.name model.graph.initializer.append(shared_weight) graph.update() graph.delete_unused_initializers() -def _convert_quantizable_ops(model: ModelProto): +def _convert_quantizable_ops(model: ModelProto, convert_qlinearconv: bool): quantizable_nodes = [n for n in model.graph.node if n.op_type in ["Conv", "Gemm"]] orig_qconv_weight_name_to_node_ids = defaultdict(list) for quantizable_node in quantizable_nodes: @@ -1123,7 +1283,7 @@ def _convert_quantizable_ops(model: ModelProto): if not output_quant or output_quant.op_type not in _QUANTIZE_OP_NAMES: continue - if quantizable_node.op_type == "Conv": + if convert_qlinearconv and quantizable_node.op_type == "Conv": weight_name = weight_quant.input[0] qconv_node = _convert_quantizable_conv( model, @@ -1299,7 +1459,10 @@ def _cleanup_unused_quants(model: ModelProto): ) dequant_children = graph.get_node_children(dequant_node) for child in dequant_children: - if isinstance(child, onnx.NodeProto) and child.op_type in _QLINEAR_OP_NAMES: + # check if any dequant children depend on having QDQ inputs + if isinstance(child, onnx.NodeProto) and ( + child.op_type in KEEP_QUANT_INPUT_OPS + ): removable = False if not removable: continue @@ -1323,11 +1486,16 @@ def quantize_torch_qat_export( model: Union[ModelProto, str], output_file_path: Union[str, None] = None, inplace: bool = True, + use_qlinearconv: bool = False, ) -> ModelProto: """ :param model: The model to convert, or a file path to it :param output_file_path: File path to save the converted model to :param inplace: If true, does conversion of model in place. Default is true + :param use_qlinearconv: Set True to use legacy QLinearConv format instead + of ConvInteger. QLinearConv requires output activations be quantized + in the quantization recipe. (This was the default behavior prior to + sparseml 0.12). Default is False :return: Converts a model exported from a torch QAT session from a QAT graph with fake quantize ops surrounding operations to a quantized graph with quantized operations. All quantized Convs and FC inputs and outputs be surrounded by @@ -1345,7 +1513,12 @@ def quantize_torch_qat_export( _delete_repeated_qat_blocks(model) _convert_quantizable_matmul(model) _convert_quantizable_matmul_and_add(model) - _convert_quantizable_ops(model) + + # only convert to either ConvInteger or QLinearConv (legacy) + if not use_qlinearconv: + _convert_quantizable_conv_integer(model) + _convert_quantizable_ops(model, convert_qlinearconv=use_qlinearconv) + _convert_quantizable_gemm_no_activations(model) _quantize_qat_embedding(model) quantize_resnet_identity_add_inputs(model) diff --git a/src/sparseml/pytorch/utils/exporter.py b/src/sparseml/pytorch/utils/exporter.py index b00ba190cda..a987287aba7 100644 --- a/src/sparseml/pytorch/utils/exporter.py +++ b/src/sparseml/pytorch/utils/exporter.py @@ -498,7 +498,15 @@ def export_onnx( quantize_torch_qat_export, ) - quantize_torch_qat_export(model=file_path, output_file_path=file_path) + use_qlinearconv = hasattr(module, "export_with_qlinearconv") and ( + module.export_with_qlinearconv + ) + + quantize_torch_qat_export( + model=file_path, + output_file_path=file_path, + use_qlinearconv=use_qlinearconv, + ) if skip_input_quantize: try: