-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Duli/clip cuda #1677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Duli/clip cuda #1677
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "core/providers/common.h" | ||
| #include "core/providers/cuda/math/clip.h" | ||
| #include "core/providers/cuda/math/clip_impl.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace cuda { | ||
|
|
||
| #define REGISTER_KERNEL_TYPED(T) \ | ||
| ONNX_OPERATOR_TYPED_KERNEL_EX( \ | ||
| Clip, \ | ||
| kOnnxDomain, \ | ||
| 6, \ | ||
| T, \ | ||
| kCudaExecutionProvider, \ | ||
| KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \ | ||
| Clip<T>); | ||
|
|
||
| template <typename T> | ||
| Status Clip<T>::ComputeInternal(OpKernelContext* ctx) const { | ||
| const Tensor& X = *ctx->Input<Tensor>(0); | ||
| const TensorShape input_shape{X.Shape()}; | ||
| Tensor* Y = ctx->Output(0, input_shape); | ||
|
|
||
| size_t count = input_shape.Size(); | ||
|
|
||
| auto* y_data = Y->template MutableData<T>(); | ||
| const auto* x_data = X.template Data<T>(); | ||
| ClipImpl<T>(x_data, y_data, min_, max_, count); | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| #define SPECIALIZED_COMPUTE(T) \ | ||
| REGISTER_KERNEL_TYPED(T) \ | ||
| template Status Clip<T>::ComputeInternal(OpKernelContext* ctx) const; | ||
|
|
||
| SPECIALIZED_COMPUTE(float) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why float only? version 6 also have double and float16, right?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add float16 and double
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While using GetAttrOrDefault() to get attribute for minimum/maximum values, it seems that ONNX has not yet supported double and float16 type as listed in AttributeProto_AttributeType. Once ONNX supports double and float16, we can add them as well. |
||
|
|
||
| } // namespace cuda | ||
| } // namespace onnxruntime | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
| #include "core/providers/cuda/cuda_common.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace cuda { | ||
|
|
||
| template <typename T> | ||
| class Clip final : public CudaKernel { | ||
| public: | ||
| Clip(const OpKernelInfo& info) : CudaKernel{info} { | ||
| auto min_val = -std::numeric_limits<T>::infinity(); | ||
| auto max_val = std::numeric_limits<T>::infinity(); | ||
|
|
||
| info.GetAttrOrDefault("min", &min_, min_val); | ||
| info.GetAttrOrDefault("max", &max_, max_val); | ||
|
|
||
| // Make sure the range of interval is sensible | ||
| ORT_ENFORCE(min_val <= max_val); | ||
| } | ||
|
|
||
| Status ComputeInternal(OpKernelContext* context) const override; | ||
|
|
||
| private: | ||
| T min_, max_; | ||
| }; | ||
|
|
||
| } // namespace cuda | ||
| } // namespace onnxruntime |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
| #include "core/providers/cuda/math/clip_impl.h" | ||
| #include "core/providers/cuda/cu_inc/common.cuh" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace cuda { | ||
| template <typename T> | ||
| __global__ void _Clip(const T* input, T* output, T min, T max, size_t N) { | ||
| CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); | ||
| output[id] = (input[id] < min) ? min : ((input[id] > max) ? max : input[id]); | ||
| } | ||
|
|
||
| template <typename T> | ||
| void ClipImpl(const T* input_data, T* output_data, T min, T max, size_t count) { | ||
| typedef typename ToCudaType<T>::MappedType CudaT; | ||
|
|
||
| int blocksPerGrid = (int)(ceil(static_cast<float>(count) / GridDim::maxThreadsPerBlock)); | ||
| _Clip<CudaT><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(reinterpret_cast<const CudaT*>(input_data), | ||
| reinterpret_cast<CudaT*>(output_data), | ||
| *reinterpret_cast<CudaT*>(&min), | ||
| *reinterpret_cast<CudaT*>(&max), | ||
| count); | ||
| } | ||
|
|
||
| template void ClipImpl<float>(const float* input_data, float* output_data, float min, float max, size_t count); | ||
| template void ClipImpl<double>(const double* input_data, double* output_data, double min, double max, size_t count); | ||
| template void ClipImpl<MLFloat16>(const MLFloat16* input_data, MLFloat16* output_data, MLFloat16 min, MLFloat16 max, size_t count); | ||
|
|
||
| } // namespace cuda | ||
| } // namespace onnxruntime |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "core/providers/cuda/math/clip.h" | ||
| #include "core/providers/cuda/cuda_common.h" | ||
| #include "core/providers/cuda/shared_inc/cuda_utils.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace cuda { | ||
| template <typename T> | ||
| void ClipImpl(const T* input_data, T* output_data, T min, T max, size_t count); | ||
|
|
||
| } // namespace cuda | ||
| } // namespace onnxruntime |
Uh oh!
There was an error while loading. Please reload this page.