-
-
Notifications
You must be signed in to change notification settings - Fork 55.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #24092 from Aser-Abdelfatah:GSoC_Support_GatherEle…
…ments_ONNX GSoC Add ONNX Support for GatherElements #24092 Merge with: opencv/opencv_extra#1082 Adds support to the ONNX operator GatherElements [operator docs](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements) Added tests to opencv_extra at pull request opencv/opencv_extra#1082 ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
- Loading branch information
1 parent
014e848
commit 240b245
Showing
8 changed files
with
303 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
// This file is part of OpenCV project. | ||
// It is subject to the license terms in the LICENSE file found in the top-level directory | ||
// of this distribution and at http://opencv.org/license.html. | ||
|
||
#include "../precomp.hpp" | ||
#include <opencv2/dnn/shape_utils.hpp> | ||
|
||
namespace cv { namespace dnn { | ||
|
||
static inline int calculateOffset(int outer_dim, const MatShape &shape_indices, int axis_skip, const MatStep &step_data) { | ||
int offset = 0; | ||
for (int axis = static_cast<int>(shape_indices.size()) - 2; axis >= 0; axis--) { | ||
int dim = shape_indices[axis]; | ||
if (axis != axis_skip) { | ||
offset += (outer_dim % dim) * step_data[axis]; | ||
} | ||
outer_dim /= dim; | ||
} | ||
return offset; | ||
} | ||
|
||
class GatherElementsLayerImpl CV_FINAL : public GatherElementsLayer | ||
{ | ||
public: | ||
GatherElementsLayerImpl(const LayerParams& params) | ||
{ | ||
setParamsFrom(params); | ||
axis = params.get<int>("axis", 0); | ||
} | ||
|
||
virtual bool supportBackend(int backendId) CV_OVERRIDE | ||
{ | ||
return backendId == DNN_BACKEND_OPENCV; | ||
} | ||
|
||
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs, | ||
const int requiredOutputs, | ||
std::vector<MatShape> &outputs, | ||
std::vector<MatShape> &internals) const CV_OVERRIDE | ||
{ | ||
CV_CheckEQ(inputs.size(), 2ull, "GatherElements: requires two inputs"); | ||
|
||
const auto &data = inputs[0]; | ||
const auto &indices = inputs[1]; | ||
CV_CheckEQ(data.size(), indices.size(), "GatherElements: data and indices should have the same dimension"); | ||
|
||
int normalized_axis = normalize_axis(axis, static_cast<int>(data.size())); | ||
CV_CheckGE(normalized_axis, 0, "GatherElements: axis out of range"); | ||
CV_CheckLT(normalized_axis, static_cast<int>(data.size()), "GatherElements: axis out of range"); | ||
for (size_t i = 0; i < data.size(); i++) { | ||
if (i != normalized_axis) { | ||
CV_CheckEQ(data[i], indices[i], "GatherElements: shape mismatched"); | ||
} | ||
} | ||
|
||
outputs.assign(1, inputs[1]); // shape of output is same as indices | ||
return false; | ||
} | ||
|
||
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE { | ||
std::vector<Mat> inputs; | ||
inputs_arr.getMatVector(inputs); | ||
|
||
const auto &data = inputs[0]; | ||
axis = normalize_axis(axis, data.dims); | ||
} | ||
|
||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE | ||
{ | ||
CV_TRACE_FUNCTION(); | ||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); | ||
|
||
std::vector<Mat> inputs, outputs; | ||
inputs_arr.getMatVector(inputs); | ||
outputs_arr.getMatVector(outputs); | ||
|
||
const Mat& data = inputs[0]; | ||
const Mat& indices = inputs[1]; | ||
Mat& out = outputs[0]; | ||
|
||
typeDispatch(outputs[0].type(), data, indices, out); | ||
} | ||
|
||
template <typename T> | ||
void forward_impl(const Mat& data_, const Mat& indices_, Mat& out_) | ||
{ | ||
const auto *ptr_data = data_.ptr<const T>(); | ||
const auto *ptr_indices = indices_.ptr<const T>(); | ||
auto *ptr_out = out_.ptr<T>(); | ||
|
||
const auto shape_data = shape(data_); | ||
const auto &step_data = data_.step; | ||
const auto shape_indices = shape(indices_); | ||
|
||
int inner_most_dim = shape_indices.back(); | ||
int axis_dim = shape_data[axis]; | ||
size_t axis_step = static_cast<size_t>(step_data[axis] / sizeof(T)); | ||
|
||
bool innermost_axis = axis == static_cast<int>(shape_data.size() - 1); | ||
|
||
auto fn = [&](const Range &r) { | ||
for (int i = r.start; i < r.end; i++) { | ||
auto *data = ptr_data + static_cast<size_t>(calculateOffset(i, shape_indices, axis, step_data) / sizeof(T)); | ||
auto *indices = ptr_indices + i * inner_most_dim; | ||
auto *out = ptr_out + i * inner_most_dim; | ||
|
||
if (innermost_axis) { | ||
for (int j = 0; j < inner_most_dim; j++) { | ||
int index = static_cast<int>((indices[j] + axis_dim)) % axis_dim; // TODO: Check out-of-range index | ||
out[j] = data[index]; | ||
} | ||
} else { | ||
for (int j = 0; j < inner_most_dim; j++) { | ||
int index = static_cast<int>(indices[j] + axis_dim) % axis_dim; // TODO: Check out-of-range index | ||
out[j] = data[index * axis_step + j]; | ||
} | ||
} | ||
} | ||
}; | ||
|
||
int outer_dims = total(shape_indices, 0, shape_indices.size() - 1); | ||
double nstripes = static_cast<size_t>(outer_dims * inner_most_dim * (1 / 1024.0)); | ||
parallel_for_(Range(0, outer_dims), fn, nstripes); | ||
} | ||
|
||
template<typename... Args> | ||
inline void typeDispatch(const int type, Args&&... args) | ||
{ | ||
switch (type) | ||
{ | ||
case CV_8U: | ||
forward_impl<uint8_t>(std::forward<Args>(args)...); | ||
break; | ||
case CV_32S: | ||
forward_impl<int32_t>(std::forward<Args>(args)...); | ||
break; | ||
case CV_32F: | ||
forward_impl<float>(std::forward<Args>(args)...); | ||
break; | ||
default: | ||
CV_Error(cv::Error::BadDepth, "DNN/GatherElements: Unsupported type."); | ||
}; | ||
} | ||
|
||
private: | ||
int axis; | ||
}; | ||
|
||
Ptr<GatherElementsLayer> GatherElementsLayer::create(const LayerParams& params) | ||
{ | ||
return makePtr<GatherElementsLayerImpl>(params); | ||
} | ||
|
||
}} // namespace cv::dnn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters