Skip to content

Commit

Permalink
Add support for one hot encoding of 0-rank tensor (#1771)
Browse files Browse the repository at this point in the history
* Add support for one hot encoding of 0-rank tensor

Signed-off-by: Dagnas <dagnas@sinequa.com>

* Skip the newly added test for tfjs

Signed-off-by: Dagnas <dagnas@sinequa.com>

Co-authored-by: Dagnas <dagnas@sinequa.com>
Co-authored-by: Guenther Schmuelling <guschmue@microsoft.com>
  • Loading branch information
3 people committed Nov 22, 2021
1 parent a1a9343 commit 65aaa2c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2421,6 +2421,18 @@ def func(x):
graph = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
self.assertTrue(len(group_nodes_by_type(graph)["OneHot"]) == 1, "onnx onehot should be used")

@check_opset_min_version(9, "onehot")
@skip_tfjs("tfjs produces incorrect results")
def test_onehot_rank0(self):
depth = 5
for np_dtype in [np.int32, np.int64]:
x_val = np.array(3, dtype=np_dtype)
for axis in [-1, 0]:
def func(x):
x_ = tf.one_hot(x, depth, axis=axis)
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@skip_caffe2_backend("issue undefined dim 1")
@check_tf_max_version("1.15", "not supported in tf-2.0")
def test_flatten0(self):
Expand Down
33 changes: 33 additions & 0 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,19 @@ def any_version_after9(cls, opset, ctx, node, **kwargs):
off_on_value = ctx.make_node("Concat", [off_value, on_value], attr={"axis": 0}).output[0]

indices = node.input[0]
indices_rank = ctx.get_rank(indices)

# Add a special support for 0-rank indices, to do so we have to expand the dimension to 1
# before the one hot encoding and remove it after.
if indices_rank == 0:
dims = ctx.make_const(name=utils.make_name('dims'), np_val=np.array([1], dtype=np.int64))
indices = ctx.make_node("Expand", [indices, dims.name]).output[0]

# Axis 0 is supported by TensorFlow for the one-hot encoding of a 0-rank tensor. It should behave
# as if axis has been set to -1 so we artificially set it as is here.
if node.get_attr('axis').i == 0:
node.set_attr('axis', -1)

if ctx.is_target(constants.TARGET_RS6) \
and ctx.get_dtype(indices) != onnx_pb.TensorProto.INT64:
indices = ctx.make_node("Cast", [indices], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
Expand All @@ -1367,6 +1380,26 @@ def any_version_after9(cls, opset, ctx, node, **kwargs):
ctx.set_dtype(new_node.output[0], output_dtype)
ctx.set_shape(new_node.output[0], ctx.get_shape(node.output[0]))

# Remove the dimension artificially added in order to support 0-rank indices
if indices_rank == 0:
nodes = [node]
name = utils.make_name(node.name)
shape = ctx.get_shape(node.output[0])
dtype = ctx.get_dtype(node.output[0])
squeeze_node = GraphBuilder(ctx).make_squeeze(
{
"axes": [0],
'data': node.output[0]
},
name=name,
dtypes=[dtype],
shapes=[shape],
return_node=True)
ctx.insert_node_on_output(squeeze_node)

nodes.append(squeeze_node)
ctx.update_node_shape_dtype(node, override=True)

@classmethod
def version_9(cls, ctx, node, **kwargs):
cls.any_version_after9(9, ctx, node, **kwargs)
Expand Down

0 comments on commit 65aaa2c

Please sign in to comment.