-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Added GatherElements to Nuphar #2016
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "core/codegen/mti/tensor/gather_elements.h" | ||
|
|
||
| #include "core/codegen/mti/mti_tvm_utils.h" | ||
| #include <topi/transform.h> | ||
|
|
||
| namespace onnxruntime { | ||
| namespace tvm_codegen { | ||
|
|
||
| tvm::Tensor GatherElements(const tvm::Tensor& t, | ||
| int64_t axis, | ||
| const tvm::Tensor& indices, | ||
| const std::string& name) { | ||
| tvm::Array<tvm::Expr> output_shape; | ||
| int64_t indices_rank = static_cast<int64_t>(indices->shape.size()); | ||
| // output shape is the same as indices | ||
| for (int64_t i = 0; i < indices_rank; ++i) | ||
| output_shape.push_back(indices->shape[i]); | ||
|
|
||
| tvm::Expr idx_upper_bound = t->shape[axis]; | ||
| auto l = [&](const tvm::Array<tvm::Var>& ovars) { | ||
| tvm::Array<tvm::Expr> ivars; | ||
| for (int64_t i = 0; i < indices_rank; i++) { | ||
| if (i == axis) { | ||
| tvm::Array<tvm::Expr> idx_vars; | ||
| for (int64_t j = 0; j < indices_rank; j++) | ||
| idx_vars.push_back(ovars[j]); | ||
| // make sure idx is clamped in the range of [-idx_upper_bound, idx_upper_bound - 1] | ||
| tvm::Expr real_idx = tvm_codegen::ClampIndex(indices(idx_vars), idx_upper_bound); | ||
| // tvm idx must be of Int(32) | ||
| ivars.push_back(tvm::cast(tvm::Int(32), real_idx)); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe better to make it a common util for index clamping, as other ops may reuse it too (Gather?) #Resolved
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. #Resolved |
||
| } else { | ||
| ivars.push_back(ovars[i]); | ||
| } | ||
| } | ||
| return t(ivars); | ||
| }; | ||
|
|
||
| return tvm::compute(output_shape, l, name); | ||
| } | ||
|
|
||
| } // namespace tvm_codegen | ||
| } // namespace onnxruntime | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
| #include <string> | ||
| #include <tvm/tvm.h> | ||
|
|
||
| namespace onnxruntime { | ||
| namespace tvm_codegen { | ||
|
|
||
| tvm::Tensor GatherElements(const tvm::Tensor& t, | ||
| int64_t axis, | ||
| const tvm::Tensor& indices, | ||
| const std::string& name = "gather_elements"); | ||
|
|
||
| } // namespace tvm_codegen | ||
| } // namespace onnxruntime |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "core/codegen/passes/op_ir_creator/all_ops.h" | ||
|
|
||
| #include "core/codegen/mti/tensor/gather_elements.h" | ||
| #include "core/framework/op_kernel_info.h" | ||
| #include "core/providers/common.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace tvm_codegen { | ||
|
|
||
| // Evaluate of GatherElements OpIRCreator | ||
| Status GENERIC_OP_IR_CREATOR_CLASS(GatherElements)::Evaluate( | ||
| const tvm::Array<tvm::Tensor>& inputs, | ||
| const Node& node, | ||
| CodeGenContext&, | ||
| tvm::Array<tvm::Tensor>& outputs) { | ||
| ProtoHelperNodeContext ctx(node); | ||
| OpNodeProtoHelper<ProtoHelperNodeContext> attrs(&ctx); | ||
|
|
||
| int64_t axis; | ||
| ORT_ENFORCE(attrs.GetAttr<int64_t>("axis", &axis).IsOK()); | ||
| axis = HandleNegativeAxis(axis, gsl::narrow_cast<int64_t>(inputs[0]->shape.size())); | ||
|
|
||
| tvm::Tensor Y = GatherElements(inputs[0], axis, inputs[1], node.Name() + "_GatherElements"); | ||
| outputs.push_back(Y); | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| } // namespace tvm_codegen | ||
| } // namespace onnxruntime |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -76,7 +76,10 @@ void RunTypedTest() | |
| test5.AddOutput<T>("output", {2, 2}, | ||
| {1, 1, | ||
| 4, 4}); | ||
| test5.Run(OpTester::ExpectResult::kExpectFailure, "GatherElements op: Value in indices must be within bounds [-2 , 1]. Actual value is 2"); | ||
| // skip nuphar, which will not throw error message but will ensure no out-of-bound access | ||
| test5.Run(OpTester::ExpectResult::kExpectFailure, | ||
| "GatherElements op: Value in indices must be within bounds [-2 , 1]. Actual value is 2", | ||
| {kNupharExecutionProvider}); | ||
|
|
||
| // 3D input - axis 1 | ||
| OpTester test6("GatherElements", 11); | ||
|
|
@@ -158,7 +161,10 @@ void RunTypedTest<std::string>() { | |
| test4.AddOutput<std::string>("output", {2, 2}, | ||
| {"a", "a", | ||
| "d", "d"}); | ||
| test4.Run(OpTester::ExpectResult::kExpectFailure, "GatherElements op: Value in indices must be within bounds [-2 , 1]. Actual value is -3"); | ||
| // skip nuphar, which will not throw error message but will ensure no out-of-bound access | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I believe this should be the default behavior in ONNX spec for this op. Lots of devices won't be able to throw on invalid indices. Please open an issue to ONNX.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
| test4.Run(OpTester::ExpectResult::kExpectFailure, | ||
| "GatherElements op: Value in indices must be within bounds [-2 , 1]. Actual value is -3", | ||
| {kNupharExecutionProvider}); | ||
|
|
||
| // 3D input - axis 1 | ||
| OpTester test5("GatherElements", 11); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be negative? if not, maybe should use unsigned types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We always use int64_t for axis in other tensor apis. Seems it's better to follow the convention.