-
-
Notifications
You must be signed in to change notification settings - Fork 55.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #24378 from fengyuentau:instance_norm
dnn onnx: add instance norm layer #24378 Resolves #24377 Relates #24092 (comment) | Perf | multi-thread | single-thread | | - | - | - | | x: [2, 64, 180, 240] | 3.95ms | 11.12ms | Todo: - [x] speed up by multi-threading - [x] add perf - [x] add backend: OpenVINO - [x] add backend: CUDA - [x] add backend: OpenCL (no fp16) - [ ] add backend: CANN (will be done via #24462) ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake ``` force_builders=Linux OpenCL,Win64 OpenCL,Custom buildworker:Custom=linux-4 build_image:Custom=ubuntu:18.04 modules_filter:Custom=none disable_ipp:Custom=ON ```
- Loading branch information
1 parent
832f738
commit ee0822d
Showing
10 changed files
with
454 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
// This file is part of OpenCV project. | ||
// It is subject to the license terms in the LICENSE file found in the top-level directory | ||
// of this distribution and at http://opencv.org/license.html. | ||
|
||
#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_INSTANCE_NORM_HPP | ||
#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_INSTANCE_NORM_HPP | ||
|
||
#include "../../op_cuda.hpp" | ||
|
||
#include "../csl/stream.hpp" | ||
#include "../csl/span.hpp" | ||
#include "../csl/tensor.hpp" | ||
#include "../csl/workspace.hpp" | ||
|
||
#include "../kernels/fill_copy.hpp" | ||
#include "../kernels/mvn.hpp" | ||
|
||
#include <opencv2/core.hpp> | ||
|
||
#include <cstddef> | ||
#include <vector> | ||
#include <utility> | ||
|
||
namespace cv { namespace dnn { namespace cuda4dnn { | ||
|
||
template <class T> | ||
class InstanceNormOp final : public CUDABackendNode { | ||
public: | ||
using wrapper_type = GetCUDABackendWrapperType<T>; | ||
|
||
InstanceNormOp(csl::Stream stream_, float epsilon_, size_t loops) | ||
: stream(std::move(stream_)), epsilon(epsilon_) { | ||
csl::WorkspaceBuilder builder; | ||
builder.require<float>(loops); | ||
builder.require<float>(loops); | ||
scratch_mem_in_bytes = builder.required_workspace_size(); | ||
} | ||
|
||
void forward(const std::vector<cv::Ptr<BackendWrapper>>& inputs, | ||
const std::vector<cv::Ptr<BackendWrapper>>& outputs, | ||
csl::Workspace& workspace) override { | ||
auto input_wrapper = inputs[0].dynamicCast<wrapper_type>(); | ||
auto scale_wrapper = inputs[1].dynamicCast<wrapper_type>(); | ||
auto bias_wrapper = inputs[2].dynamicCast<wrapper_type>(); | ||
|
||
auto input = input_wrapper->getView(); | ||
auto scale = scale_wrapper->getView(); | ||
auto bias = bias_wrapper->getView(); | ||
|
||
auto output_wrapper = outputs[0].dynamicCast<wrapper_type>(); | ||
auto output = output_wrapper->getSpan(); | ||
|
||
auto C = input.get_axis_size(1); | ||
auto loops = input.size_range(0, 2); | ||
auto norm_size = input.size_range(2, input.rank()); | ||
if (norm_size == 1) { | ||
kernels::fill<T>(stream, output, 0.f); | ||
return; | ||
} else { | ||
auto ws_allocator = csl::WorkspaceAllocator(workspace); | ||
|
||
auto mean = ws_allocator.get_span<float>(loops); | ||
kernels::fill<float>(stream, mean, 0.f); | ||
|
||
auto stdev = ws_allocator.get_span<float>(loops); | ||
kernels::fill<float>(stream, stdev, 0.f); | ||
|
||
kernels::reduce_mean_sqr_sum<T>(stream, mean, stdev, input, norm_size); | ||
kernels::compute_normalization_scale(stream, stdev, mean, stdev, norm_size, epsilon); | ||
kernels::normalize_mean_variance_channelwise<T>(stream, output, input, scale, bias, mean, stdev, norm_size, C); | ||
} | ||
} | ||
|
||
std::size_t get_workspace_memory_in_bytes() const noexcept override { return scratch_mem_in_bytes; } | ||
|
||
private: | ||
csl::Stream stream; | ||
|
||
float epsilon; | ||
|
||
std::size_t scratch_mem_in_bytes; | ||
}; | ||
|
||
}}} // cv::dnn::cuda4dnn | ||
|
||
#endif // OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_INSTANCE_NORM_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.