Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

opset 12 support #897

Merged
merged 13 commits into from Apr 28, 2020
23 changes: 23 additions & 0 deletions tests/test_backend.py
Expand Up @@ -3130,6 +3130,29 @@ def func(X, K):
k_val = np.array(raw_k).astype(np.int32)
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})

@check_opset_min_version(12)
def test_inverse(self):
x_val = np.random.random([5, 5]).astype(np.float32)
def func(x):
return tf.linalg.inv(x, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_opset_min_version(12)
def test_less_or_equal(self):
x_val = np.random.random([4, 5]).astype(np.float32)
y_val = np.random.random([4, 5]).astype(np.float32)
def func(x, y):
return tf.math.less_equal(x, y, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})

@check_opset_min_version(12)
def test_squared_distance(self):
x_val = np.random.random([4, 5]).astype(np.float32)
y_val = np.random.random([4, 5]).astype(np.float32)
def func(x, y):
return tf.math.squared_difference(x, y, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})


if __name__ == '__main__':
unittest_main()
25 changes: 24 additions & 1 deletion tf2onnx/onnx_opset/math.py
Expand Up @@ -22,7 +22,7 @@

# pylint: disable=unused-argument,missing-docstring

@tf_op(["Add", "AddV2", "Div", "Mul", "Sub"])
@tf_op(["Add", "AddV2", "Div", "Mul", "Sub", "LessOrEqual"])
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
class BroadcastOp(common.BroadcastOp):
pass

Expand Down Expand Up @@ -544,3 +544,26 @@ def version_11(cls, ctx, node, **kwargs):
cast_back_node.set_attr("to", dtypes[0])
ctx.set_dtype(cast_back_node.output[0], dtypes[0])
ctx.copy_shape(node.name, cast_back_node.output[0])

@tf_op("MatrixInverse")
class Inverse:

@classmethod
guschmue marked this conversation as resolved.
Show resolved Hide resolved
def version_12(cls, ctx, node, **kwargs):
utils.make_sure(node.get_attr('adjoint').i == 0, "adjoint must be false")
shapes = node.output_shapes
dtypes = node.output_dtypes
ctx.remove_node(node.name)
ctx.make_node("Inverse", inputs=node.input, outputs=node.output, name=node.name,
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
domain=constants.MICROSOFT_DOMAIN, shapes=shapes, dtypes=dtypes)
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved

@tf_op("SquaredDistance")
class SquaredDistance:

@classmethod
def version_12(cls, ctx, node, **kwargs):
shapes = node.output_shapes
dtypes = node.output_dtypes
ctx.remove_node(node.name)
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
ctx.make_node("MeanSquaredDistance", inputs=node.input, outputs=node.output, name=node.name,
shapes=shapes, dtypes=dtypes, attr={"reduction": "none"})
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved