From e996a9359fbc24db4925f7bdb8b5529c3f97f283 Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Mon, 12 Dec 2022 19:07:49 +0800 Subject: [PATCH] Export Qlinear to QDQ (#224) Signed-off-by: mengniwa --- .../adaptor/ox_utils/operators/__init__.py | 4 +- .../adaptor/ox_utils/operators/activation.py | 40 +- .../adaptor/ox_utils/operators/argmax.py | 12 +- .../adaptor/ox_utils/operators/attention.py | 47 +- .../adaptor/ox_utils/operators/binary_op.py | 47 +- .../adaptor/ox_utils/operators/concat.py | 41 +- .../adaptor/ox_utils/operators/conv.py | 70 +- .../adaptor/ox_utils/operators/direct_q8.py | 7 +- .../ox_utils/operators/embed_layernorm.py | 38 +- .../adaptor/ox_utils/operators/gather.py | 9 +- .../adaptor/ox_utils/operators/gavgpool.py | 37 +- .../adaptor/ox_utils/operators/gemm.py | 73 +- .../adaptor/ox_utils/operators/matmul.py | 49 +- .../adaptor/ox_utils/operators/maxpool.py | 9 +- .../adaptor/ox_utils/operators/ops.py | 78 ++- .../adaptor/ox_utils/operators/pad.py | 9 +- .../adaptor/ox_utils/operators/pooling.py | 38 +- .../adaptor/ox_utils/operators/resize.py | 6 +- .../adaptor/ox_utils/operators/split.py | 50 +- neural_compressor/config.py | 3 + .../experimental/export/__init__.py | 1 + .../experimental/export/qlinear2qdq.py | 85 +++ neural_compressor/model/onnx_model.py | 22 + test/export/test_onnx_qlieanr_to_qdq.py | 650 ++++++++++++++++++ 24 files changed, 1390 insertions(+), 35 deletions(-) create mode 100644 neural_compressor/experimental/export/qlinear2qdq.py create mode 100644 test/export/test_onnx_qlieanr_to_qdq.py diff --git a/neural_compressor/adaptor/ox_utils/operators/__init__.py b/neural_compressor/adaptor/ox_utils/operators/__init__.py index da48d428ac4..7b17ff45b5b 100644 --- a/neural_compressor/adaptor/ox_utils/operators/__init__.py +++ b/neural_compressor/adaptor/ox_utils/operators/__init__.py @@ -18,7 +18,7 @@ from os.path import dirname, basename, isfile, join import glob -from .ops import OPERATORS +from .ops import OPERATORS, QOPERATORS modules = glob.glob(join(dirname(__file__), "*.py")) @@ -26,4 +26,4 @@ if isfile(f) and not f.startswith('__') and not f.endswith('__init__.py'): __import__(basename(f)[:-3], globals(), locals(), level=1) -__all__ = ["OPERATORS"] \ No newline at end of file +__all__ = ["OPERATORS", "QOPERATORS"] diff --git a/neural_compressor/adaptor/ox_utils/operators/activation.py b/neural_compressor/adaptor/ox_utils/operators/activation.py index 5339f6834ad..cf677e61881 100644 --- a/neural_compressor/adaptor/ox_utils/operators/activation.py +++ b/neural_compressor/adaptor/ox_utils/operators/activation.py @@ -17,7 +17,7 @@ # import onnx -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, QOperator, qop_registry from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain @op_registry(op_types="LeakyRelu, Sigmoid") @@ -87,4 +87,40 @@ def quantize(self): self.quantizer.dequantize_tensor(node, node.input[0]) else: self.quantizer.model.replace_input_of_all_nodes(node.output[0], node.input[0]) - self.quantizer.remove_nodes.append(node) \ No newline at end of file + self.quantizer.remove_nodes.append(node) + +@qop_registry(op_types="QLinearLeakyRelu, QLinearSigmoid") +class QActivationOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) + + def convert(self): + node = self.node + add_nodes = [] + inits = [] + # input dq + in_dq = onnx.helper.make_node( + 'DequantizeLinear', + node.input[:3], + [node.name + '_in_dequant'], + node.name + '_in_dequant') + inputs = [node.name + '_in_dequant'] + add_nodes.append(in_dq) + # output q + out_q = onnx.helper.make_node( + 'QuantizeLinear', + [node.name + '_out', node.input[3], node.input[4]], + node.output, + node.name + '_out_quant') + outputs = [node.name + '_out'] + add_nodes.append(out_q) + + kwargs = {} + for attribute in node.attribute: # pragma: no cover + kwargs.update(attribute_to_kwarg(attribute)) + + activation_node = onnx.helper.make_node( + node.op_type.split('QLinear')[-1], inputs, + outputs, node.name + '_convert', **kwargs) + add_nodes.append(activation_node) + return True, add_nodes, inits \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/argmax.py b/neural_compressor/adaptor/ox_utils/operators/argmax.py index 9344498698e..65daf5b5523 100644 --- a/neural_compressor/adaptor/ox_utils/operators/argmax.py +++ b/neural_compressor/adaptor/ox_utils/operators/argmax.py @@ -15,9 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - - -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, QOperator, qop_registry @op_registry(op_types="ArgMax") class ArgMaxOperator(Operator): @@ -35,5 +33,9 @@ def convert(self, convert_format): origin_name = node.input[0].split('_argmax_node')[0] if origin_name in self.quantizer.quantized_value_map: - node.input[0] = self.quantizer.quantized_value_map[origin_name].q_name - node.name = node.name + '_quant' \ No newline at end of file + node.name = node.name + '_quant' + +@qop_registry(op_types="ArgMax") +class QArgMaxOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) diff --git a/neural_compressor/adaptor/ox_utils/operators/attention.py b/neural_compressor/adaptor/ox_utils/operators/attention.py index 9bd33ae4c26..26030e9284a 100644 --- a/neural_compressor/adaptor/ox_utils/operators/attention.py +++ b/neural_compressor/adaptor/ox_utils/operators/attention.py @@ -17,8 +17,8 @@ # import onnx -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator -from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, qop_registry, QOperator +from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain, find_by_name @op_registry(op_types="Attention") class AttentionOperator(Operator): @@ -74,3 +74,46 @@ def convert(self, convert_format): self.quantizer.new_nodes.append(qattention_node) self.quantizer.remove_nodes.append(node) + +@qop_registry(op_types="QAttention") +class QAttentionOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) + + def convert(self): + node = self.node + add_nodes = [] + inputs = [] + inits = [] + if find_by_name(node.input[3], self.initializers) is None: + return False, add_nodes, inits + # input dq + in_dq1 = onnx.helper.make_node( + 'DequantizeLinear', + [node.input[0], node.input[3], node.input[6]], + [node.name + '_in_dequant1'], + node.name + '_in_dequant1') + + in_dq2 = onnx.helper.make_node( + 'DequantizeLinear', + [node.input[1], node.input[4], node.input[7]], + [node.name + '_in_dequant2'], + node.name + '_in_dequant2') + inputs = [node.name + '_in_dequant1', + node.name + '_in_dequant2', + node.input[2], + node.input[5]] + + add_nodes.extend([in_dq1, in_dq2]) + + outputs = node.output + kwargs = {} + for attribute in node.attribute: # pragma: no cover + kwargs.update(attribute_to_kwarg(attribute)) + kwargs["domain"] = ms_domain + + binary_node = onnx.helper.make_node( + 'Attention', inputs, + outputs, node.name + '_convert', **kwargs) + add_nodes.append(binary_node) + return True, add_nodes, inits \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/binary_op.py b/neural_compressor/adaptor/ox_utils/operators/binary_op.py index 3848cd6ee9b..72c92da3dcf 100644 --- a/neural_compressor/adaptor/ox_utils/operators/binary_op.py +++ b/neural_compressor/adaptor/ox_utils/operators/binary_op.py @@ -17,7 +17,7 @@ # import onnx -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, QOperator, qop_registry from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain @op_registry(op_types="Add, Mul") @@ -77,4 +77,47 @@ def convert(self, convert_format): self.quantizer.new_nodes += [qlinear_binary_math_node] self.quantizer.remove_nodes.extend(parents) self.quantizer.remove_nodes.append(child) - self.quantizer.remove_nodes.append(node) \ No newline at end of file + self.quantizer.remove_nodes.append(node) + +@qop_registry(op_types="QLinearAdd, QLinearMul") +class QBinaryOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) + + def convert(self): + node = self.node + add_nodes = [] + inits = [] + # input dq + in_dq1 = onnx.helper.make_node( + 'DequantizeLinear', + node.input[:3], + [node.name + '_in_dequant1'], + node.name + '_in_dequant1') + + in_dq2 = onnx.helper.make_node( + 'DequantizeLinear', + node.input[3:6], + [node.name + '_in_dequant2'], + node.name + '_in_dequant2') + inputs = [node.name + '_in_dequant1', node.name + '_in_dequant2'] + + add_nodes.extend([in_dq1, in_dq2]) + # output q + out_q = onnx.helper.make_node( + 'QuantizeLinear', + [node.name + '_out', node.input[6], node.input[7]], + node.output, + node.name + '_out_quant') + outputs = [node.name + '_out'] + add_nodes.append(out_q) + + kwargs = {} + for attribute in node.attribute: # pragma: no cover + kwargs.update(attribute_to_kwarg(attribute)) + + binary_node = onnx.helper.make_node( + node.op_type.split('QLinear')[-1], inputs, + outputs, node.name + '_convert', **kwargs) + add_nodes.append(binary_node) + return True, add_nodes, inits diff --git a/neural_compressor/adaptor/ox_utils/operators/concat.py b/neural_compressor/adaptor/ox_utils/operators/concat.py index 763ac8e6541..eb85155421c 100644 --- a/neural_compressor/adaptor/ox_utils/operators/concat.py +++ b/neural_compressor/adaptor/ox_utils/operators/concat.py @@ -17,7 +17,7 @@ # import onnx -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, QOperator, qop_registry from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain @op_registry(op_types="Concat") @@ -96,3 +96,42 @@ def cast(self): # pragma: no cover if node.input[0] not in [i.tensor_name for i in self.quantizer.new_value_info.values()]: return self.quantizer.dtype_cast(self.node, self.dtype) + +@qop_registry(op_types="QLinearConcat") +class QConcatOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) + + def convert(self): + node = self.node + add_nodes = [] + inputs = [] + inits = [] + # input dq + for i in range(int((len(node.input) - 2) / 3 - 1)): + in_dq = onnx.helper.make_node( + 'DequantizeLinear', + node.input[2 + i*3 : 2 + (i+1)*3], + [node.name + '_in_dequant_' + str(i)], + node.name + '_in_dequant_' + str(i)) + inputs.append(node.name + '_in_dequant_' + str(i)) + add_nodes.append(in_dq) + + # output q + out_q = onnx.helper.make_node( + 'QuantizeLinear', + [node.name + '_out', node.input[0], node.input[1]], + node.output, + node.name + '_out_quant') + outputs = [node.name + '_out'] + add_nodes.append(out_q) + + kwargs = {} + for attribute in node.attribute: # pragma: no cover + kwargs.update(attribute_to_kwarg(attribute)) + + concat_node = onnx.helper.make_node( + 'Concat', inputs, + outputs, node.name + '_convert', **kwargs) + add_nodes.append(concat_node) + return True, add_nodes, inits diff --git a/neural_compressor/adaptor/ox_utils/operators/conv.py b/neural_compressor/adaptor/ox_utils/operators/conv.py index 90b849bd9e6..7f95d548b2a 100644 --- a/neural_compressor/adaptor/ox_utils/operators/conv.py +++ b/neural_compressor/adaptor/ox_utils/operators/conv.py @@ -19,7 +19,7 @@ import onnx from onnx import onnx_pb as onnx_proto -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, QOperator, qop_registry from neural_compressor.adaptor.ox_utils.util import find_by_name, attribute_to_kwarg @op_registry(op_types="Conv, FusedConv") @@ -156,6 +156,7 @@ def convert(self, convert_format): if attribute.name == 'activation_params': # pragma: no cover continue kwargs.update(attribute_to_kwarg(attribute)) + qlinear_conv_node = onnx.helper.make_node("QLinearConv", qlinear_conv_inputs, [qlinear_conv_output], node.name, **kwargs) @@ -164,4 +165,71 @@ def convert(self, convert_format): self.quantizer.remove_nodes.append(child) self.quantizer.remove_nodes.append(node) +@qop_registry(op_types="QLinearConv") +class QConvOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) + def convert(self): + node = self.node + add_nodes = [] + inits = [] + # input dq + in_dq1 = onnx.helper.make_node( + 'DequantizeLinear', + node.input[:3], + [node.name + '_in_dequant1'], + node.name + '_in_dequant1') + + in_dq2 = onnx.helper.make_node( + 'DequantizeLinear', + node.input[3:6], + [node.name + '_in_dequant2'], + node.name + '_in_dequant2') + + add_nodes.extend([in_dq1, in_dq2]) + inputs = [node.name + '_in_dequant1', node.name + '_in_dequant2'] + if len(node.input) == 9: + import numpy as np + input_scale = onnx.numpy_helper.to_array( + find_by_name(node.input[1], self.initializers)) + weight_scale = onnx.numpy_helper.to_array( + find_by_name(node.input[4], self.initializers)) + bias_scale = input_scale * weight_scale + + # update scale initializer + bias_scale_data = np.asarray(bias_scale, dtype=np.float32).reshape(-1) + bias_scale_initializer = onnx.numpy_helper.from_array(bias_scale_data, + node.input[8] + '_scale') + inits.extend([bias_scale_initializer]) + + # update zero initializer + bias_zp_data = np.zeros(bias_scale.shape, dtype=np.int32).reshape(-1) + bias_zp_initializer = onnx.numpy_helper.from_array( + bias_zp_data, node.input[8] + '_zero_point') + inits.extend([bias_zp_initializer]) + in_dq3 = onnx.helper.make_node( + 'DequantizeLinear', + [node.input[8], bias_scale_initializer.name, bias_zp_initializer.name], + [node.name + '_in_dequant3'], + node.name + '_in_dequant3') + inputs.append(in_dq3.name) + add_nodes.append(in_dq3) + # output q + out_q = onnx.helper.make_node( + 'QuantizeLinear', + [node.name + '_out', node.input[6], node.input[7]], + node.output, + node.name + '_out_quant') + outputs = [node.name + '_out'] + add_nodes.append(out_q) + + kwargs = {} + for attribute in node.attribute: # pragma: no cover + kwargs.update(attribute_to_kwarg(attribute)) + + binary_node = onnx.helper.make_node( + node.op_type.split('QLinear')[-1], inputs, + outputs, node.name + '_convert', **kwargs) + add_nodes.append(binary_node) + return True, add_nodes, inits diff --git a/neural_compressor/adaptor/ox_utils/operators/direct_q8.py b/neural_compressor/adaptor/ox_utils/operators/direct_q8.py index 00522c178a1..08a6e5a326b 100644 --- a/neural_compressor/adaptor/ox_utils/operators/direct_q8.py +++ b/neural_compressor/adaptor/ox_utils/operators/direct_q8.py @@ -16,7 +16,7 @@ # limitations under the License. # -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, qop_registry, QOperator @op_registry(op_types="Reshape, Transpose, Squeeze, Unsqueeze") class Direct8BitOperator(Operator): @@ -83,3 +83,8 @@ def cast(self): if node.input[0] not in [i.tensor_name for i in self.quantizer.new_value_info.values()]: return self.quantizer.dtype_cast(self.node, self.dtype) + +@qop_registry(op_types="Reshape, Transpose, Squeeze, Unsqueeze") +class QDirectOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/embed_layernorm.py b/neural_compressor/adaptor/ox_utils/operators/embed_layernorm.py index 256298b7142..91310f9e15d 100644 --- a/neural_compressor/adaptor/ox_utils/operators/embed_layernorm.py +++ b/neural_compressor/adaptor/ox_utils/operators/embed_layernorm.py @@ -17,7 +17,7 @@ # import onnx -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, QOperator, qop_registry from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain @op_registry(op_types="EmbedLayerNormalization") @@ -69,4 +69,38 @@ def convert(self, convert_format): inputs, node.output, node.name, **kwargs) self.quantizer.new_nodes.append(qembed_layer_norm_node) - self.quantizer.remove_nodes.extend(parents) \ No newline at end of file + self.quantizer.remove_nodes.extend(parents) + self.quantizer.remove_nodes.append(node) + +@qop_registry(op_types="QEmbedLayerNormalization") +class QEmbedLayerNormalizationOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) + + def convert(self): + node = self.node + add_nodes = [] + inits = [] + inputs = [node.input[0], node.input[1]] + # input dq + for i in range(5): + in_dq = onnx.helper.make_node( + 'DequantizeLinear', + [node.input[2+i], node.input[-10+i], node.input[-5+i]], + [node.name + '_in_dequant_' + str(i)], + node.name + '_in_dequant_' + str(i)) + inputs.append(node.name + '_in_dequant_' + str(i)) + add_nodes.append(in_dq) + if len(node.input) > 17: + inputs.append(node.input[7]) + + outputs = node.output + kwargs = {} + for attribute in node.attribute: # pragma: no cover + kwargs.update(attribute_to_kwarg(attribute)) + + binary_node = onnx.helper.make_node( + 'EmbedLayerNormalization', inputs, + outputs, node.name + '_convert', **kwargs) + add_nodes.append(binary_node) + return True, add_nodes, inits \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/gather.py b/neural_compressor/adaptor/ox_utils/operators/gather.py index 93f98823047..7c3c6285b45 100644 --- a/neural_compressor/adaptor/ox_utils/operators/gather.py +++ b/neural_compressor/adaptor/ox_utils/operators/gather.py @@ -17,7 +17,7 @@ # import onnx -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, QOperator, qop_registry from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg @op_registry(op_types="Gather") @@ -89,4 +89,9 @@ def convert(self, convert_format): for n in self.quantizer.model.get_children(child): self.quantizer.model.replace_node_input(n, child.output[0], gather_new_output) - self.quantizer.remove_nodes.extend([node, parents[0]]) \ No newline at end of file + self.quantizer.remove_nodes.extend([node, parents[0]]) + +@qop_registry(op_types="Gather") +class QGatherOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) diff --git a/neural_compressor/adaptor/ox_utils/operators/gavgpool.py b/neural_compressor/adaptor/ox_utils/operators/gavgpool.py index b4bafcafeae..eec48e6af19 100644 --- a/neural_compressor/adaptor/ox_utils/operators/gavgpool.py +++ b/neural_compressor/adaptor/ox_utils/operators/gavgpool.py @@ -17,7 +17,7 @@ # import onnx -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, QOperator, qop_registry from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain @op_registry(op_types="GlobalAveragePool") @@ -58,4 +58,37 @@ def convert(self, convert_format): self.quantizer.new_nodes += [qnode] self.quantizer.remove_nodes.append(child) self.quantizer.remove_nodes.append(parent) - self.quantizer.remove_nodes.append(node) \ No newline at end of file + self.quantizer.remove_nodes.append(node) + +@qop_registry(op_types="QLinearGlobalAveragePool") +class QGlobalAveragePoolOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) + + def convert(self): + node = self.node + add_nodes = [] + inits = [] + # input dq + in_dq = onnx.helper.make_node( + 'DequantizeLinear', + node.input[:3], + [node.name + '_in_dequant'], + node.name + '_in_dequant') + inputs = [node.name + '_in_dequant'] + add_nodes.append(in_dq) + # output q + out_q = onnx.helper.make_node( + 'QuantizeLinear', + [node.name + '_out', node.input[3], node.input[4]], + node.output, + node.name + '_out_quant') + outputs = [node.name + '_out'] + add_nodes.append(out_q) + + kwargs = {} + activation_node = onnx.helper.make_node( + 'GlobalAveragePool', inputs, + outputs, node.name + '_convert', **kwargs) + add_nodes.append(activation_node) + return True, add_nodes, inits diff --git a/neural_compressor/adaptor/ox_utils/operators/gemm.py b/neural_compressor/adaptor/ox_utils/operators/gemm.py index 65aca2e8a7d..49f8eeaa6c7 100644 --- a/neural_compressor/adaptor/ox_utils/operators/gemm.py +++ b/neural_compressor/adaptor/ox_utils/operators/gemm.py @@ -17,7 +17,7 @@ # import onnx -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, QOperator, qop_registry from neural_compressor.adaptor.ox_utils.util import find_by_name, ms_domain, \ attribute_to_kwarg, is_B_transposed @@ -91,4 +91,73 @@ def convert(self, convert_format): self.quantizer.new_nodes.append(qgemm_node) self.quantizer.remove_nodes.extend(parents) self.quantizer.remove_nodes.append(child) - self.quantizer.remove_nodes.append(node) \ No newline at end of file + self.quantizer.remove_nodes.append(node) + +@qop_registry(op_types="QGemm") +class QGemmOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) + + def convert(self): + import numpy as np + node = self.node + add_nodes = [] + inits = [] + + input_scale = onnx.numpy_helper.to_array( + find_by_name(node.input[1], self.initializers)) + weight_scale = onnx.numpy_helper.to_array( + find_by_name(node.input[4], self.initializers)) + bias_scale = input_scale * weight_scale + + # input dq + in_dq1 = onnx.helper.make_node( + 'DequantizeLinear', + node.input[:3], + [node.name + '_in_dequant1'], + node.name + '_in_dequant1') + + + in_dq2 = onnx.helper.make_node( + 'DequantizeLinear', + node.input[3:6], + [node.name + '_in_dequant2'], + node.name + '_in_dequant2') + + # update scale initializer + bias_scale_data = np.asarray(bias_scale, dtype=np.float32).reshape(-1) + bias_scale_initializer = onnx.numpy_helper.from_array(bias_scale_data, + node.input[6] + '_scale') + inits.extend([bias_scale_initializer]) + + # update zero initializer + bias_zp_data = np.zeros(bias_scale.shape, dtype=np.int32).reshape(-1) + bias_zp_initializer = onnx.numpy_helper.from_array( + bias_zp_data, node.input[6] + '_zero_point') + inits.extend([bias_zp_initializer]) + in_dq3 = onnx.helper.make_node( + 'DequantizeLinear', + [node.input[8], bias_scale_initializer.name, bias_zp_initializer.name], + [node.name + '_in_dequant3']) + + inputs = [in_dq1.name, in_dq2.name, in_dq3.name] + add_nodes.extend([in_dq1, in_dq2, in_dq3]) + + # output q + out_q = onnx.helper.make_node( + 'QuantizeLinear', + [node.name + '_out', node.input[6], node.input[7]], + node.output, + node.name + '_out_quant') + outputs = [node.name + '_out'] + add_nodes.append(out_q) + + kwargs = {} + for attribute in node.attribute: # pragma: no cover + kwargs.update(attribute_to_kwarg(attribute)) + + gemm_node = onnx.helper.make_node( + 'Gemm', inputs, + outputs, node.name + '_convert', **kwargs) + add_nodes.append(gemm_node) + return True, add_nodes, inits \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/matmul.py b/neural_compressor/adaptor/ox_utils/operators/matmul.py index 988e157e323..fbf6558bb02 100644 --- a/neural_compressor/adaptor/ox_utils/operators/matmul.py +++ b/neural_compressor/adaptor/ox_utils/operators/matmul.py @@ -17,8 +17,8 @@ # import onnx -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator -from neural_compressor.adaptor.ox_utils.util import find_by_name +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, QOperator, qop_registry +from neural_compressor.adaptor.ox_utils.util import find_by_name, attribute_to_kwarg from onnx import onnx_pb as onnx_proto @op_registry(op_types="MatMul") @@ -122,4 +122,47 @@ def convert(self, convert_format): self.quantizer.new_nodes.append(qlinear_matmul_node) self.quantizer.remove_nodes.extend(parents) self.quantizer.remove_nodes.append(child) - self.quantizer.remove_nodes.append(node) \ No newline at end of file + self.quantizer.remove_nodes.append(node) + +@qop_registry(op_types="QLinearMatMul") +class QMatMulOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) + + def convert(self): + node = self.node + add_nodes = [] + inits = [] + # input dq + in_dq1 = onnx.helper.make_node( + 'DequantizeLinear', + node.input[:3], + [node.name + '_in_dequant1'], + node.name + '_in_dequant1') + + in_dq2 = onnx.helper.make_node( + 'DequantizeLinear', + node.input[3:6], + [node.name + '_in_dequant2'], + node.name + '_in_dequant2') + inputs = [node.name + '_in_dequant1', node.name + '_in_dequant2'] + + add_nodes.extend([in_dq1, in_dq2]) + # output q + out_q = onnx.helper.make_node( + 'QuantizeLinear', + [node.name + '_out', node.input[6], node.input[7]], + node.output, + node.name + '_out_quant') + outputs = [node.name + '_out'] + add_nodes.append(out_q) + + kwargs = {} + for attribute in node.attribute: # pragma: no cover + kwargs.update(attribute_to_kwarg(attribute)) + + matmul_node = onnx.helper.make_node( + 'MatMul', inputs, + outputs, node.name + '_convert', **kwargs) + add_nodes.append(matmul_node) + return True, add_nodes, inits diff --git a/neural_compressor/adaptor/ox_utils/operators/maxpool.py b/neural_compressor/adaptor/ox_utils/operators/maxpool.py index f93befc9a4f..3180a6a49f1 100644 --- a/neural_compressor/adaptor/ox_utils/operators/maxpool.py +++ b/neural_compressor/adaptor/ox_utils/operators/maxpool.py @@ -16,7 +16,7 @@ # limitations under the License. # -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, QOperator, qop_registry @op_registry(op_types="MaxPool") class MaxPoolOperator(Operator): @@ -67,4 +67,9 @@ def convert(self, convert_format): self.quantizer.model.replace_node_input(n, child.output[0], node.output[0]) - self.quantizer.remove_nodes.append(parent) \ No newline at end of file + self.quantizer.remove_nodes.append(parent) + +@qop_registry(op_types="MaxPool") +class QMaxPoolOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/ops.py b/neural_compressor/adaptor/ox_utils/operators/ops.py index 33d4ecf7c5d..ad6237b2d41 100644 --- a/neural_compressor/adaptor/ox_utils/operators/ops.py +++ b/neural_compressor/adaptor/ox_utils/operators/ops.py @@ -15,8 +15,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from neural_compressor.utils.utility import LazyImport +from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg +onnx = LazyImport('onnx') OPERATORS = {} +QOPERATORS= {} def op_registry(op_types): '''The class decorator used to register all Operator subclasses. @@ -34,6 +38,27 @@ def decorator_op(cls): return cls return decorator_op +def qop_registry(op_types): + '''The class decorator used to register all qOperator subclasses. + + Args: + cls (class): The class of register. + ''' + def decorator_op(cls): + assert cls.__name__.endswith( + 'Operator'), "The name of subclass of QOperator should end with \'Operator\' substring." + if cls.__name__[:-len('Operator')] in QOPERATORS: # pragma: no cover + raise ValueError('Cannot have two operators with the same name.') + for single_op_type in [op_type.strip() for op_type in op_types.split(',')]: + if single_op_type.startswith('QLinear') or \ + single_op_type in ['QGemm', 'QAttention', 'QEmbedLayerNormalization', 'ArgMax', + 'Reshape', 'Transpose', 'Squeeze', 'Unsqueeze', 'Gather', + 'MaxPool', 'Pad', 'Resize', 'Split']: + QOPERATORS[single_op_type] = cls + return cls + return decorator_op + + class Operator(object): def __init__(self, onnx_quantizer, onnx_node): self.quantizer = onnx_quantizer @@ -81,4 +106,55 @@ def convert(self, convert_format): return def cast(self): # pragma: no cover - self.quantizer.dtype_cast(self.node, self.dtype) \ No newline at end of file + self.quantizer.dtype_cast(self.node, self.dtype) + +class QOperator(object): + def __init__(self, onnx_node, children, initializers): + self.node = onnx_node + self.children = children + self.initializers = initializers + self.qop_list = ['QGemm', 'QAttention', 'QEmbedLayerNormalization', + 'QLinearLeakyRelu', 'QLinearSigmoid', 'QLinearAdd','QLinearMul', + 'QLinearConcat', 'QLinearConv', 'QLinearGlobalAveragePool', + 'QLinearMatMul', 'QLinearAveragePool'] + + def convert(self): + node = self.node + add_nodes = [] + inputs = [] + inits = [] + if all([child.op_type not in self.qop_list or \ + child.op_type != 'DequantizeLinear' for child in self.children]): + return False, add_nodes, inits + + # input dq + for child in self.children: + if child.op_type == 'DequantizeLinear': + in_dq = onnx.helper.make_node( + 'DequantizeLinear', + [node.input[0], child.input[1], child.input[2]], + [node.name + '_in_dequant'], + node.name + '_in_dequant') + inputs.append(node.name + '_in_dequant') + add_nodes.append(in_dq) + break + + # output q + out_q = onnx.helper.make_node( + 'QuantizeLinear', + [node.name + '_out', in_dq.input[1], in_dq.input[2]], + node.output, + node.name + '_out_quant') + outputs = [node.name + '_out'] + add_nodes.append(out_q) + + kwargs = {} + for attribute in node.attribute: # pragma: no cover + kwargs.update(attribute_to_kwarg(attribute)) + + inputs.append(node.input[1:]) + new_node = onnx.helper.make_node( + node.op_type, inputs, + outputs, node.name + '_convert', **kwargs) + add_nodes.append(new_node) + return True, add_nodes, inits \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/pad.py b/neural_compressor/adaptor/ox_utils/operators/pad.py index 0f0acfcbec7..00bb38a3bbd 100644 --- a/neural_compressor/adaptor/ox_utils/operators/pad.py +++ b/neural_compressor/adaptor/ox_utils/operators/pad.py @@ -17,7 +17,7 @@ # import onnx -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, QOperator, qop_registry from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, quantize_nparray @op_registry(op_types="Pad") @@ -93,4 +93,9 @@ def convert(self, convert_format): # Create an entry for output quantized value node.input[0] = parent.input[0] node.output[0] = child.output[0] - self.quantizer.remove_nodes.extend([parent, child]) \ No newline at end of file + self.quantizer.remove_nodes.extend([parent, child]) + +@qop_registry(op_types="Pad") +class QPadOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/pooling.py b/neural_compressor/adaptor/ox_utils/operators/pooling.py index bba746129e6..a794dec7018 100644 --- a/neural_compressor/adaptor/ox_utils/operators/pooling.py +++ b/neural_compressor/adaptor/ox_utils/operators/pooling.py @@ -17,7 +17,7 @@ # import onnx -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, QOperator, qop_registry from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain @op_registry(op_types="AveragePool") @@ -80,3 +80,39 @@ def convert(self, convert_format): self.quantizer.new_nodes.append(qnode) self.quantizer.remove_nodes.append(node) + +@qop_registry(op_types="QLinearAveragePool") +class QPoolOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) + + def convert(self): + node = self.node + add_nodes = [] + inits = [] + # input dq + in_dq = onnx.helper.make_node( + 'DequantizeLinear', + node.input[:3], + [node.name + '_in_dequant'], + node.name + '_in_dequant') + inputs = [node.name + '_in_dequant'] + add_nodes.append(in_dq) + # output q + out_q = onnx.helper.make_node( + 'QuantizeLinear', + [node.name + '_out', node.input[3], node.input[4]], + node.output, + node.name + '_out_quant') + outputs = [node.name + '_out'] + add_nodes.append(out_q) + + kwargs = {} + for attribute in node.attribute: # pragma: no cover + kwargs.update(attribute_to_kwarg(attribute)) + + activation_node = onnx.helper.make_node( + 'AveragePool', inputs, + outputs, node.name + '_convert', **kwargs) + add_nodes.append(activation_node) + return True, add_nodes, inits diff --git a/neural_compressor/adaptor/ox_utils/operators/resize.py b/neural_compressor/adaptor/ox_utils/operators/resize.py index d5f906f8372..7d266c7a5a5 100644 --- a/neural_compressor/adaptor/ox_utils/operators/resize.py +++ b/neural_compressor/adaptor/ox_utils/operators/resize.py @@ -16,7 +16,7 @@ # limitations under the License. # -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, QOperator, qop_registry @op_registry(op_types="Resize") class ResizeOperator(Operator): @@ -70,3 +70,7 @@ def convert(self, convert_format): child.output[0], node.output[0] + '_quantized') node.output[0] = node.output[0] + '_quantized' +@qop_registry(op_types="Resize") +class QResizeOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) \ No newline at end of file diff --git a/neural_compressor/adaptor/ox_utils/operators/split.py b/neural_compressor/adaptor/ox_utils/operators/split.py index a5ec5532711..d022fd3d4c1 100644 --- a/neural_compressor/adaptor/ox_utils/operators/split.py +++ b/neural_compressor/adaptor/ox_utils/operators/split.py @@ -17,7 +17,7 @@ # import onnx -from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator +from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, QOperator, qop_registry from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg @@ -81,3 +81,51 @@ def cast(self): # pragma: no cover if node.input[0] not in [i.tensor_name for i in self.quantizer.new_value_info.values()]: return self.quantizer.dtype_cast(self.node, self.dtype) + +@qop_registry(op_types="Split") +class QSplitOperator(QOperator): + def __init__(self, onnx_node, children, initializers): + super().__init__(onnx_node, children, initializers) + + def convert(self): + node = self.node + add_nodes = [] + inputs = [] + inits = [] + + if all([child.op_type not in self.qop_list or \ + child.op_type != 'DequantizeLinear' for child in self.children]): + return False, add_nodes, inits + + # input dq + for child in self.children: + if child.op_type == 'DequantizeLinear': + in_dq = onnx.helper.make_node( + 'DequantizeLinear', + [node.input[0], child.input[1], child.input[2]], + [node.name + '_in_dequant'], + node.name + '_in_dequant') + inputs.append(node.name + '_in_dequant') + add_nodes.append(in_dq) + break + + outputs = [] + for i, out in enumerate(node.output): + out_q = onnx.helper.make_node( + 'QuantizeLinear', + [node.name + '_out_' + str(i), in_dq.input[1], in_dq.input[2]], + [node.output[i]], + node.name + '_out_quant_' + str(i)) + outputs.append([node.name + '_out_quant_' + str(i)]) + add_nodes.append(out_q) + + outputs = node.output + kwargs = {} + for attribute in node.attribute: # pragma: no cover + kwargs.update(attribute_to_kwarg(attribute)) + + gather_node = onnx.helper.make_node( + node.op_type, inputs, + outputs, node.name + '_convert', **kwargs) + add_nodes.append(gather_node) + return True, add_nodes, inits \ No newline at end of file diff --git a/neural_compressor/config.py b/neural_compressor/config.py index 295890a61ca..dd79b45db56 100644 --- a/neural_compressor/config.py +++ b/neural_compressor/config.py @@ -884,6 +884,9 @@ def dynamic_axes(self): def dynamic_axes(self, dynamic_axes): self._dynamic_axes = dynamic_axes +class ONNXQlinear2QDQConfig: + def __init__(self): + pass class Torch2ONNXConfig(ExportConfig): def __init__( diff --git a/neural_compressor/experimental/export/__init__.py b/neural_compressor/experimental/export/__init__.py index 2ccf049bf64..529ea48ed35 100644 --- a/neural_compressor/experimental/export/__init__.py +++ b/neural_compressor/experimental/export/__init__.py @@ -19,3 +19,4 @@ """Intel Neural Compressor Export.""" from .torch2onnx import torch_to_fp32_onnx, torch_to_int8_onnx +from .qlinear2qdq import onnx_qlinear_to_qdq diff --git a/neural_compressor/experimental/export/qlinear2qdq.py b/neural_compressor/experimental/export/qlinear2qdq.py new file mode 100644 index 00000000000..10c0b74d7ef --- /dev/null +++ b/neural_compressor/experimental/export/qlinear2qdq.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions to export onnx model from QLinearops to QDQ.""" + +from neural_compressor.utils import logger +from neural_compressor.adaptor.ox_utils.util import find_by_name +from neural_compressor.utils.utility import LazyImport + +numpy_helper = LazyImport('onnx.numpy_helper') + +def check_model(model): + """Check optype for input model. + + Args: + model (ModelProto): onnx model. + """ + has_integerop = False + has_qlinearop = False + for node in model.graph.node: + if node.op_type.endswith('Integer'): + has_integerop = True + elif node.op_type.startswith('QLinear'): + has_qlinearop = True + elif node.op_type in ['QAttention', 'QGemm', 'QEmbedLayerNormalization']: + has_qlinearop = True + elif node.op_type in ['Gather']: + input_data = find_by_name(node.input[0], model.graph.initializer) + if input_data is not None and \ + numpy_helper.to_array(input_data).dtype in ['int8', 'uint8']: + has_qlinearop = True + if has_integerop: + logger.info("This model has Integer ops, these ops will be skipped.") + if has_qlinearop: + return True + else: + logger.info("This model has no QLinear ops, save the original model.") + return False + +def onnx_qlinear_to_qdq( + model, + input_name_to_nodes, +): + """Export ONNX QLinearops model into QDQ model. + + Args: + model (ModelProto): int8 onnx model. + input_name_to_nodes (dict): the mapping of tensor name and its destination nodes. + """ + from neural_compressor.adaptor.ox_utils.operators import QOPERATORS + add_nodes = [] + remove_nodes = [] + inits = [] + if check_model(model): + for node in model.graph.node: + if node.op_type in QOPERATORS: + if node.output[0] not in input_name_to_nodes: + continue + children = [] + for out in node.output: + children.extend(input_name_to_nodes[node.output[0]]) + converter = QOPERATORS[node.op_type]( + node, + children, + model.graph.initializer) + done, add_node, init = converter.convert() + if done: + add_nodes.extend(add_node) + inits.extend(init) + remove_nodes.append(node) + return add_nodes, remove_nodes, inits diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index 90fcda508c6..a090412b171 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -154,6 +154,10 @@ def add_initializer(self, tensor): if ortq.find_by_name(tensor.name, self._model.graph.initializer) is None: self._model.graph.initializer.extend([tensor]) + def add_initializers(self, tensors): + for tensor in tensors: + self.add_initializer(tensor) + def get_initializer(self, name): for tensor in self._model.graph.initializer: if tensor.name == name: @@ -423,3 +427,21 @@ def get_nodes_chain(self, start_node, stop_node, result_chain=[]): start_node.append(parent.name) return result_chain + + def export(self, save_path, conf): + from neural_compressor.experimental.export import onnx_qlinear_to_qdq + from neural_compressor.config import ONNXQlinear2QDQConfig + if isinstance(conf, ONNXQlinear2QDQConfig): + add_nodes, remove_nodes, inits = onnx_qlinear_to_qdq(self._model, + self._input_name_to_nodes) + self.add_nodes(add_nodes) + self.remove_nodes(remove_nodes) + self.add_initializers(inits) + self.update() + self.remove_unused_constant() + self.topological_sort() + self.save(save_path) + else: + logger.warning("Unsupported config for export, " + "only ONNXQlinear2QDQConfig is supported!") + exit(0) diff --git a/test/export/test_onnx_qlieanr_to_qdq.py b/test/export/test_onnx_qlieanr_to_qdq.py new file mode 100644 index 00000000000..63018f12b4a --- /dev/null +++ b/test/export/test_onnx_qlieanr_to_qdq.py @@ -0,0 +1,650 @@ +import os +import shutil +import unittest +import copy +import onnx +import numpy as np +from onnx import helper, TensorProto, numpy_helper, onnx_pb +from neural_compressor.adaptor.ox_utils.quantizer import Quantizer +from neural_compressor.adaptor.ox_utils.util import QuantizedInitializer, QuantizedValue, QuantizationMode +import onnxruntime as ort +from neural_compressor import options +from neural_compressor.config import ONNXQlinear2QDQConfig +from neural_compressor.experimental.common import Model + +def build_model(): + initializers = [] + input = helper.make_tensor_value_info('input', TensorProto.FLOAT, [1, 3, 15, 15]) + output = helper.make_tensor_value_info('reshape_output', TensorProto.FLOAT, [88, 11]) + + add_node = onnx.helper.make_node('Add', ['input', 'add_init'], ['add_out'], name='add') + + conv1_weight_initializer = numpy_helper.from_array( + np.random.randint(-1, 2, [3, 3, 3, 3]).astype(np.float32), name='conv1_weight') + conv1_node = helper.make_node('Conv', ['add_out', 'conv1_weight'], ['conv1_output'], name='conv1') + + conv2_weight_initializer = numpy_helper.from_array( + np.random.randint(-1, 2, [5, 3, 3, 3]).astype(np.float32), name='conv2_weight') + conv2_node = helper.make_node('Conv', ['add_out', 'conv2_weight'], ['conv2_output'], name='conv2') + + # 1, 8, 13, 13 + concat_node = helper.make_node('Concat', ['conv1_output', 'conv2_output'], [ + 'concat_output'], name='Concat', axis=1) + # 1, 8, 11, 11 + avg_args = {'kernel_shape': [3, 3]} + avgpool_node = helper.make_node('AveragePool', ['concat_output'], ['avg_output'], name='AveragePool', **avg_args) + reshape_node = onnx.helper.make_node('Reshape', ['avg_output', 'shape'], ['reshape_output'], name='Reshape') + + initializers = [conv1_weight_initializer, conv2_weight_initializer] + initializers.append(onnx.numpy_helper.from_array(np.array([88, 11], dtype=np.int64), name='shape')) + initializers.append(onnx.numpy_helper.from_array(np.zeros((1, 3, 15, 15), dtype=np.float32), name='add_init')) + graph = helper.make_graph([conv1_node, conv2_node, concat_node, avgpool_node, reshape_node, add_node], + 'test', [input], [output], initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + return model + +class TestAdaptorONNXRT(unittest.TestCase): + + qlinear_backend = QuantizationMode.QLinearOps + qdq_backend = 'qdqops' + integer_backend = QuantizationMode.IntegerOps + static_q_config = {"weight":{'dtype': 3, + 'algorithm': 'minmax', + 'scheme':'sym', + 'granularity': 'per_tensor'}, + 'activation':{'dtype': 2, + 'algorithm': 'minmax', + 'scheme':'asym', + 'granularity':'per_tensor', + 'quant_mode': 'static'} + } + dynamic_q_config = {"weight":{'dtype': 3, + 'algorithm': 'minmax', + 'scheme':'sym', + 'granularity': 'per_tensor'}, + 'activation':{'dtype': 2, + 'algorithm': 'minmax', + 'scheme':'asym', + 'granularity':'per_tensor', + 'quant_mode': 'dynamic'}} + config = ONNXQlinear2QDQConfig() + + @classmethod + def setUpClass(cls): + os.makedirs('./onnxrt_test') + + @classmethod + def tearDownClass(cls): + shutil.rmtree("./onnxrt_test", ignore_errors=True) + os.remove("test.onnx") + + def qlinear_test(self, model, q_config, quantize_params, quantizable_op_types): + quantizer = Quantizer(copy.deepcopy(model), + q_config, + self.qlinear_backend, + True, + quantize_params, + quantizable_op_types) + model = quantizer.quantize_model() + return Model(model) + + def dynamic_test(self, model, q_config, quantize_params, quantizable_op_types): + quantizer = Quantizer(copy.deepcopy(model), + q_config, + self.integer_backend, + False, + quantize_params, + quantizable_op_types) + quantizer.quantize_model() + return Model(model) + + def test_argmax(self): + input_name = "input" + output_name = "output" + input_shape = [1, 256, 128, 128] + output_shape = [1, 32, 128] + initializers = [] + + # make Conv node + conv_weight_name = "conv_weight" + conv_weight_arr = np.random.randint(-1, 2, [32, 256, 1, 1]).astype(np.float32) + conv_weight_initializer = onnx.numpy_helper.from_array(conv_weight_arr, name=conv_weight_name) + conv_output_name = "conv_output" + conv_inputs = [input_name, conv_weight_name] + conv_outputs = [conv_output_name] + conv_name = "conv_node" + conv_node = onnx.helper.make_node( + "Conv", + conv_inputs, + conv_outputs, + dilations=[1, 1], + kernel_shape=[1, 1], + pads=[0, 0, 0, 0], + strides=[1, 1], + name=conv_name, + ) + + # make ArgMax node + argmax_inputs = [conv_output_name] + argmax_outputs = [output_name] + argmax_name = "argmax_node" + argmax_node = onnx.helper.make_node( + "ArgMax", + argmax_inputs, + argmax_outputs, + axis=3, + keepdims=0, + name=argmax_name, + ) + + initializers = [conv_weight_initializer] + + # make graph + input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, input_shape) + output_tensor = helper.make_tensor_value_info(output_name, TensorProto.INT64, output_shape) + graph_name = "ArgMax_Quant_Test" + graph = helper.make_graph( + [conv_node, argmax_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 7 # use stable onnx ir version + q_config = {'conv_node': self.static_q_config, + 'argmax_node': self.static_q_config} + quantize_params = {'input': [np.uint8(0), np.float32(10.)], + 'conv_weight': [np.uint8(0), np.float32(10.)], + 'conv_output': [np.uint8(0), np.float32(10.)], + 'output': [np.uint8(0), np.float32(10.)], + } + q_model = self.qlinear_test(model, q_config, quantize_params, ['Conv', 'ArgMax']) + q_model.export('./test.onnx', self.config) + + def test_gemm(self): + input_name = "input" + output_name = "output" + initializers = [] + weight_shape = [100, 10] + weight_name = "linear1.weight" + bias_shape = [100] + bias_name = "linear1.bias" + node_name = "gemm" + + weight_data = np.random.normal(0, 0.1, weight_shape).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) + + bias_data = np.random.normal(0, 0.1, bias_shape).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(bias_data, name=bias_name)) + + gemm1_node = onnx.helper.make_node( + "Gemm", + [input_name, weight_name, bias_name], + [output_name], + alpha=1.0, + beta=1.0, + transB=1, + name=node_name + ) + + gemm1_output_name = "gemm1_output" + input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [-1, 10]) + output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [-1, 100]) + graph_name = "gemm_test" + graph = helper.make_graph( + [gemm1_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 7 # use stable onnx ir version + q_config = {'gemm': self.static_q_config} + quantize_params = {'input': [np.uint8(0), np.float32(10.)], + 'linear1.weight': [np.uint8(0), np.float32(10.)], + 'linear1.bias': [np.uint8(0), np.float32(10.)], + 'output': [np.uint8(0), np.float32(10.)], + } + q_model = self.qlinear_test(model, q_config, quantize_params, ['Gemm']) + q_model.export('./test.onnx', self.config) + + bias_tensor = helper.make_tensor_value_info(bias_name, TensorProto.FLOAT, [100]) + gemm2_node = onnx.helper.make_node( + "Gemm", + [input_name, weight_name, bias_name], + [output_name], + alpha=1.0, + beta=1.0, + transB=1, + name=node_name + ) + initializers = [] + initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) + graph_name = "gemm_test" + graph = helper.make_graph( + [gemm2_node], + graph_name, + [input_tensor, bias_tensor], + [output_tensor], + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 7 + q_model = self.qlinear_test(model, q_config, quantize_params, ['Gemm']) + q_model.export('./test.onnx', self.config) + + def test_embed(self): + input_ids_shape = [1, 4] + input_ids_tensor = helper.make_tensor_value_info('input_ids', TensorProto.INT32, input_ids_shape) + + segment_ids_shape = [1, 4] + segment_ids_tensor = helper.make_tensor_value_info('segment_ids', TensorProto.INT32, segment_ids_shape) + + mask_shape = [1, 4] + mask_tensor = helper.make_tensor_value_info('mask', TensorProto.INT32, input_ids_shape) + + # EmbedLayerNormalization Node Constants and Weights: + word_embed_shape = [32, 4] + word_embed_weights = np.random.random_sample(word_embed_shape).astype(dtype='float32') + word_embed_initializer = onnx.numpy_helper.from_array(word_embed_weights, name='word_embed') + + pos_embed_shape = [16, 4] + pos_embed_weights = np.random.random_sample(pos_embed_shape).astype(dtype='float32') + pos_embed_initializer = onnx.numpy_helper.from_array(pos_embed_weights, name='pos_embed') + + seg_embed_shape = [2, 4] + seg_embed_weights = np.random.random_sample(seg_embed_shape).astype(dtype='float32') + seg_embed_initializer = onnx.numpy_helper.from_array(seg_embed_weights, name='seg_embed') + + gamma_shape = [4] + gamma = np.random.random_sample(gamma_shape).astype(dtype='float32') + gamma_initializer = onnx.numpy_helper.from_array(gamma, name='gamma') + + beta_shape = [4] + beta = np.random.random_sample(beta_shape).astype(dtype='float32') + beta_initializer = onnx.numpy_helper.from_array(beta, name='beta') + + # EmbedLayerNormalization Outputs: + layernorm_out_shape = [1, 4, 4] + layernorm_out_tensor = helper.make_tensor_value_info('layernorm_out', TensorProto.FLOAT, layernorm_out_shape) + + mask_index_out_shape = [1] + mask_index_out_tensor = helper.make_tensor_value_info('mask_index_out', TensorProto.INT32, mask_index_out_shape) + + # EmbedLayerNormalization Node: + embed_layer_norm_inputs = [ + 'input_ids', 'segment_ids', 'word_embed', 'pos_embed', 'seg_embed', 'gamma', 'beta', 'mask' + ] + embed_layer_norm_outputs = ['layernorm_out', 'mask_index_out'] + embed_layer_norm_node = helper.make_node('EmbedLayerNormalization', + embed_layer_norm_inputs, + embed_layer_norm_outputs, + domain='com.microsoft', + name='Embed') + + # Construct the Graph and Model: + nodes = [embed_layer_norm_node] + graph_name = 'embed_layernorm_graph' + inputs = [input_ids_tensor, segment_ids_tensor, mask_tensor] + outputs = [layernorm_out_tensor, mask_index_out_tensor] + initializers = [ + word_embed_initializer, pos_embed_initializer, seg_embed_initializer, gamma_initializer, beta_initializer + ] + + graph = helper.make_graph(nodes, graph_name, inputs, outputs, initializer=initializers) + model = helper.make_model(graph, + opset_imports=[helper.make_opsetid("com.microsoft", 1), helper.make_opsetid("ai.onnx", 12)]) + model.ir_version = 7 # use stable onnx ir version + + q_config = {'Embed': self.static_q_config} + quantize_params = {'word_embed': [np.uint8(10.), np.float32(0)], + 'pos_embed': [np.uint8(10.), np.float32(0)], + 'seg_embed': [np.uint8(10.), np.float32(0)], + 'gamma': [np.uint8(10.), np.float32(0)], + 'beta': [np.uint8(10.), np.float32(0)], + 'layernorm_out': [np.uint8(10.), np.float32(0)], + 'mask_index_out': [np.uint8(10.), np.float32(0)], + 'input_ids': [np.uint8(10.), np.float32(0)], + } + q_model = self.qlinear_test(model, q_config, quantize_params, ['EmbedLayerNormalization']) + q_model.export('./test.onnx', self.config) + + def test_concat_reshape_pooling(self): + model = build_model() + options.onnxrt.qdq_setting.DedicatedQDQPair = True + + q_config = {'Reshape':self.static_q_config, 'conv1':self.static_q_config, 'conv2':self.static_q_config, \ + 'Concat':self.static_q_config, 'AveragePool':self.static_q_config, 'add':self.static_q_config} + quantize_params = {'input': [np.uint8(10.), np.float32(0)], + 'conv1_weight': [np.uint8(10.), np.float32(0)], + 'conv1_output': [np.uint8(10.), np.float32(0)], + 'conv2_weight': [np.uint8(10.), np.float32(0)], + 'conv2_output': [np.uint8(10.), np.float32(0)], + 'concat_output': [np.uint8(10.), np.float32(0)], + 'avg_output': [np.uint8(10.), np.float32(0)], + 'add_out': [np.uint8(10.), np.float32(0)], + 'add_init': [np.uint8(10.), np.float32(0)], + 'shape': [np.uint8(10.), np.float32(0)], + 'reshape_output': [np.uint8(10.), np.float32(0)]} + quantizable_op_types = ['Reshape', 'Conv', 'Concat', 'AveragePool', 'Add'] + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + options.onnxrt.qdq_setting.DedicatedQDQPair = False + + q_config = {'Reshape':self.static_q_config, 'conv1':'fp32', 'conv2':self.static_q_config, \ + 'Concat':self.static_q_config, 'AveragePool':self.static_q_config} + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + + q_config = {'Reshape':self.static_q_config, 'conv1':'fp32', 'conv2':'fp32', \ + 'Concat':self.static_q_config, 'AveragePool':self.static_q_config} + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + + q_config = {'Reshape':self.static_q_config, 'conv1':self.static_q_config, 'conv2':self.static_q_config, \ + 'Concat':self.static_q_config, 'AveragePool':'fp32'} + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + + quantize_params = {'input': [np.uint8(10.), np.float32(0)], + 'conv1_weight': [np.uint8(10.), np.float32(0)], + 'conv1_output': [np.uint8(10.), np.float32(0)], + 'conv2_weight': [np.uint8(10.), np.float32(0)], + 'conv2_output': [np.uint8(10.), np.float32(0)], + 'concat_output': [np.uint8(10.), np.float32(0)], + 'avg_output': [np.uint8(10.), np.float32(0)], + 'shape': [np.uint8(10.), np.float32(0)], + 'add_out': [np.uint8(10.), np.float32(0)], + 'add_init': [np.uint8(10.), np.float32(0)], + 'reshape_output': [np.uint8(10.), np.float32(0)]} + q_config = {'Reshape':self.static_q_config, 'conv1':self.static_q_config, 'conv2':self.static_q_config, \ + 'Concat':self.static_q_config, 'AveragePool':self.static_q_config} + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + + def test_conv(self): + for op in ['Conv']: + A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 5, 5, 1]) + B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [1, 3, 3, 1]) + C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 5, 5, 1]) + D = helper.make_tensor_value_info('D', TensorProto.FLOAT, [1, 1, 5, 1]) + conv_node = onnx.helper.make_node(op, ['A', 'B', 'C'], ['D'], + name=op, + kernel_shape=[3, 3], + pads=[1, 1, 1, 1]) + graph = helper.make_graph([conv_node], 'test_graph_1', [A, B, C], [D]) + model = helper.make_model(graph) + q_config = {op: self.static_q_config}, + quantize_params = {"A": [np.uint8(10.), np.float32(0)], + "B": [np.uint8(10.), np.float32(0)], + "C": [np.uint8(10.), np.float32(0)], + "D": [np.uint8(10.), np.float32(0)]} + quantizable_op_types = [op] + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + + def test_matmul(self): + A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 1, 5, 5]) + B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [1, 1, 5, 1]) + C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 1, 5, 1]) + matmul_node = onnx.helper.make_node('MatMul', ['A', 'B'], ['C'], name='Matmul') + graph = helper.make_graph([matmul_node], 'test_graph_1', [A, B], [C]) + model = helper.make_model(graph) + q_config = {"Matmul": self.static_q_config} + quantize_params = {"A": [np.uint8(10.), np.float32(0)], + "B": [np.uint8(10.), np.float32(0)], + "C": [np.uint8(10.), np.float32(0)]} + quantizable_op_types = ["Matmul"] + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + q_config = {"Matmul": self.dynamic_q_config} + q_model = self.dynamic_test(model, q_config, None, quantizable_op_types) + q_model.export('./test.onnx', self.config) + quantize_params = {"A": [np.float32(10.)], + "B": [np.float32(10.)], + "C": [np.float32(10.)]} + + q_config = {"Matmul": {"weight":{'dtype': 3, + 'algorithm': 'minmax', + 'scheme':'sym', + 'granularity': 'per_tensor'}, + 'activation':{'dtype': 2, + 'algorithm': 'minmax', + 'scheme':'asym', + 'granularity':'per_tensor', + 'quant_mode': 'dynamic'}}} + quantize_params = {} + q_model = self.dynamic_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + + def test_attention(self): + A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 128, 768]) + B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [768, 2304]) + C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [2304]) + D = helper.make_tensor_value_info('D', TensorProto.FLOAT, [1, 128, 768]) + mask = helper.make_tensor_value_info('mask', TensorProto.INT32, [1, 128]) + + node = onnx.helper.make_node('Attention', ['A', 'B', 'C', 'mask'], ['D'], name='Attention', num_heads=1) + graph = helper.make_graph([node], 'test_graph_1', [A, B, C, mask], [D]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + q_config = {"Attention": self.static_q_config} + quantize_params = {"A": [np.uint8(0), np.float32(0.5)], + "B": [np.uint8(0), np.float32(0.5)], + "C": [np.uint8(0), np.float32(0.5)], + "D": [np.uint8(0), np.float32(0.5)]} + quantizable_op_types = ["Attention"] + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + q_config = {"Attention": self.dynamic_q_config} + + def test_gather(self): + a_value = np.random.randn(100, 4).astype(np.float32) + A_init = helper.make_tensor('A', TensorProto.FLOAT, [100, 4], + a_value.reshape(400).tolist()) + b_value = np.random.randint(2, size=(1, 10)).astype(np.int32) + B_init = helper.make_tensor('B', TensorProto.INT32, [1, 10], + b_value.reshape(10).tolist()) + A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [100, 4]) + B = helper.make_tensor_value_info('B', TensorProto.INT32, [1, 10]) + C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 10, 4]) + node = onnx.helper.make_node('Gather', ['A', 'B'], ['C'], name='Gather') + graph = helper.make_graph([node], 'test_graph_1', [A, B], [C], [A_init, B_init]) + model = helper.make_model(graph) + q_config = {'Gather': {"weight":{'dtype': 2, + 'algorithm': 'minmax', + 'scheme':'asym', + 'granularity': 'per_tensor'}, + 'activation':{'dtype': 2, + 'algorithm': 'minmax', + 'scheme':'asym', + 'granularity':'per_tensor', + 'quant_mode': 'static'} + }} + quantize_params = {"A": [np.uint8(10.), np.float32(0)], + "C": [np.uint8(10.), np.float32(0)]} + quantizable_op_types = ["Gather"] + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + q_config = {'Gather': {"weight":{'dtype': 3, + 'algorithm': 'minmax', + 'scheme':'sym', + 'granularity': 'per_tensor'}, + 'activation':{'dtype': 2, + 'algorithm': 'minmax', + 'scheme':'asym', + 'granularity':'per_tensor', + 'quant_mode': 'dynamic'} + }} + q_model = self.dynamic_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + graph = helper.make_graph([node], 'test_graph_1', [A, B], [C]) + model = helper.make_model(graph) + q_config = {'Gather': {"weight":{'dtype': 3, + 'algorithm': 'minmax', + 'scheme':'sym', + 'granularity': 'per_tensor'}, + 'activation':{'dtype': 2, + 'algorithm': 'minmax', + 'scheme':'asym', + 'granularity':'per_tensor', + 'quant_mode': 'dynamic'} + }} + quantize_params = {} + q_model = self.dynamic_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + + def test_binary(self): + for op in ['Mul', 'Add']: + A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 10]) + B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [1]) + C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 10]) + node = onnx.helper.make_node(op, ['A', 'B'], ['C'], name=op) + graph = helper.make_graph([node], 'test_graph_1', [A, B], [C]) + model = helper.make_model(graph) + q_config = {op: self.static_q_config} + quantize_params = {"A": [np.uint8(10.), np.float32(0)], + "B": [np.uint8(10.), np.float32(0)], + "C": [np.uint8(10.), np.float32(0)]} + quantizable_op_types = [op] + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + q_model = self.qlinear_test(model, q_config, {}, quantizable_op_types) + q_model.export('./test.onnx', self.config) + + def test_activation(self): + config = {"weight":{'dtype': 2, + 'algorithm': 'minmax', + 'scheme':'asym', + 'granularity': 'per_tensor'}, + 'activation':{'dtype': 2, + 'algorithm': 'minmax', + 'scheme':'asym', + 'granularity':'per_tensor', + 'quant_mode': 'static'} + } + + for op in ["Relu", "LeakyRelu", "Sigmoid"]: + B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [1, 10]) + A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 10]) + node = onnx.helper.make_node(op, ['A'], ['B'], name=op) + graph = helper.make_graph([node], 'test_graph_1', [A], [B]) + model = helper.make_model(graph) + q_config = {op: config} + quantize_params = {"A": [np.uint8(10.), np.float32(0)], + "B": [np.uint8(10.), np.float32(0)]} + quantizable_op_types = [op] + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + + a_value = np.random.randn(1, 10).astype(np.float32) + A_init = helper.make_tensor('A', TensorProto.FLOAT, [1, 10], + a_value.reshape(10).tolist()) + graph = helper.make_graph([node], 'test_graph_1', [A], [B], [A_init]) + model = helper.make_model(graph) + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + + def test_pooling(self): + op = "MaxPool" + B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [1, 5, 5, 1]) + A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 5, 5, 1]) + node = onnx.helper.make_node(op, ['A'], ['B'], + name=op, + kernel_shape=[3, 3], + pads=[1, 1, 1, 1]) + graph = helper.make_graph([node], 'test_graph_1', [A], [B]) + q_config = {op: self.static_q_config} + quantize_params = {"A": [np.uint8(10.), np.float32(0)], + "B": [np.uint8(10.), np.float32(0)]} + quantizable_op_types = [op] + for opset_version in [12, 13]: + opset = onnx.OperatorSetIdProto() + opset.version = opset_version + model = helper.make_model(graph, opset_imports=[opset]) + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + + A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 1, 5, 5]) + B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [1, 1, 3, 3]) + D = helper.make_tensor_value_info('D', TensorProto.FLOAT, [1, 1, 5, 5]) + conv_node = onnx.helper.make_node('Conv', ['A', 'B'], ['C'], + name='Conv', + kernel_shape=[3, 3], + pads=[1, 1, 1, 1]) + pool_node = onnx.helper.make_node(op, ['C'], ['D'], name=op, kernel_shape=[1, 1]) + graph = helper.make_graph([conv_node, pool_node], 'test_graph_1', [A, B], [D]) + model = helper.make_model(graph) + + q_config = {"Conv": self.static_q_config, op: self.static_q_config} + quantize_params = {"A": [np.uint8(10.), np.float32(0)], + "B": [np.uint8(10.), np.float32(0)], + "C": [np.uint8(10.), np.float32(0)], + "D": [np.uint8(10.), np.float32(0)]} + quantizable_op_types = ["Conv", op] + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + + op = "GlobalAveragePool" + B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [1, 5, 1, 1]) + A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 5, 5, 1]) + node = onnx.helper.make_node(op, ['A'], ['B'], + name=op) + graph = helper.make_graph([node], 'test_graph_1', [A], [B]) + q_config = {op: self.static_q_config} + quantize_params = {"A": [np.uint8(10.), np.float32(0)], + "B": [np.uint8(10.), np.float32(0)]} + quantizable_op_types = [op] + for opset_version in [12, 13]: + opset = onnx.OperatorSetIdProto() + opset.version = opset_version + model = helper.make_model(graph, opset_imports=[opset]) + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + + A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 1, 5, 5]) + B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [1, 1, 3, 3]) + D = helper.make_tensor_value_info('D', TensorProto.FLOAT, [1, 1, 1, 1]) + conv_node = onnx.helper.make_node('Conv', ['A', 'B'], ['C'], + name='Conv', + kernel_shape=[3, 3], + pads=[1, 1, 1, 1]) + pool_node = onnx.helper.make_node(op, ['C'], ['D'], name=op) + graph = helper.make_graph([conv_node, pool_node], 'test_graph_1', [A, B], [D]) + model = helper.make_model(graph) + + q_config = {"Conv": self.static_q_config, op: self.static_q_config} + quantize_params = {"A": [np.uint8(10.), np.float32(0)], + "B": [np.uint8(10.), np.float32(0)], + "C": [np.uint8(10.), np.float32(0)], + "D": [np.uint8(10.), np.float32(0)]} + quantizable_op_types = ["Conv", op] + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + + + def test_exclude_node(self): + A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 5, 5, 1]) + B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [3, 3, 1, 1]) + D = helper.make_tensor_value_info('D', TensorProto.FLOAT, [1, 3, 5, 1]) + conv_node = onnx.helper.make_node('Conv', ['A', 'B'], ['C'], + name='Conv', + kernel_shape=[3, 3], + pads=[1, 1, 1, 1]) + pool_node = onnx.helper.make_node("MaxPool", ['C'], ['D'], name="MaxPool", kernel_shape=[1, 1]) + graph = helper.make_graph([conv_node, pool_node], 'test_graph_1', [A, B], [D]) + model = helper.make_model(graph) + + q_config = {"Conv": self.static_q_config, "MaxPool": "fp32"} + quantize_params = {"A": [np.uint8(10.), np.float32(0)], + "B": [np.uint8(10.), np.float32(0)], + "C": [np.uint8(10.), np.float32(0)], + "D": [np.uint8(10.), np.float32(0)]} + quantizable_op_types = ["Conv", "MaxPool"] + self.config.exclude_output_quantization = ['Conv'] + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + q_model.export('./test.onnx', self.config) + +if __name__ == "__main__": + unittest.main()