From a94d51cf105789bbb8c80a97d7672882dc4fb9b9 Mon Sep 17 00:00:00 2001 From: abhikran-quic <63697863+abhikran-quic@users.noreply.github.com> Date: Wed, 24 Nov 2021 22:19:09 +0530 Subject: [PATCH] [ONNX] Add MatMulInteger16 contrib op (#9186) * [ONNX] Add MatMulInteger16 contrib op * Fix formatting errors * Remove a code comment and do not set default value of nd * Move flatten_to_nd function outside matmul to be used across multiple functions * Add function docstring and describe the tests * Use max/min value of int16 as high/low while generating input vectors * Converge MatMul and MatMulInteger16 ops into a single op using output dtype * Fix indentation issues * Formatting changes * Fix CUDA batchmatmul strategy to allow mixed precision * Add test_matmulinteger to unsupported_onnx_tests --- python/tvm/relay/frontend/onnx.py | 164 +++++++++++---------- python/tvm/relay/op/strategy/cuda.py | 2 +- tests/python/frontend/onnx/test_forward.py | 42 ++++++ 3 files changed, 133 insertions(+), 75 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 7e524c112e65e..d55c0e5090312 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -216,6 +216,77 @@ def get_scalar(x, params, dtype="float32"): return _op.cast(x, dtype) +def matmul_out_dtype(inputs, out_dtype): + """Common function to handle MatMul and MatMulInteger16""" + a_shape = shape_of(inputs[0]) + a_rank = infer_shape(a_shape)[0] + b_shape = shape_of(inputs[1]) + b_rank = infer_shape(b_shape)[0] + if a_rank > 2 or b_rank > 2: + + def flatten_to_nd(x, x_shape, nd=3): + ndims = infer_shape(x_shape)[0] + if ndims == nd: + return x + newshape = _op.concatenate( + [ + _expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype), + _op.strided_slice(x_shape, [ndims - nd + 1], [ndims]), + ], + 0, + ) + out = _op.reshape(x, fold_constant(newshape)) + return out + + b_type = infer_type(inputs[1]) + # Convert to dense if the second matrix is 2d and non-dynamic + if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type): + a = flatten_to_nd(inputs[0], a_shape, 2) + b = _op.transpose(inputs[1]) + output = _op.nn.dense(a, b, out_dtype=out_dtype) + else: + # Convert a and b into 3 dimensional tensors. + a = flatten_to_nd(inputs[0], a_shape, 3) + b = flatten_to_nd(inputs[1], b_shape, 3) + # Perform a NN batch matmul. + output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False) + # Determine the output batch dimension. + if a_rank > b_rank: + out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2]) + elif a_rank < b_rank: + out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2]) + # If its unclear how broadcasting should be applied, the output + # shape is determined by choosing the maximum value from each input. + else: + out_batch = _op.concatenate( + [ + _op.maximum( + _op.strided_slice(a_shape, [i], [i + 1]), + _op.strided_slice(b_shape, [i], [i + 1]), + ) + for i in range(a_rank - 2) + ], + 0, + ) + # Reshape output to original dimensions. + final_shape = _op.concatenate( + [ + out_batch, + _op.strided_slice( + a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1] + ), + _op.strided_slice( + b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]] + ), + ], + 0, + ) + return _op.reshape(output, fold_constant(final_shape)) + # Otherwise a simple dense op will get the job done. + input_1_t = _op.transpose(inputs[1], axes=(1, 0)) + return _op.nn.dense(inputs[0], input_1_t, out_dtype=out_dtype) + + class OnnxOpConverter(object): """A helper class for holding onnx op converters.""" @@ -735,80 +806,24 @@ class MatMul(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs)) # Need to check input shape as batch matmul must be supported. - a_shape = shape_of(inputs[0]) - a_rank = infer_shape(a_shape)[0] - b_shape = shape_of(inputs[1]) - b_rank = infer_shape(b_shape)[0] - # When performing a batch matmul, we need to properly handle N-dim shapes. - if a_rank > 2 or b_rank > 2: - - def flatten_to_nd(x, x_shape, nd=3): - ndims = infer_shape(x_shape)[0] - if ndims == nd: - return x - newshape = _op.concatenate( - [ - _expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype), - _op.strided_slice(x_shape, [ndims - nd + 1], [ndims]), - ], - 0, - ) - out = _op.reshape(x, fold_constant(newshape)) - return out - - b_type = infer_type(inputs[1]) - # Convert to dense if the second matrix is 2d and non-dynamic - if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type): - a = flatten_to_nd(inputs[0], a_shape, 2) - b = _op.transpose(inputs[1]) - output = _op.nn.dense(a, b) - else: - # Convert a and b into 3 dimensional tensors. - a = flatten_to_nd(inputs[0], a_shape, 3) - b = flatten_to_nd(inputs[1], b_shape, 3) - if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]: - # Transpose matrix dimensions of b. - b = _op.transpose(b, [0, 2, 1]) - # Perform a NT batch matmul. - output = _op.nn.batch_matmul(a, b) - else: - # Perform a NN batch matmul. - output = _op.nn.batch_matmul(a, b, transpose_b=False) - # Determine the output batch dimension. - if a_rank > b_rank: - out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2]) - elif a_rank < b_rank: - out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2]) - # If its unclear how broadcasting should be applied, the output - # shape is determined by choosing the maximum value from each input. - else: - out_batch = _op.concatenate( - [ - _op.maximum( - _op.strided_slice(a_shape, [i], [i + 1]), - _op.strided_slice(b_shape, [i], [i + 1]), - ) - for i in range(a_rank - 2) - ], - 0, - ) - # Reshape output to original dimensions. - final_shape = _op.concatenate( - [ - out_batch, - _op.strided_slice( - a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1] - ), - _op.strided_slice( - b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]] - ), - ], - 0, - ) - return _op.reshape(output, fold_constant(final_shape)) - # Otherwise a simple dense op will get the job done. - input_1_t = _op.transpose(inputs[1], axes=(1, 0)) - return _op.nn.dense(inputs[0], input_1_t) + return matmul_out_dtype(inputs, out_dtype=infer_type(inputs[0]).checked_type.dtype) + + +class MatMulInteger16(OnnxOpConverter): + """Operator converter for MatMulInteger16 from Microsoft onnxruntime contrib opset.""" + + @classmethod + def _impl_v10(cls, inputs, attr, params): + assert len(inputs) == 2, "MatMulInteger16 op take 2 inputs, {} given".format(len(inputs)) + a_dtype = infer_type(inputs[0]).checked_type.dtype + b_dtype = infer_type(inputs[1]).checked_type.dtype + # Check input data types + assert a_dtype in ("int16", "uint16"), "MatMulInteger16: invalid dtype for first input" + assert b_dtype in ("int16", "uint16"), "MatMulInteger16: invalid dtype for second input" + out_dtype = "int32" + if a_dtype == "uint16" and b_dtype == "uint16": + out_dtype = "uint32" + return matmul_out_dtype(inputs, out_dtype) class Mod(OnnxOpConverter): @@ -4386,6 +4401,7 @@ def _get_convert_map(opset): "Softsign": Softsign.get_converter(opset), "Gemm": Gemm.get_converter(opset), "MatMul": MatMul.get_converter(opset), + "MatMulInteger16": MatMulInteger16.get_converter(opset), "Mod": Mod.get_converter(opset), "Xor": Renamer("logical_xor"), # defs/nn diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index eee5d9a685b3b..80f1fe1765c3a 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -839,7 +839,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): ) else: strategy.add_implementation( - wrap_compute_batch_matmul(topi.cuda.batch_matmul), + wrap_compute_batch_matmul(topi.cuda.batch_matmul, need_out_dtype=True), wrap_topi_schedule(topi.cuda.schedule_batch_matmul), name="batch_matmul.cuda", plevel=10, diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 36952330a90a1..d1fb64f61e037 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1283,6 +1283,47 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None): ) +@tvm.testing.parametrize_targets +def test_matmulinteger16(target, dev): + def verify_matmulinteger16(a_shape, b_shape, out_shape): + a_dtype = "int16" + b_dtype = "int16" + low = np.iinfo(np.int16).min + high = np.iinfo(np.int16).max + + a_proto = TensorProto.INT16 + b_proto = TensorProto.INT16 + out_proto = TensorProto.INT32 + a_array = np.random.randint(low, high, size=a_shape).astype(a_dtype) + b_array = np.random.randint(low, high, size=b_shape).astype(b_dtype) + + mul_node = helper.make_node("MatMulInteger16", ["a", "b"], ["out"], domain="com.microsoft") + + graph = helper.make_graph( + [mul_node], + "matmuli16_test", + inputs=[ + helper.make_tensor_value_info("a", a_proto, list(a_shape)), + helper.make_tensor_value_info("b", b_proto, list(b_shape)), + ], + outputs=[helper.make_tensor_value_info("out", out_proto, list(out_shape))], + ) + + model = helper.make_model(graph, producer_name="matmuli16_test") + verify_with_ort_with_inputs(model, [a_array, b_array], target=target, dev=dev) + + # 2D computation to verify matmul op + verify_matmulinteger16((4, 3), (3, 4), (4, 4)) + verify_matmulinteger16((5, 7), (7, 8), (5, 8)) + # Verify 3D matmul using batch_matmul op + verify_matmulinteger16((2, 4, 3), (1, 3, 4), (2, 4, 4)) + verify_matmulinteger16((1, 4, 3), (2, 3, 4), (2, 4, 4)) + # Test implicit broadcasting + verify_matmulinteger16((2, 3, 5, 3), (2, 3, 3, 5), (2, 3, 5, 5)) + verify_matmulinteger16((2, 7, 3), (3, 7), (2, 7, 7)) + verify_matmulinteger16((2, 3, 4, 3), (3, 4), (2, 3, 4, 4)) + + def verify_simple_dynamic_model(a_shape, b_shape, target, dev): def verify_model(model, a_shape, b_shape): a_array = np.random.uniform(size=a_shape).astype("float32") @@ -5999,6 +6040,7 @@ def repeat(N, D): test_onehot() test_gemm() test_matmul() + test_matmulinteger16() test_gather() test_gatherelements() test_gather_nd()