diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 606f8706..03e00e4f 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -349,6 +349,12 @@ struct zmm_vector { } }; +template <> +bool is_a_nan(uint16_t elem) +{ + return (elem & 0x7c00) == 0x7c00; +} + template <> bool comparison_func>(const uint16_t &a, const uint16_t &b) { @@ -377,34 +383,6 @@ bool comparison_func>(const uint16_t &a, const uint16_t &b) //return npy_half_to_float(a) < npy_half_to_float(b); } -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr, - int64_t arrsize) -{ - int64_t nan_count = 0; - __mmask16 loadmask = 0xFFFF; - while (arrsize > 0) { - if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; } - __m256i in_zmm = _mm256_maskz_loadu_epi16(loadmask, arr); - __m512 in_zmm_asfloat = _mm512_cvtph_ps(in_zmm); - __mmask16 nanmask = _mm512_cmp_ps_mask( - in_zmm_asfloat, in_zmm_asfloat, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm256_mask_storeu_epi16(arr, nanmask, YMM_MAX_HALF); - arr += 16; - arrsize -= 16; - } - return nan_count; -} - -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count) -{ - for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { - arr[ii] = 0xFFFF; - nan_count -= 1; - } -} - template <> void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize) { @@ -425,11 +403,10 @@ void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize) void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + if (indx_last_elem >= k) { qselect_16bit_, uint16_t>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } @@ -453,11 +430,10 @@ void avx512_qsort(uint16_t *arr, int64_t arrsize) void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + if (indx_last_elem > 0) { qsort_16bit_, uint16_t>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index c4061ddf..c8923899 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -689,31 +689,6 @@ static void qselect_32bit_(type_t *arr, qselect_32bit_(arr, pos, pivot_index, right, max_iters - 1); } -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(float *arr, int64_t arrsize) -{ - int64_t nan_count = 0; - __mmask16 loadmask = 0xFFFF; - while (arrsize > 0) { - if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; } - __m512 in_zmm = _mm512_maskz_loadu_ps(loadmask, arr); - __mmask16 nanmask = _mm512_cmp_ps_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm512_mask_storeu_ps(arr, nanmask, ZMM_MAX_FLOAT); - arr += 16; - arrsize -= 16; - } - return nan_count; -} - -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count) -{ - for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { - arr[ii] = std::nanf("1"); - nan_count -= 1; - } -} - template <> void avx512_qselect(int32_t *arr, int64_t k, int64_t arrsize) { @@ -735,11 +710,10 @@ void avx512_qselect(uint32_t *arr, int64_t k, int64_t arrsize) template <> void avx512_qselect(float *arr, int64_t k, int64_t arrsize) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + if (indx_last_elem >= k) { qselect_32bit_, float>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } @@ -764,11 +738,10 @@ void avx512_qsort(uint32_t *arr, int64_t arrsize) template <> void avx512_qsort(float *arr, int64_t arrsize) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + if (indx_last_elem > 0) { qsort_32bit_, float>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index 1cbcd388..046288d3 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -804,11 +804,10 @@ void avx512_qselect(uint64_t *arr, int64_t k, int64_t arrsize) template <> void avx512_qselect(double *arr, int64_t k, int64_t arrsize) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + if (indx_last_elem >= k) { qselect_64bit_, double>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } @@ -833,11 +832,10 @@ void avx512_qsort(uint64_t *arr, int64_t arrsize) template <> void avx512_qsort(double *arr, int64_t arrsize) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + if (indx_last_elem > 0) { qsort_64bit_, double>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } #endif // AVX512_QSORT_64BIT diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 959352e6..60637ce8 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -116,6 +116,35 @@ inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize) template void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); +template +bool is_a_nan(T elem) +{ + return std::isnan(elem); +} + +/* + * Sort all the NAN's to end of the array and return the index of the last elem + * in the array which is not a nan + */ +template +int64_t move_nans_to_end_of_array(T* arr, int64_t arrsize) +{ + int64_t jj = arrsize - 1; + int64_t ii = 0; + int64_t count = 0; + while (ii <= jj) { + if (is_a_nan(arr[ii])) { + std::swap(arr[ii], arr[jj]); + jj -= 1; + count++; + } + else { + ii += 1; + } + } + return arrsize-count-1; +} + template bool comparison_func(const T &a, const T &b) { diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 8a9a49ed..54fdff79 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -114,55 +114,31 @@ struct zmm_vector<_Float16> { } }; -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(_Float16 *arr, - int64_t arrsize) -{ - int64_t nan_count = 0; - __mmask32 loadmask = 0xFFFFFFFF; - __m512h in_zmm; - while (arrsize > 0) { - if (arrsize < 32) { - loadmask = (0x00000001 << arrsize) - 0x00000001; - in_zmm = _mm512_castsi512_ph( - _mm512_maskz_loadu_epi16(loadmask, arr)); - } - else { - in_zmm = _mm512_loadu_ph(arr); - } - __mmask32 nanmask = _mm512_cmp_ph_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm512_mask_storeu_epi16(arr, nanmask, ZMM_MAX_HALF); - arr += 32; - arrsize -= 32; - } - return nan_count; -} - -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count) +template <> +bool is_a_nan<_Float16>(_Float16 elem) { - memset(arr + arrsize - nan_count, 0xFF, nan_count * 2); + Fp16Bits temp; + temp.f_ = elem; + return (temp.i_ & 0x7c00) == 0x7c00; } template <> void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + if (indx_last_elem >= k) { qselect_16bit_, _Float16>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } template <> void avx512_qsort(_Float16 *arr, int64_t arrsize) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + if (indx_last_elem > 0) { qsort_16bit_, _Float16>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } #endif // AVX512FP16_QSORT_16BIT diff --git a/tests/test-qsort-fp.hpp b/tests/test-qsort-fp.hpp new file mode 100644 index 00000000..9000fb38 --- /dev/null +++ b/tests/test-qsort-fp.hpp @@ -0,0 +1,47 @@ +/******************************************* + * * Copyright (C) 2022 Intel Corporation + * * SPDX-License-Identifier: BSD-3-Clause + * *******************************************/ + +#include "test-qsort-common.h" + +template +class avx512_sort_fp : public ::testing::Test { +}; +TYPED_TEST_SUITE_P(avx512_sort_fp); + +TYPED_TEST_P(avx512_sort_fp, test_random_nan) +{ + const int num_nans = 3; + if (!cpu_has_avx512bw()) { + GTEST_SKIP() << "Skipping this test, it requires avx512bw"; + } + std::vector arrsizes; + for (int64_t ii = num_nans; ii < 1024; ++ii) { + arrsizes.push_back((TypeParam)ii); + } + std::vector arr; + std::vector sortedarr; + for (auto &size : arrsizes) { + /* Random array */ + arr = get_uniform_rand_array(size); + for (auto ii = 1; ii <= num_nans; ++ii) { + arr[size-ii] = std::numeric_limits::quiet_NaN(); + } + sortedarr = arr; + std::sort(sortedarr.begin(), sortedarr.end()-3); + std::random_shuffle(arr.begin(), arr.end()); + avx512_qsort(arr.data(), arr.size()); + for (auto ii = 1; ii <= num_nans; ++ii) { + if (!std::isnan(arr[size-ii])) { + ASSERT_TRUE(false) << "NAN's aren't sorted to the end. Arr size = " << size; + } + } + if (!std::is_sorted(arr.begin(), arr.end() - num_nans)) { + ASSERT_TRUE(true) << "Array isn't sorted"; + } + arr.clear(); + sortedarr.clear(); + } +} +REGISTER_TYPED_TEST_SUITE_P(avx512_sort_fp, test_random_nan); diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index 053f5367..eb3d5f77 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -1,5 +1,6 @@ #include "test-partial-qsort.hpp" #include "test-qselect.hpp" +#include "test-qsort-fp.hpp" #include "test-qsort.hpp" using QSortTestTypes = testing::Types; + +using QSortTestFPTypes = testing::Types; + INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512_sort, QSortTestTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512_sort_fp, QSortTestFPTypes); INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512_select, QSortTestTypes); INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512_partial_sort, QSortTestTypes);