From 9783a9baf90d8e7c61598ec2896d81edc88d4dec Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Fri, 16 Jun 2023 12:35:25 -0700 Subject: [PATCH] Add an optional argument "hasnan" for qselect and partialsort --- src/avx512-16bit-qsort.hpp | 22 +++++++++++------ src/avx512-32bit-qsort.hpp | 16 +++++++------ src/avx512-64bit-qsort.hpp | 16 +++++++------ src/avx512-common-qsort.h | 44 +++++++++++++++++++++++++++++----- src/avx512fp16-16bit-qsort.hpp | 20 ++++++++++++---- 5 files changed, 86 insertions(+), 32 deletions(-) diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 606f8706..1efcf1e9 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -406,7 +406,13 @@ replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count) } template <> -void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize) +bool is_a_nan(uint16_t elem) +{ + return (elem & 0x7c00) == 0x7c00; +} + +template <> +void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize, bool hasnan) { if (arrsize > 1) { qselect_16bit_, int16_t>( @@ -415,7 +421,7 @@ void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize) } template <> -void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize) +void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) { if (arrsize > 1) { qselect_16bit_, uint16_t>( @@ -423,13 +429,15 @@ void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize) } } -void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize) +void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = arrsize - 1; + if (UNLIKELY(hasnan)) { + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + } + if (indx_last_elem >= k) { qselect_16bit_, uint16_t>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index c4061ddf..bfd4a151 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -715,7 +715,7 @@ replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count) } template <> -void avx512_qselect(int32_t *arr, int64_t k, int64_t arrsize) +void avx512_qselect(int32_t *arr, int64_t k, int64_t arrsize, bool hasnan) { if (arrsize > 1) { qselect_32bit_, int32_t>( @@ -724,7 +724,7 @@ void avx512_qselect(int32_t *arr, int64_t k, int64_t arrsize) } template <> -void avx512_qselect(uint32_t *arr, int64_t k, int64_t arrsize) +void avx512_qselect(uint32_t *arr, int64_t k, int64_t arrsize, bool hasnan) { if (arrsize > 1) { qselect_32bit_, uint32_t>( @@ -733,13 +733,15 @@ void avx512_qselect(uint32_t *arr, int64_t k, int64_t arrsize) } template <> -void avx512_qselect(float *arr, int64_t k, int64_t arrsize) +void avx512_qselect(float *arr, int64_t k, int64_t arrsize, bool hasnan) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = arrsize - 1; + if (UNLIKELY(hasnan)) { + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + } + if (indx_last_elem >= k) { qselect_32bit_, float>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index 1cbcd388..aa5d7958 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -784,7 +784,7 @@ static void qselect_64bit_(type_t *arr, } template <> -void avx512_qselect(int64_t *arr, int64_t k, int64_t arrsize) +void avx512_qselect(int64_t *arr, int64_t k, int64_t arrsize, bool hasnan) { if (arrsize > 1) { qselect_64bit_, int64_t>( @@ -793,7 +793,7 @@ void avx512_qselect(int64_t *arr, int64_t k, int64_t arrsize) } template <> -void avx512_qselect(uint64_t *arr, int64_t k, int64_t arrsize) +void avx512_qselect(uint64_t *arr, int64_t k, int64_t arrsize, bool hasnan) { if (arrsize > 1) { qselect_64bit_, uint64_t>( @@ -802,13 +802,15 @@ void avx512_qselect(uint64_t *arr, int64_t k, int64_t arrsize) } template <> -void avx512_qselect(double *arr, int64_t k, int64_t arrsize) +void avx512_qselect(double *arr, int64_t k, int64_t arrsize, bool hasnan) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = arrsize - 1; + if (UNLIKELY(hasnan)) { + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + } + if (indx_last_elem >= k) { qselect_64bit_, double>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 959352e6..841b4a83 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -85,6 +85,9 @@ #define X86_SIMD_SORT_FINLINE static #endif +#define LIKELY(x) __builtin_expect((x),1) +#define UNLIKELY(x) __builtin_expect((x),0) + template struct zmm_vector; @@ -97,18 +100,18 @@ void avx512_qsort(T *arr, int64_t arrsize); void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize); template -void avx512_qselect(T *arr, int64_t k, int64_t arrsize); -void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize); +void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false); +void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false); template -inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize) +inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) { - avx512_qselect(arr, k - 1, arrsize); + avx512_qselect(arr, k - 1, arrsize, hasnan); avx512_qsort(arr, k - 1); } -inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize) +inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false) { - avx512_qselect_fp16(arr, k - 1, arrsize); + avx512_qselect_fp16(arr, k - 1, arrsize, hasnan); avx512_qsort_fp16(arr, k - 1); } @@ -116,6 +119,35 @@ inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize) template void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); +template +bool is_a_nan(T elem) +{ + return std::isnan(elem); +} + +/* + * Sort all the NAN's to end of the array and return the index of the last elem + * in the array which is not a nan + */ +template +int64_t move_nans_to_end_of_array(T* arr, int64_t arrsize) +{ + int64_t jj = arrsize - 1; + int64_t ii = 0; + int64_t count = 0; + while (ii <= jj) { + if (is_a_nan(arr[ii])) { + std::swap(arr[ii], arr[jj]); + jj -= 1; + count++; + } + else { + ii += 1; + } + } + return arrsize-count-1; +} + template bool comparison_func(const T &a, const T &b) { diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 8a9a49ed..8e87ac29 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -145,13 +145,23 @@ replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count) } template <> -void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize) +bool is_a_nan<_Float16>(_Float16 elem) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + Fp16Bits temp; + temp.f_ = elem; + return (temp.i_ & 0x7c00) == 0x7c00; +} + +template <> +void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan) +{ + int64_t indx_last_elem = arrsize - 1; + if (UNLIKELY(hasnan)) { + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + } + if (indx_last_elem >= k) { qselect_16bit_, _Float16>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } }