diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 8b139c2d5514f..315237a7fe394 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1071,9 +1071,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Round); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Round); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Round); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 21, float, Round); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 21, double, Round); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 21, MLFloat16, Round); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, int8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, uint8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); @@ -1190,14 +1190,14 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, E class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Sum); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Max); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Min); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, bool, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int32_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int64_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint32_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint64_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, float, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, double, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, MLFloat16, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Greater); @@ -1568,6 +1568,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint8_t, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Shape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, bool, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int32_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int64_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint32_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint64_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, double, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, MLFloat16, Equal); // Opset 20 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu); @@ -1676,6 +1684,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, RoiAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, RoiAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, RoiAlign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, Round); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, Round); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, Round); // Opset 23. class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention); @@ -2279,9 +2290,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2393,14 +2404,14 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2772,6 +2783,14 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 20 BuildKernelCreateInfo, @@ -2880,6 +2899,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 23 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc index e4faa50d7acbc..babbb4b3ba672 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc @@ -579,8 +579,10 @@ Status LessOrEqual::ComputeInternal(OpKernelContext* context) const { return this->CompareMethod(context, &ImplT2_LessOrEqual); } -BINARY_LOGICALOP_REGISTER_UZILHFD(Equal, 13) -BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(Equal, 13, bool) +BINARY_LOGICALOP_REGISTER_VERSIONED_UZILHFD(Equal, 13, 18) +BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(Equal, 13, 18, bool) +BINARY_LOGICALOP_REGISTER_UZILHFD(Equal, 19) +BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(Equal, 19, bool) BINARY_OP_REGISTER_VERSIONED_UZILHFD(Equal, 11, 12) BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(Equal, 11, 12, bool) BINARY_OP_REGISTER_VERSIONED_OIL(Equal, 7, 10) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 86a1b0f5b6102..a54b96da6c174 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -249,7 +249,8 @@ UNARY_OP_HFDX(Erf, 13) UNARY_OP_BWUZCSILHFDX(Sign, 13) UNARY_LOGICALOP_NOT_TYPED(1, bool) -UNARY_OP_HFD(Round, 11) +UNARY_OP_VERSIONED_HFD(Round, 11, 21) +UNARY_OP_HFD(Round, 22) UNARY_OP_HFD(Cos, 7) UNARY_OP_HFD(Sin, 7) diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 283f20a4be9b0..11a4b373c53f1 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -3674,6 +3674,127 @@ TEST(MathOpTest, Equal_string) { test.Run(); } +#ifdef USE_CUDA +// Opset 19 tests for numeric types (CUDA EP) +TEST(MathOpTest, Equal_19_bool) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {false, true, false, true}); + test.AddInput("B", dims, {false, false, true, true}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_int32) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {1, 0, -1, -1}); + test.AddInput("B", dims, {1, 1, 2, -1}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_int64) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {1, 0, -1, -1}); + test.AddInput("B", dims, {1, 1, 2, -1}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_float) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {1.0f, 0.0f, -1.0f, -1.0f}); + test.AddInput("B", dims, {1.0f, 1.0f, 2.0f, -1.0f}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_double) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {1.0, 0.0, -1.0, -1.0}); + test.AddInput("B", dims, {1.0, 1.0, 2.0, -1.0}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_float16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {MLFloat16(1.0f), MLFloat16(0.0f), MLFloat16(-1.0f), MLFloat16(-1.0f)}); + test.AddInput("B", dims, {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(-1.0f)}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_broadcastAB) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + test.AddInput("A", {4, 2}, {1, 0, -1, -1, 1, 1, -1, 0}); + test.AddInput("B", {2}, {1, 1}); + test.AddOutput("C", {4, 2}, {true, false, false, false, true, true, false, false}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +#endif + #if defined(USE_DNNL) TEST(MathOpTest, Equal_bfloat16) { #ifdef USE_DNNL diff --git a/onnxruntime/test/providers/cpu/math/round_test.cc b/onnxruntime/test/providers/cpu/math/round_test.cc index 5df14ac079a63..48f96fe4f8494 100644 --- a/onnxruntime/test/providers/cpu/math/round_test.cc +++ b/onnxruntime/test/providers/cpu/math/round_test.cc @@ -3,6 +3,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" #include "core/framework/data_types.h" #include "core/util/math.h" @@ -30,5 +31,53 @@ TEST(RoundTest, SimpleTestFloat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +#ifdef USE_CUDA +// Opset 22 tests +TEST(RoundTest, Round22_Float) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Round", 22, onnxruntime::kOnnxDomain); + test.AddInput("x", {5}, {0.9f, 2.5f, 2.3f, 1.5f, -4.5f}); + test.AddOutput("y", {5}, {1.0f, 2.0f, 2.0f, 2.0f, -4.0f}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(RoundTest, Round22_Double) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Round", 22, onnxruntime::kOnnxDomain); + test.AddInput("x", {5}, {0.9, 2.5, 2.3, 1.5, -4.5}); + test.AddOutput("y", {5}, {1.0, 2.0, 2.0, 2.0, -4.0}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(RoundTest, Round22_Float16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Round", 22, onnxruntime::kOnnxDomain); + test.AddInput("x", {5}, {MLFloat16(0.9f), MLFloat16(2.5f), MLFloat16(2.3f), MLFloat16(1.5f), MLFloat16(-4.5f)}); + test.AddOutput("y", {5}, {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(2.0f), MLFloat16(2.0f), MLFloat16(-4.0f)}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +#endif + } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime