Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/core/codegen/common/op_macro.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ namespace onnxruntime {
ADD_OP_ITEM(Dropout) \
ADD_OP_ITEM(Flatten) \
ADD_OP_ITEM(Gather) \
ADD_OP_ITEM(GatherElements) \
ADD_OP_ITEM(Gemm) \
ADD_OP_ITEM(Identity) \
ADD_OP_ITEM(LogSoftmax) \
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/codegen/mti/mti_tvm_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,5 +192,13 @@ tvm::Array<tvm::Tensor> MakeInputsForExtern(const tvm::Array<tvm::Tensor>& input
return fixed_inputs;
}

// Make sure idx is clamped in the range of [-bound, bound - 1]
tvm::Expr ClampIndex(const tvm::Expr& idx, const tvm::Expr& bound) {
// when idx >= 0, we take tvm::max(..., 0), because (idx < 0) is 0
// when idx < 0, we take bound + tvm::max(...), because tvm::max(idx, 0) is 0
return tvm::max(tvm::min(idx, bound - 1), 0) +
(idx < 0) * (bound + tvm::max(idx, -bound));
}

} // namespace tvm_codegen
} // namespace onnxruntime
3 changes: 3 additions & 0 deletions onnxruntime/core/codegen/mti/mti_tvm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ inline int64_t HandleNegativeAxis(int64_t axis, int64_t rank) {
return axis = axis < 0 ? (axis + rank) : axis;
}

// Make sure idx is clamped in the range of [-bound, bound - 1]
tvm::Expr ClampIndex(const tvm::Expr& idx, const tvm::Expr& bound);

// Helper function to workaround tvm ExternOp issue when input has symbolic dimensions
tvm::Array<tvm::Tensor> MakeInputsForExtern(const tvm::Array<tvm::Tensor>& inputs, const std::string& name = "make_inputs_for_extern");

Expand Down
45 changes: 45 additions & 0 deletions onnxruntime/core/codegen/mti/tensor/gather_elements.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/codegen/mti/tensor/gather_elements.h"

#include "core/codegen/mti/mti_tvm_utils.h"
#include <topi/transform.h>

namespace onnxruntime {
namespace tvm_codegen {

tvm::Tensor GatherElements(const tvm::Tensor& t,
int64_t axis,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis [](start = 35, length = 4)

can this be negative? if not, maybe should use unsigned types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We always use int64_t for axis in other tensor apis. Seems it's better to follow the convention.

const tvm::Tensor& indices,
const std::string& name) {
tvm::Array<tvm::Expr> output_shape;
int64_t indices_rank = static_cast<int64_t>(indices->shape.size());
// output shape is the same as indices
for (int64_t i = 0; i < indices_rank; ++i)
output_shape.push_back(indices->shape[i]);

tvm::Expr idx_upper_bound = t->shape[axis];
auto l = [&](const tvm::Array<tvm::Var>& ovars) {
tvm::Array<tvm::Expr> ivars;
for (int64_t i = 0; i < indices_rank; i++) {
if (i == axis) {
tvm::Array<tvm::Expr> idx_vars;
for (int64_t j = 0; j < indices_rank; j++)
idx_vars.push_back(ovars[j]);
// make sure idx is clamped in the range of [-idx_upper_bound, idx_upper_bound - 1]
tvm::Expr real_idx = tvm_codegen::ClampIndex(indices(idx_vars), idx_upper_bound);
// tvm idx must be of Int(32)
ivars.push_back(tvm::cast(tvm::Int(32), real_idx));
Copy link
Contributor

@ke1337 ke1337 Oct 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe better to make it a common util for index clamping, as other ops may reuse it too (Gather?) #Resolved

Copy link
Contributor Author

@yangchen-MS yangchen-MS Oct 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. #Resolved

} else {
ivars.push_back(ovars[i]);
}
}
return t(ivars);
};

return tvm::compute(output_shape, l, name);
}

} // namespace tvm_codegen
} // namespace onnxruntime
17 changes: 17 additions & 0 deletions onnxruntime/core/codegen/mti/tensor/gather_elements.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <string>
#include <tvm/tvm.h>

namespace onnxruntime {
namespace tvm_codegen {

tvm::Tensor GatherElements(const tvm::Tensor& t,
int64_t axis,
const tvm::Tensor& indices,
const std::string& name = "gather_elements");

} // namespace tvm_codegen
} // namespace onnxruntime
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/codegen/passes/op_ir_creator/all_ops.h"

#include "core/codegen/mti/tensor/gather_elements.h"
#include "core/framework/op_kernel_info.h"
#include "core/providers/common.h"

namespace onnxruntime {
namespace tvm_codegen {

// Evaluate of GatherElements OpIRCreator
Status GENERIC_OP_IR_CREATOR_CLASS(GatherElements)::Evaluate(
const tvm::Array<tvm::Tensor>& inputs,
const Node& node,
CodeGenContext&,
tvm::Array<tvm::Tensor>& outputs) {
ProtoHelperNodeContext ctx(node);
OpNodeProtoHelper<ProtoHelperNodeContext> attrs(&ctx);

int64_t axis;
ORT_ENFORCE(attrs.GetAttr<int64_t>("axis", &axis).IsOK());
axis = HandleNegativeAxis(axis, gsl::narrow_cast<int64_t>(inputs[0]->shape.size()));

tvm::Tensor Y = GatherElements(inputs[0], axis, inputs[1], node.Name() + "_GatherElements");
outputs.push_back(Y);
return Status::OK();
}

} // namespace tvm_codegen
} // namespace onnxruntime
11 changes: 11 additions & 0 deletions onnxruntime/core/providers/nuphar/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,17 @@ ONNX_OPERATOR_KERNEL_EX(
DataTypeImpl::GetTensorType<int64_t>()}),
nuphar::NupharKernel);

ONNX_OPERATOR_KERNEL_EX(
GatherElements,
kOnnxDomain,
11,
kNupharExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
nuphar::NupharKernel);

ONNX_OPERATOR_KERNEL_EX(
MatMulInteger,
kOnnxDomain,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ LIST_NUPHAR_OPS()
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 6, 8, Cast);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, Cast);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 1, Gather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, GatherElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 10, MatMulInteger);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kMSDomain, 1, MatMulInteger16);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, Scan);
Expand All @@ -413,6 +414,7 @@ static void RegisterStandaloneNupharKernels(KernelRegistry& kernel_registry) {
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 6, 8, Cast)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, Cast)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 1, Gather)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, GatherElements)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 10, MatMulInteger)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kMSDomain, 1, MatMulInteger16)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, Scan)>());
Expand Down
10 changes: 8 additions & 2 deletions onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ void RunTypedTest()
test5.AddOutput<T>("output", {2, 2},
{1, 1,
4, 4});
test5.Run(OpTester::ExpectResult::kExpectFailure, "GatherElements op: Value in indices must be within bounds [-2 , 1]. Actual value is 2");
// skip nuphar, which will not throw error message but will ensure no out-of-bound access
test5.Run(OpTester::ExpectResult::kExpectFailure,
"GatherElements op: Value in indices must be within bounds [-2 , 1]. Actual value is 2",
{kNupharExecutionProvider});

// 3D input - axis 1
OpTester test6("GatherElements", 11);
Expand Down Expand Up @@ -158,7 +161,10 @@ void RunTypedTest<std::string>() {
test4.AddOutput<std::string>("output", {2, 2},
{"a", "a",
"d", "d"});
test4.Run(OpTester::ExpectResult::kExpectFailure, "GatherElements op: Value in indices must be within bounds [-2 , 1]. Actual value is -3");
// skip nuphar, which will not throw error message but will ensure no out-of-bound access
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which will not throw error message but will ensure no out-of-bound access [](start = 18, length = 73)

I believe this should be the default behavior in ONNX spec for this op. Lots of devices won't be able to throw on invalid indices. Please open an issue to ONNX.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

test4.Run(OpTester::ExpectResult::kExpectFailure,
"GatherElements op: Value in indices must be within bounds [-2 , 1]. Actual value is -3",
{kNupharExecutionProvider});

// 3D input - axis 1
OpTester test5("GatherElements", 11);
Expand Down