Skip to content

Commit

Permalink
Handle OneHot's depth value during shape inference (#5963)
Browse files Browse the repository at this point in the history
### Description

- Even if the `depth` value of `OneHot` is specified as an initializer,
the value is ignored during shape inference currently.
- This is a correct shape inference from my understanding, but let me if
this is wrong.

---------

Signed-off-by: maekawatoshiki <maekawatoshiki1017@gmail.com>
Co-authored-by: Ke Zhang <linkerzhang@yeah.net>
  • Loading branch information
maekawatoshiki and linkerzhang committed Mar 27, 2024
1 parent 5d5a8c4 commit df36ccc
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
13 changes: 13 additions & 0 deletions onnx/defs/tensor/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <algorithm>
#include <cmath>
#include <numeric>
#include <optional>

#include "onnx/defs/data_propagators.h"
#include "onnx/defs/function.h"
Expand Down Expand Up @@ -2895,8 +2896,18 @@ ONNX_OPERATOR_SET_SCHEMA(
// and 1 element vector for now. In future when version update for
// this op is done we should only allow scalar or change the spec to
// allow both.
std::optional<int64_t> depth_value;
if (hasInputShape(ctx, 1)) {
auto& depth_shape = getInputShape(ctx, 1);
if (const TensorProto* depth_data = ctx.getInputData(1)) {
if (depth_data->data_type() == TensorProto::INT64) {
depth_value = ParseData<int64_t>(depth_data)[0];
} else if (depth_data->data_type() == TensorProto::INT32) {
depth_value = ParseData<int32_t>(depth_data)[0];
} else if (depth_data->data_type() == TensorProto::FLOAT) {
depth_value = static_cast<int64_t>(ParseData<float>(depth_data)[0]);
}
}
if (depth_shape.dim_size() != 0 && depth_shape.dim_size() != 1) {
fail_type_inference("Input 'depth' must be a scalar or rank 1 tensor.");
}
Expand Down Expand Up @@ -2947,6 +2958,8 @@ ONNX_OPERATOR_SET_SCHEMA(
} else if (indices_shape.dim(i - 1).has_dim_param()) {
dim->set_dim_param(indices_shape.dim(i - 1).dim_param());
}
} else if (depth_value) {
dim->set_dim_value(*depth_value);
}
}
}
Expand Down
13 changes: 13 additions & 0 deletions onnx/defs/tensor/old.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <algorithm>
#include <cmath>
#include <numeric>
#include <optional>

#include "onnx/defs/data_propagators.h"
#include "onnx/defs/function.h"
Expand Down Expand Up @@ -5152,8 +5153,18 @@ ONNX_OPERATOR_SET_SCHEMA(
// and 1 element vector for now. In future when version update for
// this op is done we should only allow scalar or change the spec to
// allow both.
std::optional<int64_t> depth_value;
if (hasInputShape(ctx, 1)) {
auto& depth_shape = getInputShape(ctx, 1);
if (const TensorProto* depth_data = ctx.getInputData(1)) {
if (depth_data->data_type() == TensorProto::INT64) {
depth_value = ParseData<int64_t>(depth_data)[0];
} else if (depth_data->data_type() == TensorProto::INT32) {
depth_value = ParseData<int32_t>(depth_data)[0];
} else if (depth_data->data_type() == TensorProto::FLOAT) {
depth_value = static_cast<int64_t>(ParseData<float>(depth_data)[0]);
}
}
if (depth_shape.dim_size() != 0 && depth_shape.dim_size() != 1) {
fail_type_inference("Input 'depth' must be a scalar or rank 1 tensor.");
}
Expand Down Expand Up @@ -5204,6 +5215,8 @@ ONNX_OPERATOR_SET_SCHEMA(
} else if (indices_shape.dim(i - 1).has_dim_param()) {
dim->set_dim_param(indices_shape.dim(i - 1).dim_param());
}
} else if (depth_value) {
dim->set_dim_value(*depth_value);
}
}
}
Expand Down
26 changes: 26 additions & 0 deletions onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4594,6 +4594,32 @@ def test_onehot_with_axis(self) -> None:
)
self._assert_inferred(graph, [make_tensor_value_info("Y", TensorProto.FLOAT, (2, None, 3, 5))]) # type: ignore

def test_onehot_without_axis_2(self) -> None:
graph = self._make_graph(
[
("indices", TensorProto.INT64, (2, 2)),
("depth", TensorProto.INT64, ()),
("values", TensorProto.FLOAT, (2,)),
],
[make_node("OneHot", ["indices", "depth", "values"], "Y")],
[],
initializer=[make_tensor("depth", TensorProto.INT64, (), (256,))],
)
self._assert_inferred(graph, [make_tensor_value_info("Y", TensorProto.FLOAT, (2, 2, 256))]) # type: ignore

def test_onehot_with_axis_2(self) -> None:
graph = self._make_graph(
[
("indices", TensorProto.INT64, (2, 3, 5)),
("depth", TensorProto.INT64, (1,)),
("values", TensorProto.FLOAT, (2,)),
],
[make_node("OneHot", ["indices", "depth", "values"], "Y", axis=1)],
[],
initializer=[make_tensor("depth", TensorProto.INT64, (1,), (256,))],
)
self._assert_inferred(graph, [make_tensor_value_info("Y", TensorProto.FLOAT, (2, 256, 3, 5))]) # type: ignore

def test_loop(self) -> None:
# can't use self._make_graph for the subgraph as it add more inputs for the Reshape operations it inserts.
# this breaks the subgraph inferencing as it expects the number of inputs passed from Loop to match
Expand Down

0 comments on commit df36ccc

Please sign in to comment.