diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 5f4188c763d93..1dfb3957292dc 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -418,6 +418,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Un class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, TopK); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, double, TopK); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int64_t, TopK); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int32_t, TopK); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int64_t_int64_t_int64_t, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float_int64_t_int64_t, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int64_t_string_int64_t, OneHot); @@ -1263,6 +1264,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { TopK)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo::TopK(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_ TopkOpset11ConstructorCommon(op_kernel_info, axis_, largest_, sorted_); } +template <> +TopK<11, int32_t>::TopK(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) { + TopkOpset11ConstructorCommon(op_kernel_info, axis_, largest_, sorted_); +} + template <> TopK<11, int64_t>::TopK(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) { TopkOpset11ConstructorCommon(op_kernel_info, axis_, largest_, sorted_); @@ -507,6 +512,11 @@ Status TopK<11, double>::Compute(OpKernelContext* p_op_kernel_context) const { return ComputeImplOpset1011(p_op_kernel_context, axis_, largest_, sorted_); } +template <> +Status TopK<11, int32_t>::Compute(OpKernelContext* p_op_kernel_context) const { + return ComputeImplOpset1011(p_op_kernel_context, axis_, largest_, sorted_); +} + template <> Status TopK<11, int64_t>::Compute(OpKernelContext* p_op_kernel_context) const { return ComputeImplOpset1011(p_op_kernel_context, axis_, largest_, sorted_); @@ -539,5 +549,6 @@ REGISTER_TOPK_VERSIONED_TYPED_KERNEL(10, 10, double); REGISTER_TOPK_TYPED_KERNEL(11, float); REGISTER_TOPK_TYPED_KERNEL(11, double); REGISTER_TOPK_TYPED_KERNEL(11, int64_t); +REGISTER_TOPK_TYPED_KERNEL(11, int32_t); } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/math/topk_op_test.cc b/onnxruntime/test/providers/cpu/math/topk_op_test.cc index 843db128e0044..f1f6b34ca581b 100644 --- a/onnxruntime/test/providers/cpu/math/topk_op_test.cc +++ b/onnxruntime/test/providers/cpu/math/topk_op_test.cc @@ -584,7 +584,7 @@ TEST(TopKOperator, Top3ExplicitAxisSmallestElements) { template static void top_1_explicit_axis_MultiD_input_smallest(int opset_version, int64_t sorted = 1) { - std::vector input_vals = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + std::vector input_vals = {1, 2, 3, 4, 5, 6, 7, 8}; std::vector input_dimensions = {2, 2, 2}; std::vector expected_vals = {1, 2, 5, 6}; std::vector expected_indices = {0, 0, 0, 0}; @@ -598,6 +598,10 @@ TEST(TopKOperator, Top1ExplicitAxisMultiDInputSmallestElements) { top_1_explicit_axis_MultiD_input_smallest(11, 0); //unsorted top_1_explicit_axis_MultiD_input_smallest(11); top_1_explicit_axis_MultiD_input_smallest(11, 0); //unsorted + top_1_explicit_axis_MultiD_input_smallest(11); + top_1_explicit_axis_MultiD_input_smallest(11, 0); //unsorted + top_1_explicit_axis_MultiD_input_smallest(11); + top_1_explicit_axis_MultiD_input_smallest(11, 0); //unsorted } // test path where SelectTopK is used (select using std::nth_element)