From 8fb5d2d92e7aa484a4dbed98c2a5ae4d57977f9e Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 12 Jun 2023 15:06:08 -0700 Subject: [PATCH 1/2] Revert "Rename function" This reverts commit b3cd2621d7460019a91b4be33fd996fcfca030d1. --- src/avx512-16bit-qsort.hpp | 4 ++-- src/avx512-32bit-qsort.hpp | 4 ++-- src/avx512-64bit-qsort.hpp | 4 ++-- src/avx512-common-qsort.h | 2 +- src/avx512fp16-16bit-qsort.hpp | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 03e00e4f..9133eb15 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -403,7 +403,7 @@ void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize) void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize) { - int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); if (indx_last_elem >= k) { qselect_16bit_, uint16_t>( arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); @@ -430,7 +430,7 @@ void avx512_qsort(uint16_t *arr, int64_t arrsize) void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize) { - int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); if (indx_last_elem > 0) { qsort_16bit_, uint16_t>( arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index c8923899..4cc8e2ec 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -710,7 +710,7 @@ void avx512_qselect(uint32_t *arr, int64_t k, int64_t arrsize) template <> void avx512_qselect(float *arr, int64_t k, int64_t arrsize) { - int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); if (indx_last_elem >= k) { qselect_32bit_, float>( arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); @@ -738,7 +738,7 @@ void avx512_qsort(uint32_t *arr, int64_t arrsize) template <> void avx512_qsort(float *arr, int64_t arrsize) { - int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); if (indx_last_elem > 0) { qsort_32bit_, float>( arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index 046288d3..c0d2cdba 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -804,7 +804,7 @@ void avx512_qselect(uint64_t *arr, int64_t k, int64_t arrsize) template <> void avx512_qselect(double *arr, int64_t k, int64_t arrsize) { - int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); if (indx_last_elem >= k) { qselect_64bit_, double>( arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); @@ -832,7 +832,7 @@ void avx512_qsort(uint64_t *arr, int64_t arrsize) template <> void avx512_qsort(double *arr, int64_t arrsize) { - int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); if (indx_last_elem > 0) { qsort_64bit_, double>( arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 60637ce8..3969e39a 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -127,7 +127,7 @@ bool is_a_nan(T elem) * in the array which is not a nan */ template -int64_t move_nans_to_end_of_array(T* arr, int64_t arrsize) +int64_t put_nans_at_end_of_array(T* arr, int64_t arrsize) { int64_t jj = arrsize - 1; int64_t ii = 0; diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 54fdff79..9a00a820 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -125,7 +125,7 @@ bool is_a_nan<_Float16>(_Float16 elem) template <> void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize) { - int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); if (indx_last_elem >= k) { qselect_16bit_, _Float16>( arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); @@ -135,7 +135,7 @@ void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize) template <> void avx512_qsort(_Float16 *arr, int64_t arrsize) { - int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); if (indx_last_elem > 0) { qsort_16bit_, _Float16>( arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); From 0b2d89e52302c68322d13f687359228ca56dccd7 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 12 Jun 2023 15:09:55 -0700 Subject: [PATCH 2/2] Revert "Preserve NANs in the array for QSORT" This reverts commit 5385d3c480b0c922578a06726302ecd849ddc304. --- src/avx512-16bit-qsort.hpp | 48 +++++++++++++++++++++++++--------- src/avx512-32bit-qsort.hpp | 39 ++++++++++++++++++++++----- src/avx512-64bit-qsort.hpp | 14 +++++----- src/avx512-common-qsort.h | 29 -------------------- src/avx512fp16-16bit-qsort.hpp | 46 ++++++++++++++++++++++++-------- 5 files changed, 112 insertions(+), 64 deletions(-) diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 9133eb15..606f8706 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -349,12 +349,6 @@ struct zmm_vector { } }; -template <> -bool is_a_nan(uint16_t elem) -{ - return (elem & 0x7c00) == 0x7c00; -} - template <> bool comparison_func>(const uint16_t &a, const uint16_t &b) { @@ -383,6 +377,34 @@ bool comparison_func>(const uint16_t &a, const uint16_t &b) //return npy_half_to_float(a) < npy_half_to_float(b); } +X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr, + int64_t arrsize) +{ + int64_t nan_count = 0; + __mmask16 loadmask = 0xFFFF; + while (arrsize > 0) { + if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; } + __m256i in_zmm = _mm256_maskz_loadu_epi16(loadmask, arr); + __m512 in_zmm_asfloat = _mm512_cvtph_ps(in_zmm); + __mmask16 nanmask = _mm512_cmp_ps_mask( + in_zmm_asfloat, in_zmm_asfloat, _CMP_NEQ_UQ); + nan_count += _mm_popcnt_u32((int32_t)nanmask); + _mm256_mask_storeu_epi16(arr, nanmask, YMM_MAX_HALF); + arr += 16; + arrsize -= 16; + } + return nan_count; +} + +X86_SIMD_SORT_INLINE void +replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count) +{ + for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { + arr[ii] = 0xFFFF; + nan_count -= 1; + } +} + template <> void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize) { @@ -403,10 +425,11 @@ void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize) void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize) { - int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); - if (indx_last_elem >= k) { + if (arrsize > 1) { + int64_t nan_count = replace_nan_with_inf(arr, arrsize); qselect_16bit_, uint16_t>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count); } } @@ -430,10 +453,11 @@ void avx512_qsort(uint16_t *arr, int64_t arrsize) void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize) { - int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); - if (indx_last_elem > 0) { + if (arrsize > 1) { + int64_t nan_count = replace_nan_with_inf(arr, arrsize); qsort_16bit_, uint16_t>( - arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count); } } diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 4cc8e2ec..c4061ddf 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -689,6 +689,31 @@ static void qselect_32bit_(type_t *arr, qselect_32bit_(arr, pos, pivot_index, right, max_iters - 1); } +X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(float *arr, int64_t arrsize) +{ + int64_t nan_count = 0; + __mmask16 loadmask = 0xFFFF; + while (arrsize > 0) { + if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; } + __m512 in_zmm = _mm512_maskz_loadu_ps(loadmask, arr); + __mmask16 nanmask = _mm512_cmp_ps_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); + nan_count += _mm_popcnt_u32((int32_t)nanmask); + _mm512_mask_storeu_ps(arr, nanmask, ZMM_MAX_FLOAT); + arr += 16; + arrsize -= 16; + } + return nan_count; +} + +X86_SIMD_SORT_INLINE void +replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count) +{ + for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { + arr[ii] = std::nanf("1"); + nan_count -= 1; + } +} + template <> void avx512_qselect(int32_t *arr, int64_t k, int64_t arrsize) { @@ -710,10 +735,11 @@ void avx512_qselect(uint32_t *arr, int64_t k, int64_t arrsize) template <> void avx512_qselect(float *arr, int64_t k, int64_t arrsize) { - int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); - if (indx_last_elem >= k) { + if (arrsize > 1) { + int64_t nan_count = replace_nan_with_inf(arr, arrsize); qselect_32bit_, float>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count); } } @@ -738,10 +764,11 @@ void avx512_qsort(uint32_t *arr, int64_t arrsize) template <> void avx512_qsort(float *arr, int64_t arrsize) { - int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); - if (indx_last_elem > 0) { + if (arrsize > 1) { + int64_t nan_count = replace_nan_with_inf(arr, arrsize); qsort_32bit_, float>( - arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count); } } diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index c0d2cdba..1cbcd388 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -804,10 +804,11 @@ void avx512_qselect(uint64_t *arr, int64_t k, int64_t arrsize) template <> void avx512_qselect(double *arr, int64_t k, int64_t arrsize) { - int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); - if (indx_last_elem >= k) { + if (arrsize > 1) { + int64_t nan_count = replace_nan_with_inf(arr, arrsize); qselect_64bit_, double>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count); } } @@ -832,10 +833,11 @@ void avx512_qsort(uint64_t *arr, int64_t arrsize) template <> void avx512_qsort(double *arr, int64_t arrsize) { - int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); - if (indx_last_elem > 0) { + if (arrsize > 1) { + int64_t nan_count = replace_nan_with_inf(arr, arrsize); qsort_64bit_, double>( - arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count); } } #endif // AVX512_QSORT_64BIT diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 3969e39a..959352e6 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -116,35 +116,6 @@ inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize) template void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); -template -bool is_a_nan(T elem) -{ - return std::isnan(elem); -} - -/* - * Sort all the NAN's to end of the array and return the index of the last elem - * in the array which is not a nan - */ -template -int64_t put_nans_at_end_of_array(T* arr, int64_t arrsize) -{ - int64_t jj = arrsize - 1; - int64_t ii = 0; - int64_t count = 0; - while (ii <= jj) { - if (is_a_nan(arr[ii])) { - std::swap(arr[ii], arr[jj]); - jj -= 1; - count++; - } - else { - ii += 1; - } - } - return arrsize-count-1; -} - template bool comparison_func(const T &a, const T &b) { diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 9a00a820..8a9a49ed 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -114,31 +114,55 @@ struct zmm_vector<_Float16> { } }; -template <> -bool is_a_nan<_Float16>(_Float16 elem) +X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(_Float16 *arr, + int64_t arrsize) +{ + int64_t nan_count = 0; + __mmask32 loadmask = 0xFFFFFFFF; + __m512h in_zmm; + while (arrsize > 0) { + if (arrsize < 32) { + loadmask = (0x00000001 << arrsize) - 0x00000001; + in_zmm = _mm512_castsi512_ph( + _mm512_maskz_loadu_epi16(loadmask, arr)); + } + else { + in_zmm = _mm512_loadu_ph(arr); + } + __mmask32 nanmask = _mm512_cmp_ph_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); + nan_count += _mm_popcnt_u32((int32_t)nanmask); + _mm512_mask_storeu_epi16(arr, nanmask, ZMM_MAX_HALF); + arr += 32; + arrsize -= 32; + } + return nan_count; +} + +X86_SIMD_SORT_INLINE void +replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count) { - Fp16Bits temp; - temp.f_ = elem; - return (temp.i_ & 0x7c00) == 0x7c00; + memset(arr + arrsize - nan_count, 0xFF, nan_count * 2); } template <> void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize) { - int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); - if (indx_last_elem >= k) { + if (arrsize > 1) { + int64_t nan_count = replace_nan_with_inf(arr, arrsize); qselect_16bit_, _Float16>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count); } } template <> void avx512_qsort(_Float16 *arr, int64_t arrsize) { - int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); - if (indx_last_elem > 0) { + if (arrsize > 1) { + int64_t nan_count = replace_nan_with_inf(arr, arrsize); qsort_16bit_, _Float16>( - arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count); } } #endif // AVX512FP16_QSORT_16BIT