From 24d45a30203be97c16e41354ba6b84647c421d58 Mon Sep 17 00:00:00 2001 From: RandySheriffH Date: Wed, 29 Apr 2020 20:41:43 -0700 Subject: [PATCH] sync threads before calling next cub function --- onnxruntime/core/providers/cuda/math/topk_impl.cu | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cuda/math/topk_impl.cu b/onnxruntime/core/providers/cuda/math/topk_impl.cu index 8f447644bae0c..851270c798fc7 100644 --- a/onnxruntime/core/providers/cuda/math/topk_impl.cu +++ b/onnxruntime/core/providers/cuda/math/topk_impl.cu @@ -148,8 +148,17 @@ __global__ void BitonicTopK(const T* X, T* V, int64_t* I, const TArray template __device__ __inline__ bool Equal(const T& t0, const T& t1) { + return t0 == t1; +} + +__device__ __inline__ bool Equal(const float& t0, const float& t1) { auto t2 = t0 > t1 ? t0 - t1 : t1 - t0; - return (double)t2 < 1.0e-5; + return t2 < std::numeric_limits::epsilon(); +} + +__device__ __inline__ bool Equal(const double& t0, const double& t1) { + auto t2 = t0 > t1 ? t0 - t1 : t1 - t0; + return t2 < std::numeric_limits::epsilon(); } template @@ -220,6 +229,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray el } __syncthreads(); positive = BlockReduce(temp_storage.reduce).Sum(positive); + __syncthreads(); negative = BlockReduce(temp_storage.reduce).Sum(negative); if (0 == tid) { H[0] = positive; @@ -286,6 +296,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray el __syncthreads(); all_superior = H[0]; BlockScan(temp_storage.scan).ExclusiveSum(superior, superior); + __syncthreads(); BlockScan(temp_storage.scan).ExclusiveSum(equal, equal); __syncthreads(); auto equal_quota = K - all_superior - equal;