Skip to content

Commit

Permalink
fix slice 1 shape inference with negative starts/ends
Browse files Browse the repository at this point in the history
Signed-off-by: daquexian <daquexian566@gmail.com>
  • Loading branch information
daquexian committed Nov 2, 2023
1 parent 46c3ab3 commit 0c5711c
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions onnx/defs/tensor/old.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3723,9 +3723,7 @@ ONNX_OPERATOR_SET_SCHEMA(
}

auto is_negative = [](int64_t index) { return index < 0; };
if (std::any_of(starts.begin(), starts.end(), is_negative) ||
std::any_of(ends.begin(), ends.end(), is_negative) ||
std::any_of(axes.begin(), axes.end(), is_negative)) {
if (std::any_of(axes.begin(), axes.end(), is_negative)) {
// Negative axes were not explicitly discussed in the spec before opset-10.
// Hence, they are officially not part of the spec, but some models/runtimes may use them.
// So we perform simple rank inference in this case.
Expand All @@ -3742,13 +3740,20 @@ ONNX_OPERATOR_SET_SCHEMA(
if (j < axes.size() && static_cast<size_t>(axes[j]) == i) {
// There's a lot of potential behaviors. For now just
// handle some simple cases.
if (ctx.getInputType(0)->tensor_type().shape().dim((int)i).has_dim_value() && starts[j] >= 0 &&
ends[j] >= 0) {
auto newval =
std::min((int64_t)ctx.getInputType(0)->tensor_type().shape().dim((int)i).dim_value(), ends[j]) -
starts[j];
if (newval >= 0) {
newdim->set_dim_value(newval);
const auto& dim = ctx.getInputType(0)->tensor_type().shape().dim((int)i);
if (dim.has_dim_value()) {
auto dim_value = dim.dim_value();
if (starts[j] < 0) {
starts[j] += dim_value;
}
if (ends[j] < 0) {
ends[j] += dim_value;
}
if (starts[j] >= 0 && ends[j] >= 0) {
auto newval = std::min(dim_value, ends[j]) - starts[j];
if (newval >= 0) {
newdim->set_dim_value(newval);
}
}
}
++j;
Expand Down

0 comments on commit 0c5711c

Please sign in to comment.