Skip to content

Commit

Permalink
Fix conversion of OneHot for dtypes unsupported by ORT
Browse files Browse the repository at this point in the history
Signed-off-by: Tom Wildenhain <tomwi@microsoft.com>
  • Loading branch information
TomWildenhain-Microsoft committed Aug 14, 2021
1 parent 2be4cf3 commit cf2b252
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
8 changes: 5 additions & 3 deletions tests/test_backend.py
Expand Up @@ -2271,13 +2271,15 @@ def func(x):
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_target("rs6", "onehot")
@skip_tfjs("tfjs produces incorrect results")
def test_onehot0(self):
x_val = np.array([0, 1, 2], dtype=np.int32)
depth = 5
for axis in [-1, 0, 1]:
for dtype, axis in [(tf.float32, -1), (tf.int64, 0), (tf.float64, 1)]:
def func(x):
x_ = tf.one_hot(x, depth, on_value=5.0, axis=axis, off_value=1.0, dtype=tf.float32)
val1 = tf.constant(5, dtype)
val2 = tf.constant(1, dtype)
x_ = tf.one_hot(x, depth, on_value=val1, axis=axis, off_value=val2, dtype=dtype)
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

Expand Down
15 changes: 10 additions & 5 deletions tf2onnx/onnx_opset/tensor.py
Expand Up @@ -1313,8 +1313,11 @@ def version_1(cls, ctx, node, **kwargs):
if axis.i == 0:
# TODO: revisit for rank > 1
name = utils.make_name(node.name)
transpose_node = ctx.insert_new_node_on_output("Transpose", node.output[0], name)
ctx.copy_shape(node.output[0], transpose_node.output[0])
shape = ctx.get_shape(node.output[0])
transpose_node = ctx.make_node("Transpose", [node.output[0]], name=name, shapes=[shape])
ctx.insert_node_on_output(transpose_node, node.output[0])
if shape is not None:
ctx.set_shape(node.output[0], shape[::-1])

@classmethod
def any_version_after9(cls, opset, ctx, node, **kwargs):
Expand All @@ -1323,9 +1326,11 @@ def any_version_after9(cls, opset, ctx, node, **kwargs):
# in ONNX, op's schema is (input, depth, value, @int axis), meaning of "value" is [off-value, on-value]
# onnxruntime only supports int64
output_dtype = ctx.get_dtype(node.input[2])
if ctx.is_target(constants.TARGET_RS6) \
and output_dtype not in [onnx_pb.TensorProto.INT64, onnx_pb.TensorProto.INT32]:
logger.warning("unsupported dtype in onnxruntime, onehot-9 can't be used directly")
supported_dtypes = [onnx_pb.TensorProto.FLOAT]
if ctx.is_target(constants.TARGET_RS6):
supported_dtypes = [onnx_pb.TensorProto.INT64, onnx_pb.TensorProto.INT32]
if output_dtype not in supported_dtypes:
logger.warning("unsupported dtype in target runtime, OneHot op can't be used directly")
cls.version_1(ctx, node, **kwargs)
return

Expand Down

0 comments on commit cf2b252

Please sign in to comment.