From 5385d3c480b0c922578a06726302ecd849ddc304 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Thu, 25 May 2023 09:44:35 -0700 Subject: [PATCH 1/3] Preserve NANs in the array for QSORT --- src/avx512-16bit-qsort.hpp | 48 +++++++++------------------------- src/avx512-32bit-qsort.hpp | 39 +++++---------------------- src/avx512-64bit-qsort.hpp | 14 +++++----- src/avx512-common-qsort.h | 27 +++++++++++++++++++ src/avx512fp16-16bit-qsort.hpp | 46 ++++++++------------------------ 5 files changed, 62 insertions(+), 112 deletions(-) diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 606f8706..9133eb15 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -349,6 +349,12 @@ struct zmm_vector { } }; +template <> +bool is_a_nan(uint16_t elem) +{ + return (elem & 0x7c00) == 0x7c00; +} + template <> bool comparison_func>(const uint16_t &a, const uint16_t &b) { @@ -377,34 +383,6 @@ bool comparison_func>(const uint16_t &a, const uint16_t &b) //return npy_half_to_float(a) < npy_half_to_float(b); } -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr, - int64_t arrsize) -{ - int64_t nan_count = 0; - __mmask16 loadmask = 0xFFFF; - while (arrsize > 0) { - if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; } - __m256i in_zmm = _mm256_maskz_loadu_epi16(loadmask, arr); - __m512 in_zmm_asfloat = _mm512_cvtph_ps(in_zmm); - __mmask16 nanmask = _mm512_cmp_ps_mask( - in_zmm_asfloat, in_zmm_asfloat, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm256_mask_storeu_epi16(arr, nanmask, YMM_MAX_HALF); - arr += 16; - arrsize -= 16; - } - return nan_count; -} - -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count) -{ - for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { - arr[ii] = 0xFFFF; - nan_count -= 1; - } -} - template <> void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize) { @@ -425,11 +403,10 @@ void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize) void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); + if (indx_last_elem >= k) { qselect_16bit_, uint16_t>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } @@ -453,11 +430,10 @@ void avx512_qsort(uint16_t *arr, int64_t arrsize) void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); + if (indx_last_elem > 0) { qsort_16bit_, uint16_t>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index c4061ddf..4cc8e2ec 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -689,31 +689,6 @@ static void qselect_32bit_(type_t *arr, qselect_32bit_(arr, pos, pivot_index, right, max_iters - 1); } -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(float *arr, int64_t arrsize) -{ - int64_t nan_count = 0; - __mmask16 loadmask = 0xFFFF; - while (arrsize > 0) { - if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; } - __m512 in_zmm = _mm512_maskz_loadu_ps(loadmask, arr); - __mmask16 nanmask = _mm512_cmp_ps_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm512_mask_storeu_ps(arr, nanmask, ZMM_MAX_FLOAT); - arr += 16; - arrsize -= 16; - } - return nan_count; -} - -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count) -{ - for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { - arr[ii] = std::nanf("1"); - nan_count -= 1; - } -} - template <> void avx512_qselect(int32_t *arr, int64_t k, int64_t arrsize) { @@ -735,11 +710,10 @@ void avx512_qselect(uint32_t *arr, int64_t k, int64_t arrsize) template <> void avx512_qselect(float *arr, int64_t k, int64_t arrsize) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); + if (indx_last_elem >= k) { qselect_32bit_, float>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } @@ -764,11 +738,10 @@ void avx512_qsort(uint32_t *arr, int64_t arrsize) template <> void avx512_qsort(float *arr, int64_t arrsize) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); + if (indx_last_elem > 0) { qsort_32bit_, float>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index 1cbcd388..c0d2cdba 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -804,11 +804,10 @@ void avx512_qselect(uint64_t *arr, int64_t k, int64_t arrsize) template <> void avx512_qselect(double *arr, int64_t k, int64_t arrsize) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); + if (indx_last_elem >= k) { qselect_64bit_, double>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } @@ -833,11 +832,10 @@ void avx512_qsort(uint64_t *arr, int64_t arrsize) template <> void avx512_qsort(double *arr, int64_t arrsize) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); + if (indx_last_elem > 0) { qsort_64bit_, double>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } #endif // AVX512_QSORT_64BIT diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 959352e6..bf0bb0e1 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -116,6 +116,33 @@ inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize) template void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); +template +bool is_a_nan(T elem) +{ + return std::isnan(elem); +} + +/* + * Sort all the NAN's to end of the array and return the index of the last elem + * in the array which is not a nan + */ +template +int64_t put_nans_at_end_of_array(T* arr, int64_t arrsize) +{ + int64_t jj = arrsize - 1; + int64_t ii = 0; + while (ii < jj) { + if (is_a_nan(arr[ii])) { + std::swap(arr[ii], arr[jj]); + jj -= 1; + } + else { + ii += 1; + } + } + return ii; +} + template bool comparison_func(const T &a, const T &b) { diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 8a9a49ed..9a00a820 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -114,55 +114,31 @@ struct zmm_vector<_Float16> { } }; -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(_Float16 *arr, - int64_t arrsize) -{ - int64_t nan_count = 0; - __mmask32 loadmask = 0xFFFFFFFF; - __m512h in_zmm; - while (arrsize > 0) { - if (arrsize < 32) { - loadmask = (0x00000001 << arrsize) - 0x00000001; - in_zmm = _mm512_castsi512_ph( - _mm512_maskz_loadu_epi16(loadmask, arr)); - } - else { - in_zmm = _mm512_loadu_ph(arr); - } - __mmask32 nanmask = _mm512_cmp_ph_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm512_mask_storeu_epi16(arr, nanmask, ZMM_MAX_HALF); - arr += 32; - arrsize -= 32; - } - return nan_count; -} - -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count) +template <> +bool is_a_nan<_Float16>(_Float16 elem) { - memset(arr + arrsize - nan_count, 0xFF, nan_count * 2); + Fp16Bits temp; + temp.f_ = elem; + return (temp.i_ & 0x7c00) == 0x7c00; } template <> void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); + if (indx_last_elem >= k) { qselect_16bit_, _Float16>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } template <> void avx512_qsort(_Float16 *arr, int64_t arrsize) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); + if (indx_last_elem > 0) { qsort_16bit_, _Float16>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } #endif // AVX512FP16_QSORT_16BIT From b3cd2621d7460019a91b4be33fd996fcfca030d1 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Thu, 25 May 2023 09:46:50 -0700 Subject: [PATCH 2/3] Rename function --- src/avx512-16bit-qsort.hpp | 4 ++-- src/avx512-32bit-qsort.hpp | 4 ++-- src/avx512-64bit-qsort.hpp | 4 ++-- src/avx512-common-qsort.h | 2 +- src/avx512fp16-16bit-qsort.hpp | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 9133eb15..03e00e4f 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -403,7 +403,7 @@ void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize) void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize) { - int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); + int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); if (indx_last_elem >= k) { qselect_16bit_, uint16_t>( arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); @@ -430,7 +430,7 @@ void avx512_qsort(uint16_t *arr, int64_t arrsize) void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize) { - int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); + int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); if (indx_last_elem > 0) { qsort_16bit_, uint16_t>( arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 4cc8e2ec..c8923899 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -710,7 +710,7 @@ void avx512_qselect(uint32_t *arr, int64_t k, int64_t arrsize) template <> void avx512_qselect(float *arr, int64_t k, int64_t arrsize) { - int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); + int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); if (indx_last_elem >= k) { qselect_32bit_, float>( arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); @@ -738,7 +738,7 @@ void avx512_qsort(uint32_t *arr, int64_t arrsize) template <> void avx512_qsort(float *arr, int64_t arrsize) { - int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); + int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); if (indx_last_elem > 0) { qsort_32bit_, float>( arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index c0d2cdba..046288d3 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -804,7 +804,7 @@ void avx512_qselect(uint64_t *arr, int64_t k, int64_t arrsize) template <> void avx512_qselect(double *arr, int64_t k, int64_t arrsize) { - int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); + int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); if (indx_last_elem >= k) { qselect_64bit_, double>( arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); @@ -832,7 +832,7 @@ void avx512_qsort(uint64_t *arr, int64_t arrsize) template <> void avx512_qsort(double *arr, int64_t arrsize) { - int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); + int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); if (indx_last_elem > 0) { qsort_64bit_, double>( arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index bf0bb0e1..b1b4ef70 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -127,7 +127,7 @@ bool is_a_nan(T elem) * in the array which is not a nan */ template -int64_t put_nans_at_end_of_array(T* arr, int64_t arrsize) +int64_t move_nans_to_end_of_array(T* arr, int64_t arrsize) { int64_t jj = arrsize - 1; int64_t ii = 0; diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 9a00a820..54fdff79 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -125,7 +125,7 @@ bool is_a_nan<_Float16>(_Float16 elem) template <> void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize) { - int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); + int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); if (indx_last_elem >= k) { qselect_16bit_, _Float16>( arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); @@ -135,7 +135,7 @@ void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize) template <> void avx512_qsort(_Float16 *arr, int64_t arrsize) { - int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize); + int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); if (indx_last_elem > 0) { qsort_16bit_, _Float16>( arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); From 880f5a2627316e3f59224fa95445e563736beffa Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Thu, 25 May 2023 13:29:25 -0700 Subject: [PATCH 3/3] Add tests and fix bug in move_nans_to_end_of_array --- src/avx512-common-qsort.h | 6 +++-- tests/test-qsort-fp.hpp | 47 +++++++++++++++++++++++++++++++++++++++ tests/test-qsort.cpp | 5 +++++ 3 files changed, 56 insertions(+), 2 deletions(-) create mode 100644 tests/test-qsort-fp.hpp diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index b1b4ef70..60637ce8 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -131,16 +131,18 @@ int64_t move_nans_to_end_of_array(T* arr, int64_t arrsize) { int64_t jj = arrsize - 1; int64_t ii = 0; - while (ii < jj) { + int64_t count = 0; + while (ii <= jj) { if (is_a_nan(arr[ii])) { std::swap(arr[ii], arr[jj]); jj -= 1; + count++; } else { ii += 1; } } - return ii; + return arrsize-count-1; } template diff --git a/tests/test-qsort-fp.hpp b/tests/test-qsort-fp.hpp new file mode 100644 index 00000000..9000fb38 --- /dev/null +++ b/tests/test-qsort-fp.hpp @@ -0,0 +1,47 @@ +/******************************************* + * * Copyright (C) 2022 Intel Corporation + * * SPDX-License-Identifier: BSD-3-Clause + * *******************************************/ + +#include "test-qsort-common.h" + +template +class avx512_sort_fp : public ::testing::Test { +}; +TYPED_TEST_SUITE_P(avx512_sort_fp); + +TYPED_TEST_P(avx512_sort_fp, test_random_nan) +{ + const int num_nans = 3; + if (!cpu_has_avx512bw()) { + GTEST_SKIP() << "Skipping this test, it requires avx512bw"; + } + std::vector arrsizes; + for (int64_t ii = num_nans; ii < 1024; ++ii) { + arrsizes.push_back((TypeParam)ii); + } + std::vector arr; + std::vector sortedarr; + for (auto &size : arrsizes) { + /* Random array */ + arr = get_uniform_rand_array(size); + for (auto ii = 1; ii <= num_nans; ++ii) { + arr[size-ii] = std::numeric_limits::quiet_NaN(); + } + sortedarr = arr; + std::sort(sortedarr.begin(), sortedarr.end()-3); + std::random_shuffle(arr.begin(), arr.end()); + avx512_qsort(arr.data(), arr.size()); + for (auto ii = 1; ii <= num_nans; ++ii) { + if (!std::isnan(arr[size-ii])) { + ASSERT_TRUE(false) << "NAN's aren't sorted to the end. Arr size = " << size; + } + } + if (!std::is_sorted(arr.begin(), arr.end() - num_nans)) { + ASSERT_TRUE(true) << "Array isn't sorted"; + } + arr.clear(); + sortedarr.clear(); + } +} +REGISTER_TYPED_TEST_SUITE_P(avx512_sort_fp, test_random_nan); diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index 053f5367..eb3d5f77 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -1,5 +1,6 @@ #include "test-partial-qsort.hpp" #include "test-qselect.hpp" +#include "test-qsort-fp.hpp" #include "test-qsort.hpp" using QSortTestTypes = testing::Types; + +using QSortTestFPTypes = testing::Types; + INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512_sort, QSortTestTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512_sort_fp, QSortTestFPTypes); INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512_select, QSortTestTypes); INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512_partial_sort, QSortTestTypes);