Skip to content

Commit

Permalink
[ONNX] Add MatMulInteger16 contrib op (apache#9186)
Browse files Browse the repository at this point in the history
* [ONNX] Add MatMulInteger16 contrib op

* Fix formatting errors

* Remove a code comment and do not set default value of nd

* Move flatten_to_nd function outside matmul to be used across multiple functions

* Add function docstring and describe the tests

* Use max/min value of int16 as high/low while generating input vectors

* Converge MatMul and MatMulInteger16 ops into a single op using output dtype

* Fix indentation issues

* Formatting changes

* Fix CUDA batchmatmul strategy to allow mixed precision

* Add test_matmulinteger to unsupported_onnx_tests
  • Loading branch information
abhikran-quic authored and dchauhan-arm committed Nov 29, 2021
1 parent 181d45e commit a94d51c
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 75 deletions.
164 changes: 90 additions & 74 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,77 @@ def get_scalar(x, params, dtype="float32"):
return _op.cast(x, dtype)


def matmul_out_dtype(inputs, out_dtype):
"""Common function to handle MatMul and MatMulInteger16"""
a_shape = shape_of(inputs[0])
a_rank = infer_shape(a_shape)[0]
b_shape = shape_of(inputs[1])
b_rank = infer_shape(b_shape)[0]
if a_rank > 2 or b_rank > 2:

def flatten_to_nd(x, x_shape, nd=3):
ndims = infer_shape(x_shape)[0]
if ndims == nd:
return x
newshape = _op.concatenate(
[
_expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype),
_op.strided_slice(x_shape, [ndims - nd + 1], [ndims]),
],
0,
)
out = _op.reshape(x, fold_constant(newshape))
return out

b_type = infer_type(inputs[1])
# Convert to dense if the second matrix is 2d and non-dynamic
if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
a = flatten_to_nd(inputs[0], a_shape, 2)
b = _op.transpose(inputs[1])
output = _op.nn.dense(a, b, out_dtype=out_dtype)
else:
# Convert a and b into 3 dimensional tensors.
a = flatten_to_nd(inputs[0], a_shape, 3)
b = flatten_to_nd(inputs[1], b_shape, 3)
# Perform a NN batch matmul.
output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False)
# Determine the output batch dimension.
if a_rank > b_rank:
out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
elif a_rank < b_rank:
out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
# If its unclear how broadcasting should be applied, the output
# shape is determined by choosing the maximum value from each input.
else:
out_batch = _op.concatenate(
[
_op.maximum(
_op.strided_slice(a_shape, [i], [i + 1]),
_op.strided_slice(b_shape, [i], [i + 1]),
)
for i in range(a_rank - 2)
],
0,
)
# Reshape output to original dimensions.
final_shape = _op.concatenate(
[
out_batch,
_op.strided_slice(
a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1]
),
_op.strided_slice(
b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]]
),
],
0,
)
return _op.reshape(output, fold_constant(final_shape))
# Otherwise a simple dense op will get the job done.
input_1_t = _op.transpose(inputs[1], axes=(1, 0))
return _op.nn.dense(inputs[0], input_1_t, out_dtype=out_dtype)


class OnnxOpConverter(object):
"""A helper class for holding onnx op converters."""

Expand Down Expand Up @@ -735,80 +806,24 @@ class MatMul(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs))
# Need to check input shape as batch matmul must be supported.
a_shape = shape_of(inputs[0])
a_rank = infer_shape(a_shape)[0]
b_shape = shape_of(inputs[1])
b_rank = infer_shape(b_shape)[0]
# When performing a batch matmul, we need to properly handle N-dim shapes.
if a_rank > 2 or b_rank > 2:

def flatten_to_nd(x, x_shape, nd=3):
ndims = infer_shape(x_shape)[0]
if ndims == nd:
return x
newshape = _op.concatenate(
[
_expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype),
_op.strided_slice(x_shape, [ndims - nd + 1], [ndims]),
],
0,
)
out = _op.reshape(x, fold_constant(newshape))
return out

b_type = infer_type(inputs[1])
# Convert to dense if the second matrix is 2d and non-dynamic
if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
a = flatten_to_nd(inputs[0], a_shape, 2)
b = _op.transpose(inputs[1])
output = _op.nn.dense(a, b)
else:
# Convert a and b into 3 dimensional tensors.
a = flatten_to_nd(inputs[0], a_shape, 3)
b = flatten_to_nd(inputs[1], b_shape, 3)
if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]:
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a NT batch matmul.
output = _op.nn.batch_matmul(a, b)
else:
# Perform a NN batch matmul.
output = _op.nn.batch_matmul(a, b, transpose_b=False)
# Determine the output batch dimension.
if a_rank > b_rank:
out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
elif a_rank < b_rank:
out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
# If its unclear how broadcasting should be applied, the output
# shape is determined by choosing the maximum value from each input.
else:
out_batch = _op.concatenate(
[
_op.maximum(
_op.strided_slice(a_shape, [i], [i + 1]),
_op.strided_slice(b_shape, [i], [i + 1]),
)
for i in range(a_rank - 2)
],
0,
)
# Reshape output to original dimensions.
final_shape = _op.concatenate(
[
out_batch,
_op.strided_slice(
a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1]
),
_op.strided_slice(
b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]]
),
],
0,
)
return _op.reshape(output, fold_constant(final_shape))
# Otherwise a simple dense op will get the job done.
input_1_t = _op.transpose(inputs[1], axes=(1, 0))
return _op.nn.dense(inputs[0], input_1_t)
return matmul_out_dtype(inputs, out_dtype=infer_type(inputs[0]).checked_type.dtype)


class MatMulInteger16(OnnxOpConverter):
"""Operator converter for MatMulInteger16 from Microsoft onnxruntime contrib opset."""

@classmethod
def _impl_v10(cls, inputs, attr, params):
assert len(inputs) == 2, "MatMulInteger16 op take 2 inputs, {} given".format(len(inputs))
a_dtype = infer_type(inputs[0]).checked_type.dtype
b_dtype = infer_type(inputs[1]).checked_type.dtype
# Check input data types
assert a_dtype in ("int16", "uint16"), "MatMulInteger16: invalid dtype for first input"
assert b_dtype in ("int16", "uint16"), "MatMulInteger16: invalid dtype for second input"
out_dtype = "int32"
if a_dtype == "uint16" and b_dtype == "uint16":
out_dtype = "uint32"
return matmul_out_dtype(inputs, out_dtype)


class Mod(OnnxOpConverter):
Expand Down Expand Up @@ -4386,6 +4401,7 @@ def _get_convert_map(opset):
"Softsign": Softsign.get_converter(opset),
"Gemm": Gemm.get_converter(opset),
"MatMul": MatMul.get_converter(opset),
"MatMulInteger16": MatMulInteger16.get_converter(opset),
"Mod": Mod.get_converter(opset),
"Xor": Renamer("logical_xor"),
# defs/nn
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
)
else:
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul),
wrap_compute_batch_matmul(topi.cuda.batch_matmul, need_out_dtype=True),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
name="batch_matmul.cuda",
plevel=10,
Expand Down
42 changes: 42 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,6 +1283,47 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None):
)


@tvm.testing.parametrize_targets
def test_matmulinteger16(target, dev):
def verify_matmulinteger16(a_shape, b_shape, out_shape):
a_dtype = "int16"
b_dtype = "int16"
low = np.iinfo(np.int16).min
high = np.iinfo(np.int16).max

a_proto = TensorProto.INT16
b_proto = TensorProto.INT16
out_proto = TensorProto.INT32
a_array = np.random.randint(low, high, size=a_shape).astype(a_dtype)
b_array = np.random.randint(low, high, size=b_shape).astype(b_dtype)

mul_node = helper.make_node("MatMulInteger16", ["a", "b"], ["out"], domain="com.microsoft")

graph = helper.make_graph(
[mul_node],
"matmuli16_test",
inputs=[
helper.make_tensor_value_info("a", a_proto, list(a_shape)),
helper.make_tensor_value_info("b", b_proto, list(b_shape)),
],
outputs=[helper.make_tensor_value_info("out", out_proto, list(out_shape))],
)

model = helper.make_model(graph, producer_name="matmuli16_test")
verify_with_ort_with_inputs(model, [a_array, b_array], target=target, dev=dev)

# 2D computation to verify matmul op
verify_matmulinteger16((4, 3), (3, 4), (4, 4))
verify_matmulinteger16((5, 7), (7, 8), (5, 8))
# Verify 3D matmul using batch_matmul op
verify_matmulinteger16((2, 4, 3), (1, 3, 4), (2, 4, 4))
verify_matmulinteger16((1, 4, 3), (2, 3, 4), (2, 4, 4))
# Test implicit broadcasting
verify_matmulinteger16((2, 3, 5, 3), (2, 3, 3, 5), (2, 3, 5, 5))
verify_matmulinteger16((2, 7, 3), (3, 7), (2, 7, 7))
verify_matmulinteger16((2, 3, 4, 3), (3, 4), (2, 3, 4, 4))


def verify_simple_dynamic_model(a_shape, b_shape, target, dev):
def verify_model(model, a_shape, b_shape):
a_array = np.random.uniform(size=a_shape).astype("float32")
Expand Down Expand Up @@ -5999,6 +6040,7 @@ def repeat(N, D):
test_onehot()
test_gemm()
test_matmul()
test_matmulinteger16()
test_gather()
test_gatherelements()
test_gather_nd()
Expand Down

0 comments on commit a94d51c

Please sign in to comment.