diff --git a/tests/test_backend.py b/tests/test_backend.py index 4deea6014..c0bd92a1a 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -2271,13 +2271,15 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) - @check_target("rs6", "onehot") + @skip_tfjs("tfjs produces incorrect results") def test_onehot0(self): x_val = np.array([0, 1, 2], dtype=np.int32) depth = 5 - for axis in [-1, 0, 1]: + for dtype, axis in [(tf.float32, -1), (tf.int64, 0), (tf.float64, 1)]: def func(x): - x_ = tf.one_hot(x, depth, on_value=5.0, axis=axis, off_value=1.0, dtype=tf.float32) + val1 = tf.constant(5, dtype) + val2 = tf.constant(1, dtype) + x_ = tf.one_hot(x, depth, on_value=val1, axis=axis, off_value=val2, dtype=dtype) return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index 625db5cd3..ee04592ca 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -1313,8 +1313,11 @@ def version_1(cls, ctx, node, **kwargs): if axis.i == 0: # TODO: revisit for rank > 1 name = utils.make_name(node.name) - transpose_node = ctx.insert_new_node_on_output("Transpose", node.output[0], name) - ctx.copy_shape(node.output[0], transpose_node.output[0]) + shape = ctx.get_shape(node.output[0]) + transpose_node = ctx.make_node("Transpose", [node.output[0]], name=name, shapes=[shape]) + ctx.insert_node_on_output(transpose_node, node.output[0]) + if shape is not None: + ctx.set_shape(node.output[0], shape[::-1]) @classmethod def any_version_after9(cls, opset, ctx, node, **kwargs): @@ -1323,9 +1326,11 @@ def any_version_after9(cls, opset, ctx, node, **kwargs): # in ONNX, op's schema is (input, depth, value, @int axis), meaning of "value" is [off-value, on-value] # onnxruntime only supports int64 output_dtype = ctx.get_dtype(node.input[2]) - if ctx.is_target(constants.TARGET_RS6) \ - and output_dtype not in [onnx_pb.TensorProto.INT64, onnx_pb.TensorProto.INT32]: - logger.warning("unsupported dtype in onnxruntime, onehot-9 can't be used directly") + supported_dtypes = [onnx_pb.TensorProto.FLOAT] + if ctx.is_target(constants.TARGET_RS6): + supported_dtypes = [onnx_pb.TensorProto.INT64, onnx_pb.TensorProto.INT32] + if output_dtype not in supported_dtypes: + logger.warning("unsupported dtype in target runtime, OneHot op can't be used directly") cls.version_1(ctx, node, **kwargs) return