Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Ceil);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Ceil);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Ceil);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Clip);
Comment thread
snnn marked this conversation as resolved.
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Reciprocal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Reciprocal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Reciprocal);
Expand Down Expand Up @@ -545,6 +546,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Clip)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Tile)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Tile)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Tile)>,
Expand Down Expand Up @@ -820,7 +822,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 9, int32_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 9, int64_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, int32_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, int64_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, int64_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, Compress)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, float, Upsample)>,
Expand Down
42 changes: 42 additions & 0 deletions onnxruntime/core/providers/cuda/math/clip.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/common.h"
#include "core/providers/cuda/math/clip.h"
#include "core/providers/cuda/math/clip_impl.h"

namespace onnxruntime {
namespace cuda {

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Clip, \
kOnnxDomain, \
6, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Clip<T>);

template <typename T>
Status Clip<T>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor& X = *ctx->Input<Tensor>(0);
const TensorShape input_shape{X.Shape()};
Tensor* Y = ctx->Output(0, input_shape);

size_t count = input_shape.Size();

auto* y_data = Y->template MutableData<T>();
const auto* x_data = X.template Data<T>();
ClipImpl<T>(x_data, y_data, min_, max_, count);
return Status::OK();
}

#define SPECIALIZED_COMPUTE(T) \
REGISTER_KERNEL_TYPED(T) \
template Status Clip<T>::ComputeInternal(OpKernelContext* ctx) const;

SPECIALIZED_COMPUTE(float)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why float only? version 6 also have double and float16, right?
why not support opset 11 also?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add float16 and double

Copy link
Copy Markdown
Contributor

@chilo-ms chilo-ms Sep 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While using GetAttrOrDefault() to get attribute for minimum/maximum values, it seems that ONNX has not yet supported double and float16 type as listed in AttributeProto_AttributeType. Once ONNX supports double and float16, we can add them as well.


} // namespace cuda
} // namespace onnxruntime
31 changes: 31 additions & 0 deletions onnxruntime/core/providers/cuda/math/clip.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/providers/cuda/cuda_common.h"

namespace onnxruntime {
namespace cuda {

template <typename T>
class Clip final : public CudaKernel {
public:
Clip(const OpKernelInfo& info) : CudaKernel{info} {
auto min_val = -std::numeric_limits<T>::infinity();
auto max_val = std::numeric_limits<T>::infinity();

info.GetAttrOrDefault("min", &min_, min_val);
info.GetAttrOrDefault("max", &max_, max_val);

// Make sure the range of interval is sensible
ORT_ENFORCE(min_val <= max_val);
}

Status ComputeInternal(OpKernelContext* context) const override;

private:
T min_, max_;
};

} // namespace cuda
} // namespace onnxruntime
33 changes: 33 additions & 0 deletions onnxruntime/core/providers/cuda/math/clip_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/providers/cuda/math/clip_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"

namespace onnxruntime {
namespace cuda {
template <typename T>
__global__ void _Clip(const T* input, T* output, T min, T max, size_t N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
output[id] = (input[id] < min) ? min : ((input[id] > max) ? max : input[id]);
}

template <typename T>
void ClipImpl(const T* input_data, T* output_data, T min, T max, size_t count) {
typedef typename ToCudaType<T>::MappedType CudaT;

int blocksPerGrid = (int)(ceil(static_cast<float>(count) / GridDim::maxThreadsPerBlock));
_Clip<CudaT><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(reinterpret_cast<const CudaT*>(input_data),
reinterpret_cast<CudaT*>(output_data),
*reinterpret_cast<CudaT*>(&min),
*reinterpret_cast<CudaT*>(&max),
count);
}

template void ClipImpl<float>(const float* input_data, float* output_data, float min, float max, size_t count);
template void ClipImpl<double>(const double* input_data, double* output_data, double min, double max, size_t count);
template void ClipImpl<MLFloat16>(const MLFloat16* input_data, MLFloat16* output_data, MLFloat16 min, MLFloat16 max, size_t count);

} // namespace cuda
} // namespace onnxruntime
16 changes: 16 additions & 0 deletions onnxruntime/core/providers/cuda/math/clip_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/cuda/math/clip.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/cuda_utils.h"

namespace onnxruntime {
namespace cuda {
template <typename T>
void ClipImpl(const T* input_data, T* output_data, T min, T max, size_t count);

} // namespace cuda
} // namespace onnxruntime