Skip to content

Commit

Permalink
Add contrib Q/DQ ops to symbolic shape inference tool (#19340)
Browse files Browse the repository at this point in the history
### Description
Adds type/shape inferencing support for MSFT domain QuantizeLinear and
DequantizeLinear operators to symbolic_shape_infer.py


### Motivation and Context
Need a way to infer the types and shapes of Q/DQ ops in models that use
the MSFT domain versions (e.g., int16 quantization).
  • Loading branch information
adrianlizarraga committed Jan 31, 2024
1 parent 2b361c0 commit ca8d445
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 0 deletions.
27 changes: 27 additions & 0 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"BiasGelu": self._infer_BiasGelu,
"BiasSplitGelu": self._infer_BiasSplitGelu,
"DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention,
"DequantizeLinear": self._infer_DequantizeLinear,
"EmbedLayerNormalization": self._infer_EmbedLayerNormalization,
"FastGelu": self._infer_FastGelu,
"GatedRelativePositionBias": self._infer_GatedRelativePositionBias,
Expand All @@ -212,6 +213,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"PackedAttention": self._infer_PackedAttention,
"PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention,
"PythonOp": self._infer_PythonOp,
"QuantizeLinear": self._infer_QuantizeLinear,
"QuickGelu": self._infer_FastGelu,
"RelativePositionBias": self._infer_RelativePositionBias,
"RemovePadding": self._infer_RemovePadding,
Expand Down Expand Up @@ -457,6 +459,8 @@ def _onnx_infer_single_node(self, node):
"GemmFastGelu",
"LayerNormalization",
"LongformerAttention",
"DequantizeLinear",
"QuantizeLinear",
"RelativePositionBias",
"RemovePadding",
"RestorePadding",
Expand Down Expand Up @@ -979,6 +983,29 @@ def _infer_NhwcConv(self, node): # noqa: N802
)
)

def _infer_DequantizeLinear(self, node): # noqa: N802
# Get the output data type from the scale input (index 1, required).
output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type

# Get the output shape from the first input.
output_shape = self._get_shape(node, 0)

vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))

def _infer_QuantizeLinear(self, node): # noqa: N802
# Get the output data type from the zero-point input (index 2, optional).
# Otherwise, default to uint8
output_dtype = onnx.TensorProto.UINT8
if len(node.input) > 2 and node.input[2]:
output_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type

# Get the output shape from the first input.
output_shape = self._get_shape(node, 0)

vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))

def _infer_Einsum(self, node): # noqa: N802
# ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275
equation = get_attribute(node, "equation")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,208 @@ def test_div_precision(self):
self.assertEqual(len(output_dims), 1)
self.assertEqual(output_dims[0].dim_value, 512)

def test_quantize_linear(self):
"""
Test ONNX QuantizeLinear op.
Check that the output shape is propagated from the first input and that the output data
type comes from the zero-point input.
"""
initializers = [
helper.make_tensor(
"scale",
TensorProto.FLOAT,
[],
[1.0],
),
helper.make_tensor(
"zero_point",
TensorProto.INT8,
[],
[16],
),
]

nodes = [
helper.make_node(
"QuantizeLinear",
inputs=[
"input_f32",
"scale",
"zero_point",
],
outputs=["output_s8"],
),
]

inputs = [
helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["b", 2, 3, 4]),
]

outputs = [
helper.make_tensor_value_info("output_s8", TensorProto.UNDEFINED, None),
]

graph = helper.make_graph(nodes, "QuantizeLinear_Test", inputs, outputs, initializers)
model = helper.make_model(graph)

inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True)

expected_shapes = [
helper.make_tensor_value_info("output_s8", TensorProto.INT8, ["b", 2, 3, 4]),
]
self._check_shapes(graph, inferred.graph, expected_shapes)

def test_quantize_linear_ms_domain(self):
"""
Test QuantizeLinear op ('com.microsoft' domain).
Check that the output shape is propagated from the first input and that the output data
type comes from the zero-point input.
"""
initializers = [
helper.make_tensor(
"scale",
TensorProto.FLOAT,
[],
[1.0],
),
helper.make_tensor(
"zero_point",
TensorProto.UINT16,
[],
[16],
),
]

nodes = [
helper.make_node(
"QuantizeLinear",
inputs=[
"input_f32",
"scale",
"zero_point",
],
outputs=["output_u16"],
domain="com.microsoft",
),
]

inputs = [
helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["b", 2, 3, 4]),
]

outputs = [
helper.make_tensor_value_info("output_u16", TensorProto.UNDEFINED, None),
]

graph = helper.make_graph(nodes, "QuantizeLinear_MSDomain_Test", inputs, outputs, initializers)
model = helper.make_model(graph)

inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True)

expected_shapes = [
helper.make_tensor_value_info("output_u16", TensorProto.UINT16, ["b", 2, 3, 4]),
]
self._check_shapes(graph, inferred.graph, expected_shapes)

def test_quantize_linear_no_zp_input(self):
"""
Test QuantizeLinear op ('com.microsoft' domain).
Check that the output shape is propagated from the first input.
The zero-point input is missing, so the output data type should default to uint8.
"""
initializers = [
helper.make_tensor(
"scale",
TensorProto.FLOAT,
[],
[1.0],
),
]

nodes = [
helper.make_node(
"QuantizeLinear",
inputs=[
"input_f32",
"scale",
],
outputs=["output_u8"],
domain="com.microsoft",
),
]

inputs = [
helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["b", 2, 3, 4]),
]

outputs = [
helper.make_tensor_value_info("output_u8", TensorProto.UNDEFINED, None),
]

graph = helper.make_graph(nodes, "QuantizeLinear_NoZP_Test", inputs, outputs, initializers)
model = helper.make_model(graph)

inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True)

# Check that the output shape is propagated from the first input and that the
# output data type comes from the zero-point input.
expected_shapes = [
helper.make_tensor_value_info("output_u8", TensorProto.UINT8, ["b", 2, 3, 4]),
]
self._check_shapes(graph, inferred.graph, expected_shapes)

def test_dequantize_linear_ms_domain(self):
"""
Test DequantizeLinear operator ('com.microsoft' domain).
Check that the output shape is propagated from the first input and that the output data
type comes from the scale input.
"""
initializers = [
helper.make_tensor(
"scale",
TensorProto.FLOAT,
[],
[1.0],
),
helper.make_tensor(
"zero_point",
TensorProto.UINT16,
[],
[16],
),
]

nodes = [
helper.make_node(
"DequantizeLinear",
inputs=[
"input_u16",
"scale",
"zero_point",
],
outputs=["output_f32"],
domain="com.microsoft",
),
]

inputs = [
helper.make_tensor_value_info("input_u16", TensorProto.UINT16, ["b", 2, 3, 4]),
]

outputs = [
helper.make_tensor_value_info("output_f32", TensorProto.UNDEFINED, None),
]

graph = helper.make_graph(nodes, "DequantizeLinear_MSDomain_Test", inputs, outputs, initializers)
model = helper.make_model(graph)

inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True)

expected_shapes = [
helper.make_tensor_value_info("output_f32", TensorProto.FLOAT, ["b", 2, 3, 4]),
]
self._check_shapes(graph, inferred.graph, expected_shapes)


class TestSymbolicShapeInferenceForSlice(unittest.TestCase):
def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim):
Expand Down

0 comments on commit ca8d445

Please sign in to comment.