Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion onnxruntime/core/providers/cuda/math/topk_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,17 @@ __global__ void BitonicTopK(const T* X, T* V, int64_t* I, const TArray<int64_t>

template <typename T>
__device__ __inline__ bool Equal(const T& t0, const T& t1) {
return t0 == t1;
}

__device__ __inline__ bool Equal(const float& t0, const float& t1) {
auto t2 = t0 > t1 ? t0 - t1 : t1 - t0;
return (double)t2 < 1.0e-5;
return t2 < std::numeric_limits<float>::epsilon();
}

__device__ __inline__ bool Equal(const double& t0, const double& t1) {
auto t2 = t0 > t1 ? t0 - t1 : t1 - t0;
return t2 < std::numeric_limits<double>::epsilon();
}

template<typename T>
Expand Down Expand Up @@ -220,6 +229,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> el
}
__syncthreads();
positive = BlockReduce(temp_storage.reduce).Sum(positive);
__syncthreads();
negative = BlockReduce(temp_storage.reduce).Sum(negative);
if (0 == tid) {
H[0] = positive;
Expand Down Expand Up @@ -286,6 +296,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> el
__syncthreads();
all_superior = H[0];
BlockScan(temp_storage.scan).ExclusiveSum(superior, superior);
__syncthreads();
BlockScan(temp_storage.scan).ExclusiveSum(equal, equal);
__syncthreads();
auto equal_quota = K - all_superior - equal;
Expand Down